Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d439dbcc29 | |||
| 4707f16c76 | |||
| f5a23c1fa0 | |||
| 4fcc19e91f |
6
.gitignore
vendored
6
.gitignore
vendored
@ -36,10 +36,14 @@ go.work.sum
|
||||
e2b/
|
||||
.impeccable.md
|
||||
.gstack
|
||||
.mcp.json
|
||||
|
||||
## Builds
|
||||
builds/
|
||||
|
||||
## Rust
|
||||
envd-rs/target/
|
||||
|
||||
## Frontend
|
||||
frontend/node_modules/
|
||||
frontend/.svelte-kit/
|
||||
@ -49,3 +53,5 @@ frontend/build/
|
||||
internal/dashboard/static/*
|
||||
!internal/dashboard/static/.gitkeep.dual-graph/
|
||||
.dual-graph/
|
||||
# Added by code-review-graph
|
||||
.code-review-graph/
|
||||
|
||||
62
.woodpecker/pipeline.yml
Normal file
62
.woodpecker/pipeline.yml
Normal file
@ -0,0 +1,62 @@
|
||||
when:
|
||||
- event: push
|
||||
branch: main
|
||||
|
||||
steps:
|
||||
build-go:
|
||||
image: python:3.13
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: wrenn_api_key
|
||||
commands:
|
||||
- pip install wrenn
|
||||
- export GO_VERSION=$$(grep '^go ' go.mod | cut -d' ' -f2)
|
||||
- python .woodpecker/scripts/build_go.py
|
||||
depends_on: []
|
||||
|
||||
build-rust:
|
||||
image: python:3.13
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: wrenn_api_key
|
||||
commands:
|
||||
- pip install wrenn
|
||||
- python .woodpecker/scripts/build_rust.py
|
||||
depends_on: []
|
||||
|
||||
tag-release:
|
||||
image: python:3.13
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- VERSION=$$(cat VERSION_CP)
|
||||
- git config user.name "R3dRum92"
|
||||
- git config user.email "tksadik@omukk.dev"
|
||||
- git tag "v$${VERSION}"
|
||||
- git push "https://tksadik92:$${GITEA_TOKEN}@git.omukk.dev/tksadik92/wrenn-releases.git" "v$${VERSION}"
|
||||
depends_on: [build-go, build-rust]
|
||||
|
||||
release-notes:
|
||||
image: python:3.13
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: wrenn_api_key
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
ZHIPU_API_KEY:
|
||||
from_secret: zhipu_api_key
|
||||
commands:
|
||||
- pip install wrenn
|
||||
- python .woodpecker/scripts/release_notes.py
|
||||
depends_on: [tag-release]
|
||||
|
||||
publish-github:
|
||||
image: python:3.13
|
||||
environment:
|
||||
GITHUB_TOKEN:
|
||||
from_secret: github_token
|
||||
commands:
|
||||
- pip install httpx
|
||||
- python .woodpecker/scripts/publish_github.py
|
||||
depends_on: [release-notes]
|
||||
136
.woodpecker/scripts/build_go.py
Normal file
136
.woodpecker/scripts/build_go.py
Normal file
@ -0,0 +1,136 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
|
||||
from wrenn._git import GitCommandError
|
||||
|
||||
GO_VERSION = os.getenv("GO_VERSION", "1.25.8")
|
||||
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
|
||||
REPO_DIR = "/opt/wrenn"
|
||||
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
|
||||
|
||||
|
||||
def read_remote_version(capsule: Capsule, filename: str) -> str:
|
||||
content = capsule.files.read_bytes(f"{REPO_DIR}/{filename}")
|
||||
return content.decode("utf-8").strip()
|
||||
|
||||
|
||||
def run(capsule: Capsule, cmd: str, timeout: int = 30) -> int:
|
||||
result = capsule.commands.run(cmd, timeout=timeout)
|
||||
if result.exit_code != 0:
|
||||
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
|
||||
if result.stderr:
|
||||
print(result.stderr.strip(), file=sys.stderr)
|
||||
return result.exit_code
|
||||
print(f"OK [{cmd.split()[0]}]")
|
||||
return 0
|
||||
|
||||
|
||||
def install_go(capsule: Capsule) -> bool:
|
||||
tarball = f"go{GO_VERSION}.linux-amd64.tar.gz"
|
||||
url = f"https://go.dev/dl/{tarball}"
|
||||
|
||||
if run(capsule, "apt update", timeout=120) != 0:
|
||||
return False
|
||||
if run(capsule, "apt install -y make build-essential file", timeout=300) != 0:
|
||||
return False
|
||||
if run(capsule, f"curl -LO {url}", timeout=120) != 0:
|
||||
return False
|
||||
if run(capsule, f"tar -C /usr/local -xzf {tarball}", timeout=300) != 0:
|
||||
return False
|
||||
if run(capsule, 'echo "export PATH=$PATH:/usr/local/go/bin" >> ~/.profile') != 0:
|
||||
return False
|
||||
if run(capsule, "rm -f " + tarball) != 0:
|
||||
return False
|
||||
|
||||
result = capsule.commands.run("/usr/local/go/bin/go version")
|
||||
print(result.stdout.strip())
|
||||
return result.exit_code == 0
|
||||
|
||||
|
||||
def clone_repo(capsule: Capsule) -> bool:
|
||||
try:
|
||||
capsule.git.clone(REPO_URL, REPO_DIR)
|
||||
print("OK [git clone]")
|
||||
return True
|
||||
except GitCommandError as e:
|
||||
print(f"FAIL [git clone]: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def build_go(capsule: Capsule) -> bool:
|
||||
command = "CGO_ENABLED=1 make build-cp build-agent"
|
||||
handle = capsule.commands.run(
|
||||
command,
|
||||
background=True,
|
||||
cwd=REPO_DIR,
|
||||
envs={
|
||||
"PATH": "/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
},
|
||||
)
|
||||
print(f"{command} started (pid={handle.pid}), streaming output...")
|
||||
|
||||
exit_code = 0
|
||||
for event in capsule.commands.connect(handle.pid):
|
||||
if isinstance(event, StreamStdoutEvent):
|
||||
print(event.data, end="")
|
||||
elif isinstance(event, StreamStderrEvent):
|
||||
print(event.data, end="", file=sys.stderr)
|
||||
elif isinstance(event, StreamExitEvent):
|
||||
exit_code = event.exit_code
|
||||
|
||||
if exit_code != 0:
|
||||
print(f"FAIL [go build]: exit={exit_code}", file=sys.stderr)
|
||||
return False
|
||||
print("OK [go build]")
|
||||
return True
|
||||
|
||||
|
||||
def download_artifacts(capsule: Capsule) -> bool:
|
||||
remote_dir = f"{REPO_DIR}/builds"
|
||||
entries = capsule.files.list(remote_dir, depth=1)
|
||||
files = [e for e in entries if e.type != "directory"]
|
||||
|
||||
if not files:
|
||||
print("FAIL [download]: no files found in builds/", file=sys.stderr)
|
||||
return False
|
||||
|
||||
local_dir = os.path.normpath(BUILDS_DIR)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
versions = {
|
||||
"wrenn-cp": read_remote_version(capsule, "VERSION_CP"),
|
||||
"wrenn-agent": read_remote_version(capsule, "VERSION_AGENT"),
|
||||
}
|
||||
|
||||
for entry in files:
|
||||
name = entry.name or "unknown"
|
||||
remote_path = f"{remote_dir}/{name}"
|
||||
local_name = f"{name}-{versions[name]}" if name in versions else name
|
||||
local_path = os.path.join(local_dir, local_name)
|
||||
print(f"Downloading {name} as {local_name} ({entry.size or '?'} bytes)...")
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in capsule.files.download_stream(remote_path):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"OK [download {local_name}]")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
|
||||
print(f"Capsule: {capsule.capsule_id}")
|
||||
if not install_go(capsule):
|
||||
sys.exit(1)
|
||||
if not clone_repo(capsule):
|
||||
sys.exit(1)
|
||||
if not build_go(capsule):
|
||||
sys.exit(1)
|
||||
if not download_artifacts(capsule):
|
||||
sys.exit(1)
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
173
.woodpecker/scripts/build_rust.py
Normal file
173
.woodpecker/scripts/build_rust.py
Normal file
@ -0,0 +1,173 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
|
||||
from wrenn._git import GitCommandError
|
||||
|
||||
RUST_VERSION = os.getenv("RUST_VERSION", "1.95.0")
|
||||
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
|
||||
REPO_DIR = "/opt/wrenn"
|
||||
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
|
||||
RUST_PATH = (
|
||||
"/root/.cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
)
|
||||
|
||||
|
||||
def read_envd_version(capsule: Capsule) -> str:
|
||||
content = capsule.files.read_bytes(f"{REPO_DIR}/envd-rs/Cargo.toml")
|
||||
for line in content.decode("utf-8").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("version ="):
|
||||
return stripped.split("=", 1)[1].strip().strip('"')
|
||||
print("FAIL [version]: envd-rs/Cargo.toml has no package version", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run(capsule: Capsule, cmd: str, timeout: int = 30, envs={}) -> int:
|
||||
result = capsule.commands.run(cmd, timeout=timeout, envs=envs)
|
||||
if result.exit_code != 0:
|
||||
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
|
||||
if result.stderr:
|
||||
print(result.stderr.strip(), file=sys.stderr)
|
||||
return result.exit_code
|
||||
print(f"OK [{cmd.split()[0]}]")
|
||||
return 0
|
||||
|
||||
|
||||
def install_rust(capsule: Capsule) -> bool:
|
||||
if run(capsule, "apt update", timeout=120) != 0:
|
||||
return False
|
||||
if (
|
||||
run(
|
||||
capsule,
|
||||
"apt install -y make build-essential file curl musl-tools protobuf-compiler",
|
||||
timeout=300,
|
||||
)
|
||||
!= 0
|
||||
):
|
||||
return False
|
||||
if (
|
||||
run(
|
||||
capsule,
|
||||
f"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain {RUST_VERSION}",
|
||||
timeout=300,
|
||||
)
|
||||
!= 0
|
||||
):
|
||||
return False
|
||||
if (
|
||||
run(
|
||||
capsule,
|
||||
"/root/.cargo/bin/rustup target add x86_64-unknown-linux-musl",
|
||||
timeout=120,
|
||||
)
|
||||
!= 0
|
||||
):
|
||||
return False
|
||||
|
||||
result = capsule.commands.run("/root/.cargo/bin/rustc --version")
|
||||
print(result.stdout.strip())
|
||||
return result.exit_code == 0
|
||||
|
||||
|
||||
def clone_repo(capsule: Capsule) -> bool:
|
||||
try:
|
||||
capsule.git.clone(REPO_URL, REPO_DIR)
|
||||
capsule.commands.run(f"cd {REPO_DIR} && git checkout fix/large-operations")
|
||||
print("OK [git clone]")
|
||||
return True
|
||||
except GitCommandError as e:
|
||||
print(f"FAIL [git clone]: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def build_rust(capsule: Capsule) -> bool:
|
||||
if run(capsule, f"mkdir -p {REPO_DIR}/builds") != 0:
|
||||
return False
|
||||
|
||||
# result = capsule.commands.run("file --version")
|
||||
# print(result.stdout)
|
||||
# result = capsule.commands.run(
|
||||
# 'git rev-parse --short HEAD 2>/dev/null || echo "unknown"'
|
||||
# )
|
||||
# commit = result.stdout
|
||||
|
||||
# run(capsule, f"mkdir -p {REPO_DIR}/builds")
|
||||
# result = capsule.commands.run("which musl-gcc")
|
||||
# print(result.stdout)
|
||||
|
||||
handle = capsule.commands.run(
|
||||
"make build-envd",
|
||||
background=True,
|
||||
cwd=REPO_DIR,
|
||||
envs={"PATH": RUST_PATH},
|
||||
)
|
||||
print(f"rust build started (pid={handle.pid}), streaming output...")
|
||||
|
||||
exit_code = 0
|
||||
for event in capsule.commands.connect(handle.pid):
|
||||
if isinstance(event, StreamStdoutEvent):
|
||||
print(event.data, end="")
|
||||
elif isinstance(event, StreamStderrEvent):
|
||||
print(event.data, end="", file=sys.stderr)
|
||||
elif isinstance(event, StreamExitEvent):
|
||||
exit_code = event.exit_code
|
||||
|
||||
if exit_code != 0:
|
||||
print(f"FAIL [rust build]: exit={exit_code}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
print("OK [rust build]")
|
||||
|
||||
# if (
|
||||
# run(
|
||||
# capsule,
|
||||
# f"cp {REPO_DIR}/envd-rs/target/x86_64-unknown-linux-musl/release/envd {REPO_DIR}/builds/envd",
|
||||
# envs={"BIN_DIR": REPO_DIR},
|
||||
# )
|
||||
# != 0
|
||||
# ):
|
||||
# return False
|
||||
|
||||
# result = capsule.commands.run(f"readelf -d {REPO_DIR}/builds/envd 2>&1")
|
||||
# print(result.stdout, end="")
|
||||
# if result.stderr:
|
||||
# print(result.stderr, end="", file=sys.stderr)
|
||||
# result = capsule.commands.run(f"file {REPO_DIR}/builds/envd 2>&1")
|
||||
# print(result.stdout)
|
||||
return True
|
||||
|
||||
|
||||
def download_artifacts(capsule: Capsule) -> bool:
|
||||
version = read_envd_version(capsule)
|
||||
remote_path = f"{REPO_DIR}/builds/envd"
|
||||
local_dir = os.path.normpath(BUILDS_DIR)
|
||||
local_name = f"envd-{version}"
|
||||
local_path = os.path.join(local_dir, local_name)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
print(f"Downloading envd as {local_name}...")
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in capsule.files.download_stream(remote_path):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"OK [download {local_name}]")
|
||||
return True
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
|
||||
print(f"Capsule: {capsule.capsule_id}")
|
||||
if not install_rust(capsule):
|
||||
sys.exit(1)
|
||||
if not clone_repo(capsule):
|
||||
sys.exit(1)
|
||||
if not build_rust(capsule):
|
||||
sys.exit(1)
|
||||
if not download_artifacts(capsule):
|
||||
sys.exit(1)
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
104
.woodpecker/scripts/publish_github.py
Normal file
104
.woodpecker/scripts/publish_github.py
Normal file
@ -0,0 +1,104 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
GITHUB_REPO = "R3dRum92/wrenn-releases"
|
||||
GITHUB_API = "https://api.github.com"
|
||||
GITHUB_UPLOADS = "https://uploads.github.com"
|
||||
BUILDS_DIR = "builds"
|
||||
VERSION_FILE = "VERSION_CP"
|
||||
NOTES_FILE = os.path.join(".woodpecker", "release_notes.md")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
token = os.environ["GITHUB_TOKEN"]
|
||||
|
||||
with open(VERSION_FILE) as f:
|
||||
version = f.read().strip()
|
||||
tag = f"v{version}"
|
||||
|
||||
release_notes = ""
|
||||
if os.path.exists(NOTES_FILE):
|
||||
with open(NOTES_FILE) as f:
|
||||
release_notes = f.read()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
client = httpx.Client(headers=headers, timeout=60)
|
||||
|
||||
print(f"Creating GitHub release for {tag}...")
|
||||
resp = client.post(
|
||||
f"{GITHUB_API}/repos/{GITHUB_REPO}/releases",
|
||||
json={
|
||||
"tag_name": tag,
|
||||
"name": tag,
|
||||
"body": release_notes,
|
||||
"draft": False,
|
||||
"prerelease": False,
|
||||
},
|
||||
)
|
||||
if resp.status_code == 422:
|
||||
print(f"WARN [create release]: release for {tag} already exists, skipping")
|
||||
data = resp.json()
|
||||
errors = data.get("errors", [])
|
||||
if errors:
|
||||
existing_url = errors[0].get("documentation_url", "")
|
||||
print(f" See: {existing_url}")
|
||||
client.close()
|
||||
return
|
||||
if resp.status_code != 201:
|
||||
print(f"FAIL [create release]: {resp.status_code} {resp.text}", file=sys.stderr)
|
||||
client.close()
|
||||
sys.exit(1)
|
||||
|
||||
release_data = resp.json()
|
||||
release_id = release_data["id"]
|
||||
release_url = release_data.get("html_url", "")
|
||||
print(f"OK [create release] id={release_id}")
|
||||
|
||||
builds_path = Path(BUILDS_DIR)
|
||||
if not builds_path.exists():
|
||||
print(f"No {BUILDS_DIR}/ directory found, skipping asset upload")
|
||||
client.close()
|
||||
print(f"Release published: {release_url}")
|
||||
return
|
||||
|
||||
upload_headers = {
|
||||
**headers,
|
||||
"Content-Type": "application/octet-stream",
|
||||
}
|
||||
|
||||
for artifact in sorted(builds_path.iterdir()):
|
||||
if artifact.is_dir():
|
||||
continue
|
||||
print(f"Uploading {artifact.name}...")
|
||||
|
||||
with open(artifact, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
resp = client.post(
|
||||
f"{GITHUB_UPLOADS}/repos/{GITHUB_REPO}/releases/{release_id}/assets",
|
||||
params={"name": artifact.name},
|
||||
headers=upload_headers,
|
||||
content=data,
|
||||
)
|
||||
if resp.status_code != 201:
|
||||
print(
|
||||
f"WARN [upload {artifact.name}]: {resp.status_code} {resp.text}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(f"OK [upload {artifact.name}]")
|
||||
|
||||
client.close()
|
||||
print(f"Release published: {release_url}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
246
.woodpecker/scripts/release_notes.py
Normal file
246
.woodpecker/scripts/release_notes.py
Normal file
@ -0,0 +1,246 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
|
||||
from wrenn import Capsule
|
||||
|
||||
REPO_URL = "https://git.omukk.dev/tksadik92/wrenn-releases.git"
|
||||
REPO_DIR = "/opt/wrenn-releases"
|
||||
CAPSULE_OUTPUT = "/tmp/release_notes.md"
|
||||
LOCAL_OUTPUT = os.path.join(os.path.dirname(__file__), "..", "release_notes.md")
|
||||
|
||||
# Default starting configuration
|
||||
ZHIPU_API_KEY = os.environ.get("ZHIPU_API_KEY", "")
|
||||
if ZHIPU_API_KEY:
|
||||
DEFAULT_MODEL = "zhipuai-coding-plan/glm-5.1"
|
||||
else:
|
||||
DEFAULT_MODEL = "opencode/minimax-m2.5-free"
|
||||
|
||||
RELEASE_NOTES_EXAMPLE = """
|
||||
## What's new
|
||||
Sandbox HTTP proxying, terminal reliability, and auth robustness improvements.
|
||||
|
||||
### Proxy
|
||||
- Fixed redirect loops for apps served inside sandboxes (Python HTTP server, Jupyter, etc.)
|
||||
- Proxy traffic no longer interferes with terminal and exec connections
|
||||
- Services that take a moment to start up inside a sandbox are now retried instead of immediately failing
|
||||
|
||||
### Terminal (PTY)
|
||||
- Terminal input is no longer blocked by slow network conditions — fast typing no longer causes timeouts or disconnects
|
||||
- Input bursts are coalesced into fewer round trips — lower latency under fast typing
|
||||
|
||||
### Authentication
|
||||
- WebSocket connections now authenticate correctly for both SDK clients (header-based) and browser clients (message-based)
|
||||
|
||||
### Bug Fixes
|
||||
- Fixed crash in envd when a process exits without a PTY
|
||||
- Fixed goroutine leak on sandbox pause
|
||||
|
||||
### Others
|
||||
- Version bump
|
||||
""".strip()
|
||||
|
||||
|
||||
def run(capsule: Capsule, cmd: str, cwd: str | None = None, timeout: int = 30) -> int:
|
||||
result = capsule.commands.run(cmd, cwd=cwd, timeout=timeout)
|
||||
if result.exit_code != 0:
|
||||
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
|
||||
if result.stderr:
|
||||
print(result.stderr.strip(), file=sys.stderr)
|
||||
return result.exit_code
|
||||
print(f"OK [{cmd.split()[0]}]")
|
||||
return 0
|
||||
|
||||
|
||||
def get_tags(capsule: Capsule) -> tuple[str, str | None]:
|
||||
result = capsule.commands.run(
|
||||
f"cd {REPO_DIR} && git tag --sort=-version:refname",
|
||||
cwd=REPO_DIR,
|
||||
timeout=30,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
print(f"FAIL [git tag]: {result.stderr}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
tags = [t for t in result.stdout.strip().split("\n") if t]
|
||||
if not tags:
|
||||
print("No tags found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
current_tag = tags[0]
|
||||
previous_tag = tags[1] if len(tags) > 1 else None
|
||||
print(f"Current tag: {current_tag}")
|
||||
print(f"Previous tag: {previous_tag}")
|
||||
return current_tag, previous_tag
|
||||
|
||||
|
||||
def get_git_context(
|
||||
capsule: Capsule, current_tag: str, previous_tag: str | None
|
||||
) -> tuple[str, str]:
|
||||
if previous_tag:
|
||||
# FIX: Removed '-n 2' to ensure we grab ALL commits between the two tags
|
||||
log_cmd = f"cd {REPO_DIR} && git log {previous_tag}..{current_tag} --pretty=format:'%s (%h)'"
|
||||
else:
|
||||
# Fallback to limit log size if this is the very first tag in the repo
|
||||
log_cmd = (
|
||||
f"cd {REPO_DIR} && git log {current_tag} --pretty=format:'%s (%h)' -n 50"
|
||||
)
|
||||
|
||||
log_result = capsule.commands.run(log_cmd, cwd=REPO_DIR, timeout=30)
|
||||
if log_result.exit_code != 0:
|
||||
print(f"FAIL [git log]: {log_result.stderr}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# git diff natively compares the entire tree state between tags
|
||||
if previous_tag:
|
||||
diff_cmd = f"cd {REPO_DIR} && git diff {previous_tag}..{current_tag} --stat"
|
||||
else:
|
||||
diff_cmd = f"cd {REPO_DIR} && git show {current_tag} --stat"
|
||||
|
||||
diff_result = capsule.commands.run(diff_cmd, cwd=REPO_DIR, timeout=30)
|
||||
if diff_result.exit_code != 0:
|
||||
print(f"FAIL [git diff]: {diff_result.stderr}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return log_result.stdout.strip(), diff_result.stdout.strip()
|
||||
|
||||
|
||||
def generate_release_notes(
|
||||
capsule: Capsule,
|
||||
current_tag: str,
|
||||
git_log: str,
|
||||
git_diff: str,
|
||||
output_path: str,
|
||||
model: str,
|
||||
) -> None:
|
||||
prompt = (
|
||||
f"You are writing release notes for version {current_tag} of a software project.\n\n"
|
||||
f"Here is what changed between the previous version and this one:\n\n"
|
||||
f"Commit messages:\n{git_log}\n\n"
|
||||
f"Files and areas that changed:\n{git_diff}\n\n"
|
||||
f"Write the release notes in plain, friendly language that any developer can understand "
|
||||
f"without deep knowledge of the codebase. Avoid jargon like 'goroutine', 'PTY', 'envd', "
|
||||
f"or internal function names — describe what the change means for the user instead. "
|
||||
f"Group related changes under headings that reflect what actually changed. "
|
||||
f"Only include sections that are relevant to these specific changes. "
|
||||
f"Start with a short one-line summary of what this release is about. "
|
||||
f"Keep each bullet point to one clear sentence.\n\n"
|
||||
f"Here is an example of the style to aim for — not a template to copy:\n\n"
|
||||
f"{RELEASE_NOTES_EXAMPLE}\n\n"
|
||||
f"You MUST start the document with `## What's New`\n"
|
||||
f"The very next line MUST be a single short summary sentence.\n"
|
||||
f"Output only the markdown. No intro, no explanation."
|
||||
f"CRITICAL: Do not output any conversational filler, acknowledgments, or thoughts "
|
||||
f"like 'Let me look at the changes'. Output absolutely nothing except the final markdown."
|
||||
)
|
||||
|
||||
prompt_b64 = base64.b64encode(prompt.encode("utf-8")).decode("utf-8")
|
||||
|
||||
write_prompt_cmd = f"echo '{prompt_b64}' | base64 -d > /tmp/oc_prompt.txt"
|
||||
|
||||
result = capsule.commands.run(
|
||||
write_prompt_cmd,
|
||||
cwd=REPO_DIR,
|
||||
timeout=10,
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
print(f"FAIL [write prompt]: {result.stderr}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# FIX: Wrapper function to handle execution and authentication dynamically
|
||||
def run_opencode_with_model(target_model: str) -> int:
|
||||
env = ""
|
||||
if "zhipu" in target_model.lower():
|
||||
env = f"ZHIPU_API_KEY={os.environ.get('ZHIPU_API_KEY', '')}"
|
||||
|
||||
cmd = (
|
||||
f"{env} "
|
||||
f"~/.opencode/bin/opencode run "
|
||||
f'"Read the attached file and generate the release notes. Output ONLY markdown." '
|
||||
f"--model {target_model} "
|
||||
f"--file /tmp/oc_prompt.txt "
|
||||
f"> {output_path}"
|
||||
)
|
||||
|
||||
cmd_result = capsule.commands.run(cmd, cwd=REPO_DIR, timeout=120)
|
||||
|
||||
if cmd_result.exit_code != 0:
|
||||
print(
|
||||
f"FAIL [opencode via {target_model}]: exit={cmd_result.exit_code}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"STDOUT:\n{cmd_result.stdout}", file=sys.stderr)
|
||||
print(f"STDERR:\n{cmd_result.stderr}", file=sys.stderr)
|
||||
|
||||
return cmd_result.exit_code
|
||||
|
||||
# First attempt with the target model
|
||||
exit_status = run_opencode_with_model(model)
|
||||
|
||||
# FIX: Catch failures (like Zhipu rate limits) and fallback to MiniMax
|
||||
if exit_status != 0:
|
||||
if "zhipu" in model.lower():
|
||||
print(
|
||||
"\n[!] Zhipu AI failed (likely rate-limited). Falling back to MiniMax...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
fallback_model = "opencode/minimax-m2.5-free"
|
||||
exit_status = run_opencode_with_model(fallback_model)
|
||||
if exit_status != 0:
|
||||
print("FAIL: Fallback model also failed. Exiting.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
result = capsule.commands.run(f"cat {output_path}")
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(result.stderr)
|
||||
|
||||
print(f"OK [opencode] release notes written to {output_path}")
|
||||
|
||||
|
||||
def download_release_notes(capsule: Capsule) -> None:
|
||||
local_path = os.path.normpath(LOCAL_OUTPUT)
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
print(f"Downloading release notes from capsule...")
|
||||
content = capsule.files.read_bytes(CAPSULE_OUTPUT)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"OK [download] release notes → {local_path}")
|
||||
print(content.decode("utf-8", errors="replace"))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
model = os.environ.get("OPENCODE_MODEL", DEFAULT_MODEL)
|
||||
|
||||
with Capsule(template="opencode", wait=True, vcpus=2, memory_mb=2048) as capsule:
|
||||
print(f"Capsule: {capsule.capsule_id}")
|
||||
|
||||
capsule.git.clone(
|
||||
REPO_URL,
|
||||
REPO_DIR,
|
||||
username="tksadik92",
|
||||
)
|
||||
print("OK [git clone]")
|
||||
|
||||
current_tag, previous_tag = get_tags(capsule)
|
||||
git_log, git_diff = get_git_context(capsule, current_tag, previous_tag)
|
||||
|
||||
# Note: This simply creates the directory string safely
|
||||
output_path = os.path.normpath(CAPSULE_OUTPUT)
|
||||
|
||||
generate_release_notes(
|
||||
capsule,
|
||||
current_tag,
|
||||
git_log,
|
||||
git_diff,
|
||||
output_path,
|
||||
model,
|
||||
)
|
||||
|
||||
download_release_notes(capsule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
96
CLAUDE.md
96
CLAUDE.md
@ -14,7 +14,7 @@ All commands go through the Makefile. Never use raw `go build` or `go run`.
|
||||
make build # Build all binaries → builds/
|
||||
make build-cp # Control plane only
|
||||
make build-agent # Host agent only
|
||||
make build-envd # envd static binary (verified statically linked)
|
||||
make build-envd # envd static binary (Rust, musl, verified statically linked)
|
||||
make build-frontend # SvelteKit dashboard → frontend/build/ (served by Caddy)
|
||||
|
||||
make dev # Full local dev: infra + migrate + control plane
|
||||
@ -23,13 +23,13 @@ make dev-down # Stop dev infra
|
||||
make dev-cp # Control plane with hot reload (if air installed)
|
||||
make dev-frontend # Vite dev server with HMR (port 5173)
|
||||
make dev-agent # Host agent (sudo required)
|
||||
make dev-envd # envd in TCP debug mode
|
||||
make dev-envd # envd in debug mode (--isnotfc, port 49983)
|
||||
|
||||
make check # fmt + vet + lint + test (CI order)
|
||||
make test # Unit tests: go test -race -v ./internal/...
|
||||
make test-integration # Integration tests (require host agent + Firecracker)
|
||||
make fmt # gofmt both modules
|
||||
make vet # go vet both modules
|
||||
make fmt # gofmt
|
||||
make vet # go vet
|
||||
make lint # golangci-lint
|
||||
|
||||
make migrate-up # Apply pending migrations
|
||||
@ -38,8 +38,8 @@ make migrate-create name=xxx # Scaffold new goose migration (never create manua
|
||||
make migrate-reset # Drop + re-apply all
|
||||
|
||||
make generate # Proto (buf) + sqlc codegen
|
||||
make proto # buf generate for all proto dirs
|
||||
make tidy # go mod tidy both modules
|
||||
make proto # buf generate for proto dirs
|
||||
make tidy # go mod tidy
|
||||
```
|
||||
|
||||
Run a single test: `go test -race -v -run TestName ./internal/path/...`
|
||||
@ -50,15 +50,15 @@ Run a single test: `go test -race -v -run TestName ./internal/path/...`
|
||||
User SDK → HTTPS/WS → Control Plane → Connect RPC → Host Agent → HTTP/Connect RPC over TAP → envd (inside VM)
|
||||
```
|
||||
|
||||
**Three binaries, two Go modules:**
|
||||
**Three binaries:**
|
||||
|
||||
| Binary | Module | Entry point | Runs as |
|
||||
|--------|--------|-------------|---------|
|
||||
| wrenn-cp | `git.omukk.dev/wrenn/wrenn` | `cmd/control-plane/main.go` | Unprivileged |
|
||||
| wrenn-agent | `git.omukk.dev/wrenn/wrenn` | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
|
||||
| envd | `git.omukk.dev/wrenn/wrenn/envd` (standalone `envd/go.mod`) | `envd/main.go` | PID 1 inside guest VM |
|
||||
| Binary | Language | Entry point | Runs as |
|
||||
|--------|----------|-------------|---------|
|
||||
| wrenn-cp | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/control-plane/main.go` | Unprivileged |
|
||||
| wrenn-agent | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
|
||||
| envd | Rust (`envd-rs/`) | `envd-rs/src/main.rs` | PID 1 inside guest VM |
|
||||
|
||||
envd is a **completely independent Go module**. It is never imported by the main module. The only connection is the protobuf contract. It compiles to a static binary baked into rootfs images.
|
||||
envd is a standalone Rust binary (Tokio + Axum + connectrpc-rs). It is completely independent from the Go module — the only connection is the protobuf contract. It compiles to a statically linked musl binary baked into rootfs images.
|
||||
|
||||
**Key architectural invariant:** The host agent is **stateful** (in-memory `boxes` map is the source of truth for running VMs). The control plane is **stateless** (all persistent state in PostgreSQL). The reconciler (`internal/api/reconciler.go`) bridges the gap — it periodically compares DB records against the host agent's live state and marks orphaned sandboxes as "stopped".
|
||||
|
||||
@ -99,13 +99,17 @@ Startup (`cmd/host-agent/main.go`) wires: root/capabilities check → enable IP
|
||||
|
||||
### envd (Guest Agent)
|
||||
|
||||
**Module:** `envd/` with its own `go.mod` (`git.omukk.dev/wrenn/wrenn/envd`)
|
||||
**Directory:** `envd-rs/` — standalone Rust crate
|
||||
|
||||
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd). Extracted from E2B (Apache 2.0), with shared packages internalized into `envd/internal/shared/`. Listens on TCP `0.0.0.0:49983`.
|
||||
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd via tini). Built with `cargo build --release --target x86_64-unknown-linux-musl`. Listens on TCP `0.0.0.0:49983`.
|
||||
|
||||
- **ProcessService**: start processes, stream stdout/stderr, signal handling, PTY support
|
||||
- **FilesystemService**: stat/list/mkdir/move/remove/watch files
|
||||
- **Health**: GET `/health`
|
||||
- **Stack**: Tokio (async runtime) + Axum (HTTP) + connectrpc-rs (Connect protocol RPC)
|
||||
- **ProcessService** (Connect RPC): start/connect/list/signal processes, stream stdout/stderr, PTY support
|
||||
- **FilesystemService** (Connect RPC): stat/list/mkdir/move/remove/watch files
|
||||
- **HTTP endpoints**: GET `/health`, GET `/metrics`, POST `/init`, POST `/snapshot/prepare`, GET/POST `/files`
|
||||
- **Proto codegen**: `connectrpc-build` compiles `proto/envd/*.proto` at `cargo build` time via `build.rs` — no committed stubs
|
||||
- **Build**: `make build-envd` → static musl binary in `builds/envd`
|
||||
- **Dev**: `make dev-envd` → `cargo run -- --isnotfc --port 49983`
|
||||
|
||||
### Dashboard (Frontend)
|
||||
|
||||
@ -185,17 +189,16 @@ Routes defined in `internal/api/server.go`, handlers in `internal/api/handlers_*
|
||||
|
||||
### Proto (Connect RPC)
|
||||
|
||||
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate. Three `buf.gen.yaml` files control output:
|
||||
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate Go stubs. Two `buf.gen.yaml` files control Go output:
|
||||
|
||||
| buf.gen.yaml location | Generates to | Used by |
|
||||
|---|---|---|
|
||||
| `proto/envd/buf.gen.yaml` | `proto/envd/gen/` | Main module (host agent's envd client) |
|
||||
| `proto/hostagent/buf.gen.yaml` | `proto/hostagent/gen/` | Main module (control plane ↔ host agent) |
|
||||
| `envd/spec/buf.gen.yaml` | `envd/internal/services/spec/` | envd module (guest agent server) |
|
||||
|
||||
The envd `buf.gen.yaml` reads from `../../proto/envd/` (same source protos) but generates into envd's own module. This means the same `.proto` files produce two independent sets of Go stubs — one for each Go module.
|
||||
The Rust envd (`envd-rs/`) generates its own protobuf stubs at `cargo build` time via `connectrpc-build` in `envd-rs/build.rs`, reading from the same `proto/envd/*.proto` sources. No committed Rust stubs — they live in `OUT_DIR`.
|
||||
|
||||
To add a new RPC method: edit the `.proto` file → `make proto` → implement the handler on both sides.
|
||||
To add a new RPC method: edit the `.proto` file → `make proto` (Go stubs) → rebuild envd-rs (Rust stubs generated automatically) → implement the handler on both sides.
|
||||
|
||||
### sqlc
|
||||
|
||||
@ -206,7 +209,7 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
|
||||
## Key Technical Decisions
|
||||
|
||||
- **Connect RPC** (not gRPC) for all RPC communication between components
|
||||
- **Buf + protoc-gen-connect-go** for code generation (not protoc-gen-go-grpc)
|
||||
- **Buf + protoc-gen-connect-go** for Go code generation; **connectrpc-build** for Rust code generation in envd
|
||||
- **Raw Firecracker HTTP API** via Unix socket (not firecracker-go-sdk Machine type)
|
||||
- **TAP networking** (not vsock) for host-to-envd communication
|
||||
- **Device-mapper snapshots** for rootfs CoW — shared read-only loop device per base template, per-sandbox sparse CoW file, Firecracker gets `/dev/mapper/wrenn-{id}`
|
||||
@ -218,19 +221,15 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
|
||||
|
||||
- **Go style**: `gofmt`, `go vet`, `context.Context` everywhere, errors wrapped with `fmt.Errorf("action: %w", err)`, `slog` for logging, no global state
|
||||
- **Naming**: Sandbox IDs `sb-` + 8 hex, API keys `wrn_` + 32 chars, Host IDs `host-` + 8 hex
|
||||
- **Dependencies**: Use `go get` to add deps, never hand-edit go.mod. For envd deps: `cd envd && go get ...` (separate module)
|
||||
- **Dependencies**: Use `go get` to add Go deps, never hand-edit go.mod. For envd-rs deps: edit `envd-rs/Cargo.toml`
|
||||
- **Generated code**: Always commit generated code (proto stubs, sqlc). Never add generated code to .gitignore
|
||||
- **Migrations**: Always use `make migrate-create name=xxx`, never create migration files manually
|
||||
- **Testing**: Table-driven tests for handlers and state machine transitions
|
||||
|
||||
### Two-module gotcha
|
||||
|
||||
The main module (`go.mod`) and envd (`envd/go.mod`) are fully independent. `make tidy`, `make fmt`, `make vet` already operate on both. But when adding dependencies manually, remember to target the correct module (`cd envd && go get ...` for envd deps). `make proto` also generates stubs for both modules from the same proto sources.
|
||||
|
||||
## Rootfs & Guest Init
|
||||
|
||||
- **wrenn-init** (`images/wrenn-init.sh`): the PID 1 init script baked into every rootfs. Mounts virtual filesystems, sets hostname, writes `/etc/resolv.conf`, then execs envd.
|
||||
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-debug-rootfs.sh [rootfs_path]`. This builds envd via `make build-envd`, mounts the rootfs image, copies in the new binaries, and unmounts. Defaults to `/var/lib/wrenn/images/minimal.ext4`.
|
||||
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-minimal-rootfs.sh`. This builds envd via `make build-envd` (Rust → static musl binary), mounts the rootfs image, copies in the new binaries, and unmounts. Defaults to `/var/lib/wrenn/images/minimal.ext4`.
|
||||
- Rootfs images are minimal debootstrap — no systemd, no coreutils beyond busybox. Use `/bin/sh -c` for shell builtins inside the guest.
|
||||
|
||||
## Fixed Paths (on host machine)
|
||||
@ -372,3 +371,42 @@ All values are CSS custom properties in `frontend/src/app.css`.
|
||||
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
|
||||
|
||||
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
||||
|
||||
<!-- code-review-graph MCP tools -->
|
||||
## MCP Tools: code-review-graph
|
||||
|
||||
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||
you structural context (callers, dependents, test coverage) that file
|
||||
scanning cannot.
|
||||
|
||||
### When to use graph tools FIRST
|
||||
|
||||
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||
|
||||
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||
|
||||
### Key Tools
|
||||
|
||||
| Tool | Use when |
|
||||
|------|----------|
|
||||
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||
| `get_impact_radius` | Understanding blast radius of a change |
|
||||
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||
| `refactor_tool` | Planning renames, finding dead code |
|
||||
|
||||
### Workflow
|
||||
|
||||
1. The graph auto-updates on file changes (via hooks).
|
||||
2. Use `detect_changes` for code review.
|
||||
3. Use `get_affected_flows` to understand impact.
|
||||
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||
|
||||
38
Makefile
38
Makefile
@ -2,12 +2,10 @@
|
||||
# Variables
|
||||
# ═══════════════════════════════════════════════════
|
||||
DATABASE_URL ?= postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
|
||||
GOBIN := $(shell pwd)/builds
|
||||
ENVD_DIR := envd
|
||||
BIN_DIR := $(shell pwd)/builds
|
||||
COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
VERSION_CP := $(shell cat VERSION_CP 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
VERSION_AGENT := $(shell cat VERSION_AGENT 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
VERSION_ENVD := $(shell cat envd/VERSION 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
LDFLAGS := -s -w
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
@ -21,16 +19,20 @@ build-frontend:
|
||||
cd frontend && pnpm install --frozen-lockfile && pnpm build
|
||||
|
||||
build-cp:
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-cp ./cmd/control-plane
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-cp ./cmd/control-plane
|
||||
|
||||
build-agent:
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-agent ./cmd/host-agent
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-agent ./cmd/host-agent
|
||||
|
||||
build-envd:
|
||||
cd $(ENVD_DIR) && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags="$(LDFLAGS) -X main.Version=$(VERSION_ENVD) -X main.commitSHA=$(COMMIT)" -o $(GOBIN)/envd .
|
||||
@file $(GOBIN)/envd | grep -q "statically linked" || \
|
||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||
cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl
|
||||
@cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd
|
||||
@readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \
|
||||
readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \
|
||||
! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \
|
||||
{ ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \
|
||||
readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \
|
||||
(echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1)
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Development
|
||||
@ -60,8 +62,7 @@ dev-frontend:
|
||||
cd frontend && pnpm dev --port 5173 --host 0.0.0.0
|
||||
|
||||
dev-envd:
|
||||
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
|
||||
|
||||
cd envd-rs && cargo run -- --isnotfc --port 49983
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Database (goose)
|
||||
@ -94,7 +95,6 @@ generate: proto sqlc
|
||||
proto:
|
||||
cd proto/envd && buf generate
|
||||
cd proto/hostagent && buf generate
|
||||
cd $(ENVD_DIR)/spec && buf generate
|
||||
|
||||
sqlc:
|
||||
sqlc generate
|
||||
@ -106,17 +106,16 @@ sqlc:
|
||||
|
||||
fmt:
|
||||
gofmt -w .
|
||||
cd $(ENVD_DIR) && gofmt -w .
|
||||
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
cd $(ENVD_DIR) && go vet ./...
|
||||
|
||||
test:
|
||||
go test -race -v ./internal/...
|
||||
cd envd-rs && cargo test
|
||||
|
||||
test-integration:
|
||||
go test -race -v -tags=integration ./tests/integration/...
|
||||
@ -125,7 +124,6 @@ test-all: test test-integration
|
||||
|
||||
tidy:
|
||||
go mod tidy
|
||||
cd $(ENVD_DIR) && go mod tidy
|
||||
|
||||
## Run all quality checks in CI order
|
||||
check: fmt vet lint test
|
||||
@ -155,8 +153,8 @@ setup-host:
|
||||
sudo bash scripts/setup-host.sh
|
||||
|
||||
install: build
|
||||
sudo cp $(GOBIN)/wrenn-cp /usr/local/bin/
|
||||
sudo cp $(GOBIN)/wrenn-agent /usr/local/bin/
|
||||
sudo cp $(BIN_DIR)/wrenn-cp /usr/local/bin/
|
||||
sudo cp $(BIN_DIR)/wrenn-agent /usr/local/bin/
|
||||
sudo cp deploy/systemd/*.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
@ -167,7 +165,7 @@ install: build
|
||||
|
||||
clean:
|
||||
rm -rf builds/
|
||||
cd $(ENVD_DIR) && rm -f envd
|
||||
cd envd-rs && cargo clean
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Help
|
||||
@ -183,11 +181,11 @@ help:
|
||||
@echo " make dev-cp Control plane (hot reload if air installed)"
|
||||
@echo " make dev-frontend Vite dev server with HMR (port 5173)"
|
||||
@echo " make dev-agent Host agent (sudo required)"
|
||||
@echo " make dev-envd envd in TCP debug mode"
|
||||
@echo " make dev-envd envd in debug mode (--isnotfc, port 49983)"
|
||||
@echo ""
|
||||
@echo " make build Build all binaries → builds/"
|
||||
@echo " make build-frontend Build SvelteKit dashboard → frontend/build/"
|
||||
@echo " make build-envd Build envd static binary"
|
||||
@echo " make build-envd Build envd static binary (Rust, musl)"
|
||||
@echo ""
|
||||
@echo " make migrate-up Apply migrations"
|
||||
@echo " make migrate-create name=xxx New migration"
|
||||
|
||||
19
NOTICE
19
NOTICE
@ -1,19 +0,0 @@
|
||||
Wrenn Sandbox
|
||||
Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
|
||||
This project includes software derived from the following project:
|
||||
|
||||
Project: e2b infra
|
||||
Repository: https://github.com/e2b-dev/infra
|
||||
|
||||
The following files and directories in this repository contain code derived from the above project:
|
||||
|
||||
- envd/
|
||||
- proto/envd/*.proto
|
||||
- internal/snapshot/
|
||||
- internal/uffd/
|
||||
|
||||
Modifications to this code were made by M/S Omukk.
|
||||
|
||||
Copyright (c) 2023 FoundryLabs, Inc.
|
||||
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
@ -8,6 +8,7 @@ Secure infrastructure for AI
|
||||
- Firecracker binary at `/usr/local/bin/firecracker`
|
||||
- PostgreSQL
|
||||
- Go 1.25+
|
||||
- Rust 1.88+ with `x86_64-unknown-linux-musl` target (`rustup target add x86_64-unknown-linux-musl`)
|
||||
- pnpm (for frontend)
|
||||
- Docker (for dev infra and rootfs builds)
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.0
|
||||
0.1.3
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.3
|
||||
0.1.6
|
||||
|
||||
@ -80,6 +80,25 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Register with the control plane before touching rootfs images. If the
|
||||
// agent can't reach the CP there's no point inflating images (and crashing
|
||||
// afterward would leave them in the expanded state).
|
||||
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: credsFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("host registration failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", creds.HostID)
|
||||
|
||||
// Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M").
|
||||
defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB
|
||||
if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" {
|
||||
@ -128,27 +147,21 @@ func main() {
|
||||
|
||||
mgr := sandbox.New(cfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mgr.StartTTLReaper(ctx)
|
||||
|
||||
// Register with the control plane and start heartbeating.
|
||||
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: credsFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("host registration failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", creds.HostID)
|
||||
|
||||
// httpServer is declared here so the shutdown func can reference it.
|
||||
httpServer := &http.Server{Addr: listenAddr}
|
||||
// ReadTimeout/WriteTimeout are intentionally omitted — they would kill
|
||||
// long-lived Connect RPC streams and WebSocket proxy connections.
|
||||
httpServer := &http.Server{
|
||||
Addr: listenAddr,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
IdleTimeout: 620 * time.Second, // > typical LB upstream timeout (600s)
|
||||
// Disable HTTP/2: empty non-nil map prevents Go from registering
|
||||
// the h2 ALPN token. Connect RPC works over HTTP/1.1; HTTP/2
|
||||
// multiplexing causes HOL blocking when a slow sandbox RPC stalls
|
||||
// the shared connection.
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
|
||||
}
|
||||
|
||||
// mTLS is mandatory — refuse to start without a valid certificate.
|
||||
var certStore hostagent.CertStore
|
||||
@ -193,6 +206,7 @@ func main() {
|
||||
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
||||
|
||||
proxyHandler := hostagent.NewProxyHandler(mgr)
|
||||
mgr.SetOnDestroy(proxyHandler.EvictProxy)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(path, handler)
|
||||
|
||||
@ -22,6 +22,12 @@ RETURNING *;
|
||||
-- name: SetUserAdmin :exec
|
||||
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
|
||||
|
||||
-- name: RevokeUserAdmin :execrows
|
||||
UPDATE users u SET is_admin = false, updated_at = NOW()
|
||||
WHERE u.id = $1
|
||||
AND u.is_admin = true
|
||||
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1;
|
||||
|
||||
-- name: GetAdminUsers :many
|
||||
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;
|
||||
|
||||
|
||||
2
envd-rs/.cargo/config.toml
Normal file
2
envd-rs/.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[target.x86_64-unknown-linux-musl]
|
||||
linker = "musl-gcc"
|
||||
2623
envd-rs/Cargo.lock
generated
Normal file
2623
envd-rs/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
86
envd-rs/Cargo.toml
Normal file
86
envd-rs/Cargo.toml
Normal file
@ -0,0 +1,86 @@
|
||||
[package]
|
||||
name = "envd"
|
||||
version = "0.2.1"
|
||||
edition = "2024"
|
||||
rust-version = "1.88"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# HTTP framework
|
||||
axum = { version = "0.8", features = ["multipart"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
||||
tower-service = "0.3"
|
||||
|
||||
# RPC (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
|
||||
connectrpc = { version = "0.3", features = ["axum"] }
|
||||
buffa-types = { path = "buffa-types-shim" }
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
|
||||
# System metrics
|
||||
sysinfo = "0.33"
|
||||
|
||||
# Unix syscalls
|
||||
nix = { version = "0.30", features = ["fs", "process", "signal", "user", "term", "mount", "ioctl"] }
|
||||
|
||||
# Concurrent map
|
||||
dashmap = "6"
|
||||
|
||||
# Crypto
|
||||
sha2 = "0.10"
|
||||
hmac = "0.12"
|
||||
hex = "0.4"
|
||||
base64 = "0.22"
|
||||
|
||||
# Secure memory
|
||||
zeroize = { version = "1", features = ["derive"] }
|
||||
|
||||
# File watching
|
||||
notify = "7"
|
||||
|
||||
# Compression
|
||||
flate2 = "1"
|
||||
|
||||
# HTTP client (MMDS polling)
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
||||
|
||||
# Directory walking
|
||||
walkdir = "2"
|
||||
|
||||
# Misc
|
||||
libc = "0.2"
|
||||
bytes = "1"
|
||||
http = "1"
|
||||
http-body-util = "0.1"
|
||||
futures = "0.3"
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
subtle = "2"
|
||||
http-body = "1.0.1"
|
||||
buffa = "0.3"
|
||||
async-stream = "0.3.6"
|
||||
mime_guess = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
[build-dependencies]
|
||||
connectrpc-build = "0.3"
|
||||
|
||||
[profile.release]
|
||||
strip = true
|
||||
lto = true
|
||||
opt-level = "z"
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
141
envd-rs/README.md
Normal file
141
envd-rs/README.md
Normal file
@ -0,0 +1,141 @@
|
||||
# envd (Rust)
|
||||
|
||||
Wrenn guest agent daemon — runs as PID 1 inside Firecracker microVMs. Provides process management, filesystem operations, file transfer, port forwarding, and VM lifecycle control over Connect RPC and HTTP.
|
||||
|
||||
Rust rewrite of `envd/` (Go). Drop-in replacement — same wire protocol, same endpoints, same CLI flags.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Rust 1.88+ (required by `connectrpc` 0.3.3)
|
||||
- `protoc` (protobuf compiler, for proto codegen at build time)
|
||||
- `musl-tools` (for static linking)
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install musl-tools protobuf-compiler
|
||||
|
||||
# Rust musl target
|
||||
rustup target add x86_64-unknown-linux-musl
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Static binary (production — what goes into the rootfs)
|
||||
|
||||
```bash
|
||||
cd envd-rs
|
||||
ENVD_COMMIT=$(git rev-parse --short HEAD) \
|
||||
cargo build --release --target x86_64-unknown-linux-musl
|
||||
```
|
||||
|
||||
Output: `target/x86_64-unknown-linux-musl/release/envd`
|
||||
|
||||
Verify static linking:
|
||||
|
||||
```bash
|
||||
file target/x86_64-unknown-linux-musl/release/envd
|
||||
# should say: "statically linked"
|
||||
|
||||
ldd target/x86_64-unknown-linux-musl/release/envd
|
||||
# should say: "not a dynamic executable"
|
||||
```
|
||||
|
||||
### Debug binary (dev machine, dynamically linked)
|
||||
|
||||
```bash
|
||||
cd envd-rs
|
||||
cargo build
|
||||
```
|
||||
|
||||
Run locally (outside a VM):
|
||||
|
||||
```bash
|
||||
./target/debug/envd --isnotfc --port 49983
|
||||
```
|
||||
|
||||
### Via Makefile (from repo root)
|
||||
|
||||
```bash
|
||||
make build-envd # static musl release build
|
||||
make build-envd-go # Go version (for comparison)
|
||||
```
|
||||
|
||||
## CLI Flags
|
||||
|
||||
```
|
||||
--port <PORT> Listen port [default: 49983]
|
||||
--isnotfc Not running inside Firecracker (disables MMDS, cgroups)
|
||||
--version Print version and exit
|
||||
--commit Print git commit and exit
|
||||
--cmd <CMD> Spawn a process at startup (e.g. --cmd "/bin/bash")
|
||||
--cgroup-root <PATH> Cgroup v2 root [default: /sys/fs/cgroup]
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### HTTP
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|---------------------|--------------------------------------|
|
||||
| GET | `/health` | Health check, triggers post-restore |
|
||||
| GET | `/metrics` | System metrics (CPU, memory, disk) |
|
||||
| GET | `/envs` | Current environment variables |
|
||||
| POST | `/init` | Host agent init (token, env, mounts) |
|
||||
| POST | `/snapshot/prepare` | Quiesce before Firecracker snapshot |
|
||||
| GET | `/files` | Download file (gzip, range support) |
|
||||
| POST | `/files` | Upload file(s) via multipart |
|
||||
|
||||
### Connect RPC (same port)
|
||||
|
||||
| Service | RPCs |
|
||||
|------------|-------------------------------------------------------------------------|
|
||||
| Process | List, Start, Connect, Update, StreamInput, SendInput, SendSignal, CloseStdin |
|
||||
| Filesystem | Stat, MakeDir, Move, ListDir, Remove, WatchDir, CreateWatcher, GetWatcherEvents, RemoveWatcher |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
42 files, ~4200 LOC Rust
|
||||
Binary: ~4 MB (stripped, LTO, musl static)
|
||||
|
||||
src/
|
||||
├── main.rs # Entry point, CLI, server setup
|
||||
├── state.rs # Shared AppState
|
||||
├── config.rs # Constants
|
||||
├── conntracker.rs # TCP connection tracking for snapshot/restore
|
||||
├── execcontext.rs # Default user/workdir/env
|
||||
├── logging.rs # tracing-subscriber (JSON or pretty)
|
||||
├── util.rs # AtomicMax
|
||||
├── auth/ # Token, signing, middleware
|
||||
├── crypto/ # SHA-256, SHA-512, HMAC
|
||||
├── host/ # MMDS polling, system metrics
|
||||
├── http/ # Axum handlers (health, init, snapshot, files, encoding)
|
||||
├── permissions/ # Path resolution, user lookup, chown
|
||||
├── rpc/ # Connect RPC services
|
||||
│ ├── pb.rs # Generated proto types
|
||||
│ ├── process_*.rs # Process service + handler (PTY, pipe, broadcast)
|
||||
│ ├── filesystem_*.rs # Filesystem service (stat, list, watch, mkdir, move, remove)
|
||||
│ └── entry.rs # EntryInfo builder
|
||||
├── port/ # Port subsystem
|
||||
│ ├── conn.rs # /proc/net/tcp parser
|
||||
│ ├── scanner.rs # Periodic TCP port scanner
|
||||
│ ├── forwarder.rs # socat-based port forwarding
|
||||
│ └── subsystem.rs # Lifecycle (start/stop/restart)
|
||||
└── cgroups/ # Cgroup v2 manager (pty/user/socat groups)
|
||||
```
|
||||
|
||||
## Updating the rootfs
|
||||
|
||||
After building the static binary, copy it into the rootfs:
|
||||
|
||||
```bash
|
||||
bash scripts/update-debug-rootfs.sh [rootfs_path]
|
||||
```
|
||||
|
||||
Or manually:
|
||||
|
||||
```bash
|
||||
sudo mount -o loop /var/lib/wrenn/images/minimal.ext4 /mnt
|
||||
sudo cp target/x86_64-unknown-linux-musl/release/envd /mnt/usr/bin/envd
|
||||
sudo umount /mnt
|
||||
```
|
||||
12
envd-rs/buffa-types-shim/Cargo.toml
Normal file
12
envd-rs/buffa-types-shim/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "buffa-types"
|
||||
version = "0.3.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
buffa = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
||||
[build-dependencies]
|
||||
connectrpc-build = "0.3"
|
||||
9
envd-rs/buffa-types-shim/build.rs
Normal file
9
envd-rs/buffa-types-shim/build.rs
Normal file
@ -0,0 +1,9 @@
|
||||
fn main() {
|
||||
connectrpc_build::Config::new()
|
||||
.files(&["/usr/include/google/protobuf/timestamp.proto"])
|
||||
.includes(&["/usr/include"])
|
||||
.include_file("_types.rs")
|
||||
.emit_register_fn(false)
|
||||
.compile()
|
||||
.unwrap();
|
||||
}
|
||||
6
envd-rs/buffa-types-shim/src/lib.rs
Normal file
6
envd-rs/buffa-types-shim/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
|
||||
|
||||
use ::buffa;
|
||||
use ::serde;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/_types.rs"));
|
||||
11
envd-rs/build.rs
Normal file
11
envd-rs/build.rs
Normal file
@ -0,0 +1,11 @@
|
||||
fn main() {
|
||||
connectrpc_build::Config::new()
|
||||
.files(&[
|
||||
"../proto/envd/process.proto",
|
||||
"../proto/envd/filesystem.proto",
|
||||
])
|
||||
.includes(&["../proto/envd", "/usr/include"])
|
||||
.include_file("_connectrpc.rs")
|
||||
.compile()
|
||||
.unwrap();
|
||||
}
|
||||
3
envd-rs/rust-toolchain.toml
Normal file
3
envd-rs/rust-toolchain.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "stable"
|
||||
targets = ["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl"]
|
||||
56
envd-rs/src/auth/middleware.rs
Normal file
56
envd-rs/src/auth/middleware.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::auth::token::SecureToken;
|
||||
|
||||
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
|
||||
|
||||
/// Paths excluded from general token auth.
|
||||
/// Format: "METHOD/path"
|
||||
const AUTH_EXCLUDED: &[&str] = &[
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
"POST/snapshot/prepare",
|
||||
];
|
||||
|
||||
/// Axum middleware that checks X-Access-Token header.
|
||||
pub async fn auth_layer(
|
||||
request: Request,
|
||||
next: Next,
|
||||
access_token: Arc<SecureToken>,
|
||||
) -> Response {
|
||||
if access_token.is_set() {
|
||||
let method = request.method().as_str();
|
||||
let path = request.uri().path();
|
||||
let key = format!("{method}{path}");
|
||||
|
||||
let is_excluded = AUTH_EXCLUDED.iter().any(|p| *p == key);
|
||||
|
||||
let header_val = request
|
||||
.headers()
|
||||
.get(ACCESS_TOKEN_HEADER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if !access_token.equals(header_val) && !is_excluded {
|
||||
tracing::error!("unauthorized access attempt");
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
axum::Json(json!({
|
||||
"code": 401,
|
||||
"message": "unauthorized access, please provide a valid access token or method signing if supported"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
3
envd-rs/src/auth/mod.rs
Normal file
3
envd-rs/src/auth/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod token;
|
||||
pub mod signing;
|
||||
pub mod middleware;
|
||||
210
envd-rs/src/auth/signing.rs
Normal file
210
envd-rs/src/auth/signing.rs
Normal file
@ -0,0 +1,210 @@
|
||||
use crate::auth::token::SecureToken;
|
||||
use crate::crypto;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
pub const READ_OPERATION: &str = "read";
|
||||
pub const WRITE_OPERATION: &str = "write";
|
||||
|
||||
/// Generate a v1 signature: `v1_{sha256_base64(path:operation:username:token[:expiration])}`
|
||||
pub fn generate_signature(
|
||||
token: &SecureToken,
|
||||
path: &str,
|
||||
username: &str,
|
||||
operation: &str,
|
||||
expiration: Option<i64>,
|
||||
) -> Result<String, &'static str> {
|
||||
let mut token_bytes = token.bytes().ok_or("access token is not set")?;
|
||||
|
||||
let payload = match expiration {
|
||||
Some(exp) => format!(
|
||||
"{}:{}:{}:{}:{}",
|
||||
path,
|
||||
operation,
|
||||
username,
|
||||
String::from_utf8_lossy(&token_bytes),
|
||||
exp
|
||||
),
|
||||
None => format!(
|
||||
"{}:{}:{}:{}",
|
||||
path,
|
||||
operation,
|
||||
username,
|
||||
String::from_utf8_lossy(&token_bytes),
|
||||
),
|
||||
};
|
||||
|
||||
token_bytes.zeroize();
|
||||
|
||||
let hash = crypto::sha256::hash_without_prefix(payload.as_bytes());
|
||||
Ok(format!("v1_{hash}"))
|
||||
}
|
||||
|
||||
/// Validate a request's signing. Returns Ok(()) if valid.
|
||||
pub fn validate_signing(
|
||||
token: &SecureToken,
|
||||
header_token: Option<&str>,
|
||||
signature: Option<&str>,
|
||||
signature_expiration: Option<i64>,
|
||||
username: &str,
|
||||
path: &str,
|
||||
operation: &str,
|
||||
) -> Result<(), String> {
|
||||
if !token.is_set() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(ht) = header_token {
|
||||
if !ht.is_empty() {
|
||||
if token.equals(ht) {
|
||||
return Ok(());
|
||||
}
|
||||
return Err("access token present in header but does not match".into());
|
||||
}
|
||||
}
|
||||
|
||||
let sig = signature.ok_or("missing signature query parameter")?;
|
||||
|
||||
let expected = generate_signature(token, path, username, operation, signature_expiration)
|
||||
.map_err(|e| format!("error generating signing key: {e}"))?;
|
||||
|
||||
if expected != sig {
|
||||
return Err("invalid signature".into());
|
||||
}
|
||||
|
||||
if let Some(exp) = signature_expiration {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
if exp < now {
|
||||
return Err("signature is already expired".into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_token(val: &[u8]) -> SecureToken {
|
||||
let t = SecureToken::new();
|
||||
t.set(val).unwrap();
|
||||
t
|
||||
}
|
||||
|
||||
fn far_future() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64
|
||||
+ 3600
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_starts_with_v1() {
|
||||
let token = test_token(b"secret");
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert!(sig.starts_with("v1_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_deterministic() {
|
||||
let token = test_token(b"secret");
|
||||
let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert_eq!(s1, s2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_with_expiration_differs() {
|
||||
let token = test_token(b"secret");
|
||||
let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap();
|
||||
let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap();
|
||||
assert_ne!(without, with);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_unset_token_errors() {
|
||||
let token = SecureToken::new();
|
||||
assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_token_set_passes() {
|
||||
let token = SecureToken::new();
|
||||
assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_correct_header_token() {
|
||||
let token = test_token(b"secret");
|
||||
assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_wrong_header_token() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("does not match"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_valid_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let exp = far_future();
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap();
|
||||
assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_invalid_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("invalid signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_expired_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let expired: i64 = 1_000_000;
|
||||
let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap();
|
||||
let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("expired"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_missing_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("missing signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_empty_header_token_falls_through_to_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("missing signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_valid_signature_no_expiration() {
|
||||
let token = test_token(b"secret");
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_operations_produce_different_signatures() {
|
||||
let token = test_token(b"secret");
|
||||
let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap();
|
||||
let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap();
|
||||
assert_ne!(r, w);
|
||||
}
|
||||
}
|
||||
256
envd-rs/src/auth/token.rs
Normal file
256
envd-rs/src/auth/token.rs
Normal file
@ -0,0 +1,256 @@
|
||||
use std::sync::RwLock;
|
||||
|
||||
use subtle::ConstantTimeEq;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
/// Secure token storage with constant-time comparison and zeroize-on-drop.
|
||||
///
|
||||
/// Mirrors Go's SecureToken backed by memguard.LockedBuffer.
|
||||
/// In Rust we rely on `zeroize` for Drop-based zeroing.
|
||||
pub struct SecureToken {
|
||||
inner: RwLock<Option<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl SecureToken {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set(&self, token: &[u8]) -> Result<(), &'static str> {
|
||||
if token.is_empty() {
|
||||
return Err("empty token not allowed");
|
||||
}
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut old) = *guard {
|
||||
old.zeroize();
|
||||
}
|
||||
*guard = Some(token.to_vec());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_set(&self) -> bool {
|
||||
let guard = self.inner.read().unwrap();
|
||||
guard.is_some()
|
||||
}
|
||||
|
||||
/// Constant-time comparison.
|
||||
pub fn equals(&self, other: &str) -> bool {
|
||||
let guard = self.inner.read().unwrap();
|
||||
match guard.as_ref() {
|
||||
Some(buf) => buf.as_slice().ct_eq(other.as_bytes()).into(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Constant-time comparison with another SecureToken.
|
||||
pub fn equals_secure(&self, other: &SecureToken) -> bool {
|
||||
let other_bytes = match other.bytes() {
|
||||
Some(b) => b,
|
||||
None => return false,
|
||||
};
|
||||
let guard = self.inner.read().unwrap();
|
||||
let result = match guard.as_ref() {
|
||||
Some(buf) => buf.as_slice().ct_eq(&other_bytes).into(),
|
||||
None => false,
|
||||
};
|
||||
// other_bytes dropped here, Vec<u8> doesn't auto-zeroize but
|
||||
// we accept this — same as Go's `defer memguard.WipeBytes(otherBytes)`
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns a copy of the token bytes (for signature generation).
|
||||
pub fn bytes(&self) -> Option<Vec<u8>> {
|
||||
let guard = self.inner.read().unwrap();
|
||||
guard.as_ref().map(|b| b.clone())
|
||||
}
|
||||
|
||||
/// Transfer token from another SecureToken, clearing the source.
|
||||
pub fn take_from(&self, src: &SecureToken) {
|
||||
let taken = {
|
||||
let mut src_guard = src.inner.write().unwrap();
|
||||
src_guard.take()
|
||||
};
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut old) = *guard {
|
||||
old.zeroize();
|
||||
}
|
||||
*guard = taken;
|
||||
}
|
||||
|
||||
pub fn destroy(&self) {
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut buf) = *guard {
|
||||
buf.zeroize();
|
||||
}
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SecureToken {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
if let Some(ref mut buf) = *guard {
|
||||
buf.zeroize();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize from JSON string, matching Go's UnmarshalJSON behavior.
|
||||
/// Expects a quoted JSON string. Rejects escape sequences.
|
||||
impl SecureToken {
|
||||
pub fn from_json_bytes(data: &mut [u8]) -> Result<Self, &'static str> {
|
||||
if data.len() < 2 || data[0] != b'"' || data[data.len() - 1] != b'"' {
|
||||
data.zeroize();
|
||||
return Err("invalid secure token JSON string");
|
||||
}
|
||||
|
||||
let content = &data[1..data.len() - 1];
|
||||
if content.contains(&b'\\') {
|
||||
data.zeroize();
|
||||
return Err("invalid secure token: unexpected escape sequence");
|
||||
}
|
||||
|
||||
if content.is_empty() {
|
||||
data.zeroize();
|
||||
return Err("empty token not allowed");
|
||||
}
|
||||
|
||||
let token = Self::new();
|
||||
token.set(content).map_err(|_| "failed to set token")?;
|
||||
|
||||
data.zeroize();
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_is_unset() {
|
||||
let t = SecureToken::new();
|
||||
assert!(!t.is_set());
|
||||
assert!(!t.equals("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_and_equals() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"secret").unwrap();
|
||||
assert!(t.is_set());
|
||||
assert!(t.equals("secret"));
|
||||
assert!(!t.equals("wrong"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_empty_errors() {
|
||||
let t = SecureToken::new();
|
||||
assert!(t.set(b"").is_err());
|
||||
assert!(!t.is_set());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_overwrites_previous() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"first").unwrap();
|
||||
t.set(b"second").unwrap();
|
||||
assert!(!t.equals("first"));
|
||||
assert!(t.equals("second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn destroy_clears() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"secret").unwrap();
|
||||
t.destroy();
|
||||
assert!(!t.is_set());
|
||||
assert!(!t.equals("secret"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bytes_returns_copy() {
|
||||
let t = SecureToken::new();
|
||||
assert!(t.bytes().is_none());
|
||||
t.set(b"hello").unwrap();
|
||||
assert_eq!(t.bytes().unwrap(), b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_from_transfers_and_clears_source() {
|
||||
let src = SecureToken::new();
|
||||
src.set(b"token").unwrap();
|
||||
let dst = SecureToken::new();
|
||||
dst.take_from(&src);
|
||||
assert!(!src.is_set());
|
||||
assert!(dst.equals("token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_from_overwrites_existing() {
|
||||
let src = SecureToken::new();
|
||||
src.set(b"new").unwrap();
|
||||
let dst = SecureToken::new();
|
||||
dst.set(b"old").unwrap();
|
||||
dst.take_from(&src);
|
||||
assert!(dst.equals("new"));
|
||||
assert!(!dst.equals("old"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_matching() {
|
||||
let a = SecureToken::new();
|
||||
a.set(b"same").unwrap();
|
||||
let b = SecureToken::new();
|
||||
b.set(b"same").unwrap();
|
||||
assert!(a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_different() {
|
||||
let a = SecureToken::new();
|
||||
a.set(b"one").unwrap();
|
||||
let b = SecureToken::new();
|
||||
b.set(b"two").unwrap();
|
||||
assert!(!a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_unset() {
|
||||
let a = SecureToken::new();
|
||||
let b = SecureToken::new();
|
||||
assert!(!a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_valid() {
|
||||
let mut data = b"\"mysecret\"".to_vec();
|
||||
let t = SecureToken::from_json_bytes(&mut data).unwrap();
|
||||
assert!(t.equals("mysecret"));
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_missing_quotes() {
|
||||
let mut data = b"noquotes".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_escape_sequences() {
|
||||
let mut data = b"\"has\\nescapes\"".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_empty_content() {
|
||||
let mut data = b"\"\"".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
}
|
||||
66
envd-rs/src/cgroups/mod.rs
Normal file
66
envd-rs/src/cgroups/mod.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::os::unix::io::{OwnedFd, RawFd};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ProcessType {
|
||||
Pty,
|
||||
User,
|
||||
Socat,
|
||||
}
|
||||
|
||||
pub trait CgroupManager: Send + Sync {
|
||||
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd>;
|
||||
}
|
||||
|
||||
pub struct Cgroup2Manager {
|
||||
fds: HashMap<ProcessType, OwnedFd>,
|
||||
}
|
||||
|
||||
impl Cgroup2Manager {
|
||||
pub fn new(root: &str, configs: &[(ProcessType, &str, &[(&str, &str)])]) -> Result<Self, String> {
|
||||
let mut fds = HashMap::new();
|
||||
|
||||
for (proc_type, sub_path, properties) in configs {
|
||||
let full_path = PathBuf::from(root).join(sub_path);
|
||||
|
||||
fs::create_dir_all(&full_path).map_err(|e| {
|
||||
format!("failed to create cgroup {}: {e}", full_path.display())
|
||||
})?;
|
||||
|
||||
for (name, value) in *properties {
|
||||
let prop_path = full_path.join(name);
|
||||
fs::write(&prop_path, value).map_err(|e| {
|
||||
format!("failed to write cgroup property {}: {e}", prop_path.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let fd = nix::fcntl::open(
|
||||
&full_path,
|
||||
nix::fcntl::OFlag::O_RDONLY,
|
||||
nix::sys::stat::Mode::empty(),
|
||||
)
|
||||
.map_err(|e| format!("failed to open cgroup {}: {e}", full_path.display()))?;
|
||||
|
||||
fds.insert(*proc_type, fd);
|
||||
}
|
||||
|
||||
Ok(Self { fds })
|
||||
}
|
||||
}
|
||||
|
||||
impl CgroupManager for Cgroup2Manager {
|
||||
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd> {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
self.fds.get(&proc_type).map(|fd| fd.as_raw_fd())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoopCgroupManager;
|
||||
|
||||
impl CgroupManager for NoopCgroupManager {
|
||||
fn get_fd(&self, _proc_type: ProcessType) -> Option<RawFd> {
|
||||
None
|
||||
}
|
||||
}
|
||||
16
envd-rs/src/config.rs
Normal file
16
envd-rs/src/config.rs
Normal file
@ -0,0 +1,16 @@
|
||||
use std::time::Duration;
|
||||
|
||||
pub const DEFAULT_PORT: u16 = 49983;
|
||||
pub const IDLE_TIMEOUT: Duration = Duration::from_secs(640);
|
||||
pub const CORS_MAX_AGE: Duration = Duration::from_secs(7200);
|
||||
pub const PORT_SCANNER_INTERVAL: Duration = Duration::from_millis(1000);
|
||||
pub const DEFAULT_USER: &str = "root";
|
||||
pub const WRENN_RUN_DIR: &str = "/run/wrenn";
|
||||
|
||||
pub const KILOBYTE: u64 = 1024;
|
||||
pub const MEGABYTE: u64 = 1024 * KILOBYTE;
|
||||
|
||||
pub const MMDS_ADDRESS: &str = "169.254.169.254";
|
||||
pub const MMDS_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
pub const MMDS_TOKEN_EXPIRATION_SECS: u64 = 60;
|
||||
pub const MMDS_ACCESS_TOKEN_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
200
envd-rs/src/conntracker.rs
Normal file
200
envd-rs/src/conntracker.rs
Normal file
@ -0,0 +1,200 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Tracks active TCP connections for snapshot/restore lifecycle.
|
||||
///
|
||||
/// Before snapshot: close idle connections, record active ones.
|
||||
/// After restore: close all pre-snapshot connections (zombie TCP sockets).
|
||||
///
|
||||
/// In Rust/axum, we don't have Go's ConnState callback. Instead we track
|
||||
/// connections via a tower middleware that registers connection IDs.
|
||||
/// For the initial implementation, we track by a simple connection counter
|
||||
/// and rely on axum's graceful shutdown mechanics.
|
||||
pub struct ConnTracker {
|
||||
inner: Mutex<ConnTrackerInner>,
|
||||
}
|
||||
|
||||
struct ConnTrackerInner {
|
||||
active: HashSet<u64>,
|
||||
pre_snapshot: Option<HashSet<u64>>,
|
||||
next_id: u64,
|
||||
keepalives_enabled: bool,
|
||||
}
|
||||
|
||||
impl ConnTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(ConnTrackerInner {
|
||||
active: HashSet::new(),
|
||||
pre_snapshot: None,
|
||||
next_id: 0,
|
||||
keepalives_enabled: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_connection(&self) -> u64 {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let id = inner.next_id;
|
||||
inner.next_id += 1;
|
||||
inner.active.insert(id);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn remove_connection(&self, id: u64) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.active.remove(&id);
|
||||
if let Some(ref mut pre) = inner.pre_snapshot {
|
||||
pre.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prepare_for_snapshot(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.keepalives_enabled = false;
|
||||
inner.pre_snapshot = Some(inner.active.clone());
|
||||
tracing::info!(
|
||||
active_connections = inner.active.len(),
|
||||
"snapshot: recorded pre-snapshot connections, keep-alives disabled"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn restore_after_snapshot(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
if let Some(pre) = inner.pre_snapshot.take() {
|
||||
let zombie_count = pre.len();
|
||||
for id in &pre {
|
||||
inner.active.remove(id);
|
||||
}
|
||||
if zombie_count > 0 {
|
||||
tracing::info!(zombie_count, "restore: closed zombie connections");
|
||||
}
|
||||
}
|
||||
inner.keepalives_enabled = true;
|
||||
}
|
||||
|
||||
pub fn keepalives_enabled(&self) -> bool {
|
||||
self.inner.lock().unwrap().keepalives_enabled
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn active_count(&self) -> usize {
|
||||
self.inner.lock().unwrap().active.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn register_assigns_sequential_ids() {
|
||||
let ct = ConnTracker::new();
|
||||
assert_eq!(ct.register_connection(), 0);
|
||||
assert_eq!(ct.register_connection(), 1);
|
||||
assert_eq!(ct.register_connection(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_clears_active() {
|
||||
let ct = ConnTracker::new();
|
||||
let id = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.remove_connection(id);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_is_noop() {
|
||||
let ct = ConnTracker::new();
|
||||
ct.remove_connection(999);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_disables_keepalives() {
|
||||
let ct = ConnTracker::new();
|
||||
assert!(ct.keepalives_enabled());
|
||||
ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
assert!(!ct.keepalives_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restore_removes_zombies_and_reenables_keepalives() {
|
||||
let ct = ConnTracker::new();
|
||||
let id0 = ct.register_connection();
|
||||
let id1 = ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
// Both pre-snapshot connections removed as zombies
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
// IDs don't matter anymore, but remove shouldn't panic
|
||||
ct.remove_connection(id0);
|
||||
ct.remove_connection(id1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restore_without_prepare_is_noop() {
|
||||
let ct = ConnTracker::new();
|
||||
let _id = ct.register_connection();
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_before_restore_not_zombie() {
|
||||
let ct = ConnTracker::new();
|
||||
let id0 = ct.register_connection();
|
||||
let _id1 = ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
// Close id0 during snapshot window
|
||||
ct.remove_connection(id0);
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.restore_after_snapshot();
|
||||
// id1 was zombie (still active at restore), id0 already gone
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn post_snapshot_connection_survives_restore() {
|
||||
let ct = ConnTracker::new();
|
||||
ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
// New connection after snapshot
|
||||
let _post = ct.register_connection();
|
||||
ct.restore_after_snapshot();
|
||||
// Pre-snapshot connection removed, post-snapshot survives
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_lifecycle() {
|
||||
let ct = ConnTracker::new();
|
||||
let _a = ct.register_connection();
|
||||
let b = ct.register_connection();
|
||||
let _c = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 3);
|
||||
assert!(ct.keepalives_enabled());
|
||||
|
||||
ct.prepare_for_snapshot();
|
||||
assert!(!ct.keepalives_enabled());
|
||||
|
||||
let d = ct.register_connection();
|
||||
ct.remove_connection(b);
|
||||
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
// a and c were zombies, b removed before restore, d is post-snapshot
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.remove_connection(d);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
|
||||
// Can reuse tracker after restore
|
||||
let e = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
assert!(e > d);
|
||||
}
|
||||
}
|
||||
43
envd-rs/src/crypto/hmac_sha256.rs
Normal file
43
envd-rs/src/crypto/hmac_sha256.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub fn compute(key: &[u8], data: &[u8]) -> String {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
|
||||
mac.update(data);
|
||||
let result = mac.finalize();
|
||||
hex::encode(result.into_bytes())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rfc4231_tc1() {
|
||||
let key = &[0x0b; 20];
|
||||
let data = b"Hi There";
|
||||
assert_eq!(
|
||||
compute(key, data),
|
||||
"b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rfc4231_tc2() {
|
||||
let key = b"Jefe";
|
||||
let data = b"what do ya want for nothing?";
|
||||
assert_eq!(
|
||||
compute(key, data),
|
||||
"5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_64_hex_chars() {
|
||||
let result = compute(b"key", b"data");
|
||||
assert_eq!(result.len(), 64);
|
||||
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
}
|
||||
3
envd-rs/src/crypto/mod.rs
Normal file
3
envd-rs/src/crypto/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod sha256;
|
||||
pub mod sha512;
|
||||
pub mod hmac_sha256;
|
||||
54
envd-rs/src/crypto/sha256.rs
Normal file
54
envd-rs/src/crypto/sha256.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD_NO_PAD;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
pub fn hash(data: &[u8]) -> String {
|
||||
let h = Sha256::digest(data);
|
||||
let encoded = STANDARD_NO_PAD.encode(h);
|
||||
format!("$sha256${encoded}")
|
||||
}
|
||||
|
||||
pub fn hash_without_prefix(data: &[u8]) -> String {
|
||||
let h = Sha256::digest(data);
|
||||
STANDARD_NO_PAD.encode(h)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const VECTORS: &[(&[u8], &str)] = &[
|
||||
(b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"),
|
||||
(b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"),
|
||||
(b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn known_answer_with_prefix() {
|
||||
for (input, expected_b64) in VECTORS {
|
||||
let result = hash(input);
|
||||
assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn known_answer_without_prefix() {
|
||||
for (input, expected_b64) in VECTORS {
|
||||
let result = hash_without_prefix(input);
|
||||
assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_base64_padding() {
|
||||
for (input, _) in VECTORS {
|
||||
assert!(!hash(input).contains('='));
|
||||
assert!(!hash_without_prefix(input).contains('='));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deterministic() {
|
||||
assert_eq!(hash(b"test"), hash(b"test"));
|
||||
}
|
||||
}
|
||||
43
envd-rs/src/crypto/sha512.rs
Normal file
43
envd-rs/src/crypto/sha512.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use sha2::{Digest, Sha512};
|
||||
|
||||
pub fn hash_access_token(token: &str) -> String {
|
||||
let h = Sha512::digest(token.as_bytes());
|
||||
hex::encode(h)
|
||||
}
|
||||
|
||||
pub fn hash_access_token_bytes(token: &[u8]) -> String {
|
||||
let h = Sha512::digest(token);
|
||||
hex::encode(h)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const VECTORS: &[(&str, &str)] = &[
|
||||
("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"),
|
||||
("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"),
|
||||
("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn known_answer() {
|
||||
for (input, expected) in VECTORS {
|
||||
assert_eq!(hash_access_token(input), *expected, "input: {input:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn str_and_bytes_agree() {
|
||||
for (input, _) in VECTORS {
|
||||
assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_lowercase_hex_128_chars() {
|
||||
let h = hash_access_token("anything");
|
||||
assert_eq!(h.len(), 128);
|
||||
assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
|
||||
}
|
||||
}
|
||||
118
envd-rs/src/execcontext.rs
Normal file
118
envd-rs/src/execcontext.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
pub struct Defaults {
|
||||
pub env_vars: Arc<DashMap<String, String>>,
|
||||
user: RwLock<String>,
|
||||
workdir: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl Defaults {
|
||||
pub fn new(user: &str) -> Self {
|
||||
Self {
|
||||
env_vars: Arc::new(DashMap::new()),
|
||||
user: RwLock::new(user.to_string()),
|
||||
workdir: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(&self) -> String {
|
||||
self.user.read().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn set_user(&self, user: String) {
|
||||
*self.user.write().unwrap() = user;
|
||||
}
|
||||
|
||||
pub fn workdir(&self) -> Option<String> {
|
||||
self.workdir.read().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn set_workdir(&self, workdir: Option<String>) {
|
||||
*self.workdir.write().unwrap() = workdir;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String {
|
||||
if !workdir.is_empty() {
|
||||
return workdir.to_string();
|
||||
}
|
||||
if let Some(dw) = default_workdir {
|
||||
return dw.to_string();
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
pub fn resolve_default_username<'a>(
|
||||
username: Option<&'a str>,
|
||||
default_username: &'a str,
|
||||
) -> Result<&'a str, &'static str> {
|
||||
if let Some(u) = username {
|
||||
return Ok(u);
|
||||
}
|
||||
if !default_username.is_empty() {
|
||||
return Ok(default_username);
|
||||
}
|
||||
Err("username not provided")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn workdir_explicit_overrides_default() {
|
||||
assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_empty_uses_default() {
|
||||
assert_eq!(resolve_default_workdir("", Some("/default")), "/default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_empty_no_default_returns_empty() {
|
||||
assert_eq!(resolve_default_workdir("", None), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_explicit_ignores_none_default() {
|
||||
assert_eq!(resolve_default_workdir("/explicit", None), "/explicit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_explicit_returns_explicit() {
|
||||
assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_none_uses_default() {
|
||||
assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_none_empty_default_errors() {
|
||||
assert!(resolve_default_username(None, "").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_some_overrides_empty_default() {
|
||||
assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defaults_user_set_and_get() {
|
||||
let d = Defaults::new("initial");
|
||||
assert_eq!(d.user(), "initial");
|
||||
d.set_user("changed".into());
|
||||
assert_eq!(d.user(), "changed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defaults_workdir_initially_none() {
|
||||
let d = Defaults::new("user");
|
||||
assert!(d.workdir().is_none());
|
||||
d.set_workdir(Some("/home".into()));
|
||||
assert_eq!(d.workdir().unwrap(), "/home");
|
||||
}
|
||||
}
|
||||
73
envd-rs/src/host/metrics.rs
Normal file
73
envd-rs/src/host/metrics.rs
Normal file
@ -0,0 +1,73 @@
|
||||
use std::ffi::CString;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Metrics {
|
||||
pub ts: i64,
|
||||
pub cpu_count: u32,
|
||||
pub cpu_used_pct: f32,
|
||||
pub mem_total_mib: u64,
|
||||
pub mem_used_mib: u64,
|
||||
pub mem_total: u64,
|
||||
pub mem_used: u64,
|
||||
pub disk_used: u64,
|
||||
pub disk_total: u64,
|
||||
}
|
||||
|
||||
pub fn get_metrics() -> Result<Metrics, String> {
|
||||
use sysinfo::System;
|
||||
|
||||
let mut sys = System::new();
|
||||
sys.refresh_memory();
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
let cpu_count = sys.cpus().len() as u32;
|
||||
let cpu_used_pct = sys.global_cpu_usage();
|
||||
let cpu_used_pct_rounded = if cpu_used_pct > 0.0 {
|
||||
(cpu_used_pct * 100.0).round() / 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let mem_total = sys.total_memory();
|
||||
let mem_used = sys.used_memory();
|
||||
|
||||
let (disk_total, disk_used) = disk_stats("/")?;
|
||||
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
Ok(Metrics {
|
||||
ts,
|
||||
cpu_count,
|
||||
cpu_used_pct: cpu_used_pct_rounded,
|
||||
mem_total_mib: mem_total / 1024 / 1024,
|
||||
mem_used_mib: mem_used / 1024 / 1024,
|
||||
mem_total,
|
||||
mem_used,
|
||||
disk_used,
|
||||
disk_total,
|
||||
})
|
||||
}
|
||||
|
||||
fn disk_stats(path: &str) -> Result<(u64, u64), String> {
|
||||
let c_path = CString::new(path).unwrap();
|
||||
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
|
||||
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
|
||||
if ret != 0 {
|
||||
return Err(format!("statfs failed: {}", std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let block = stat.f_bsize as u64;
|
||||
let total = stat.f_blocks * block;
|
||||
let available = stat.f_bavail * block;
|
||||
|
||||
Ok((total, total - available))
|
||||
}
|
||||
120
envd-rs/src/host/mmds.rs
Normal file
120
envd-rs/src/host/mmds.rs
Normal file
@ -0,0 +1,120 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use serde::Deserialize;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::config::{MMDS_ADDRESS, MMDS_POLL_INTERVAL, MMDS_TOKEN_EXPIRATION_SECS, WRENN_RUN_DIR};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MMDSOpts {
|
||||
#[serde(rename = "instanceID")]
|
||||
pub sandbox_id: String,
|
||||
#[serde(rename = "envID")]
|
||||
pub template_id: String,
|
||||
#[serde(rename = "address", default)]
|
||||
pub logs_collector_address: String,
|
||||
#[serde(rename = "accessTokenHash", default)]
|
||||
pub access_token_hash: String,
|
||||
}
|
||||
|
||||
async fn get_mmds_token(client: &reqwest::Client) -> Result<String, String> {
|
||||
let resp = client
|
||||
.put(format!("http://{MMDS_ADDRESS}/latest/api/token"))
|
||||
.header(
|
||||
"X-metadata-token-ttl-seconds",
|
||||
MMDS_TOKEN_EXPIRATION_SECS.to_string(),
|
||||
)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("mmds token request failed: {e}"))?;
|
||||
|
||||
let token = resp.text().await.map_err(|e| format!("mmds token read: {e}"))?;
|
||||
if token.is_empty() {
|
||||
return Err("mmds token is an empty string".into());
|
||||
}
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn get_mmds_opts(client: &reqwest::Client, token: &str) -> Result<MMDSOpts, String> {
|
||||
let resp = client
|
||||
.get(format!("http://{MMDS_ADDRESS}"))
|
||||
.header("X-metadata-token", token)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("mmds opts request failed: {e}"))?;
|
||||
|
||||
resp.json::<MMDSOpts>()
|
||||
.await
|
||||
.map_err(|e| format!("mmds opts parse: {e}"))
|
||||
}
|
||||
|
||||
pub async fn get_access_token_hash() -> Result<String, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.no_proxy()
|
||||
.build()
|
||||
.map_err(|e| format!("http client: {e}"))?;
|
||||
|
||||
let token = get_mmds_token(&client).await?;
|
||||
let opts = get_mmds_opts(&client, &token).await?;
|
||||
Ok(opts.access_token_hash)
|
||||
}
|
||||
|
||||
/// Polls MMDS every 50ms until metadata is available.
|
||||
/// Stores sandbox_id and template_id in env_vars and writes to /run/wrenn/ files.
|
||||
pub async fn poll_for_opts(
|
||||
env_vars: Arc<DashMap<String, String>>,
|
||||
cancel: CancellationToken,
|
||||
) -> Option<MMDSOpts> {
|
||||
let client = reqwest::Client::builder()
|
||||
.no_proxy()
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let mut interval = tokio::time::interval(MMDS_POLL_INTERVAL);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
tracing::warn!("context cancelled while waiting for mmds opts");
|
||||
return None;
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
let token = match get_mmds_token(&client).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::debug!(error = %e, "mmds token poll");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let opts = match get_mmds_opts(&client, &token).await {
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
tracing::debug!(error = %e, "mmds opts poll");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
env_vars.insert("WRENN_SANDBOX_ID".into(), opts.sandbox_id.clone());
|
||||
env_vars.insert("WRENN_TEMPLATE_ID".into(), opts.template_id.clone());
|
||||
|
||||
let run_dir = std::path::Path::new(WRENN_RUN_DIR);
|
||||
if let Err(e) = std::fs::create_dir_all(run_dir) {
|
||||
tracing::error!(error = %e, "mmds: failed to create run dir");
|
||||
}
|
||||
if let Err(e) = std::fs::write(run_dir.join(".WRENN_SANDBOX_ID"), &opts.sandbox_id) {
|
||||
tracing::error!(error = %e, "mmds: failed to write .WRENN_SANDBOX_ID");
|
||||
}
|
||||
if let Err(e) = std::fs::write(run_dir.join(".WRENN_TEMPLATE_ID"), &opts.template_id) {
|
||||
tracing::error!(error = %e, "mmds: failed to write .WRENN_TEMPLATE_ID");
|
||||
}
|
||||
|
||||
return Some(opts);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
2
envd-rs/src/host/mod.rs
Normal file
2
envd-rs/src/host/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod metrics;
|
||||
pub mod mmds;
|
||||
336
envd-rs/src/http/encoding.rs
Normal file
336
envd-rs/src/http/encoding.rs
Normal file
@ -0,0 +1,336 @@
|
||||
use axum::http::Request;
|
||||
|
||||
const ENCODING_GZIP: &str = "gzip";
|
||||
const ENCODING_IDENTITY: &str = "identity";
|
||||
const ENCODING_WILDCARD: &str = "*";
|
||||
|
||||
const SUPPORTED_ENCODINGS: &[&str] = &[ENCODING_GZIP];
|
||||
|
||||
struct EncodingWithQuality {
|
||||
encoding: String,
|
||||
quality: f64,
|
||||
}
|
||||
|
||||
fn parse_encoding_with_quality(value: &str) -> EncodingWithQuality {
|
||||
let value = value.trim();
|
||||
let mut quality = 1.0;
|
||||
|
||||
if let Some(idx) = value.find(';') {
|
||||
let params = &value[idx + 1..];
|
||||
let enc = value[..idx].trim();
|
||||
for param in params.split(';') {
|
||||
let param = param.trim();
|
||||
if let Some(stripped) = param.strip_prefix("q=").or_else(|| param.strip_prefix("Q=")) {
|
||||
if let Ok(q) = stripped.parse::<f64>() {
|
||||
quality = q;
|
||||
}
|
||||
}
|
||||
}
|
||||
return EncodingWithQuality {
|
||||
encoding: enc.to_ascii_lowercase(),
|
||||
quality,
|
||||
};
|
||||
}
|
||||
|
||||
EncodingWithQuality {
|
||||
encoding: value.to_ascii_lowercase(),
|
||||
quality,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_accept_encoding_header(header: &str) -> (Vec<EncodingWithQuality>, bool) {
|
||||
if header.is_empty() {
|
||||
return (Vec::new(), false);
|
||||
}
|
||||
|
||||
let encodings: Vec<EncodingWithQuality> =
|
||||
header.split(',').map(|v| parse_encoding_with_quality(v)).collect();
|
||||
|
||||
let mut identity_rejected = false;
|
||||
let mut identity_explicitly_accepted = false;
|
||||
let mut wildcard_rejected = false;
|
||||
|
||||
for eq in &encodings {
|
||||
match eq.encoding.as_str() {
|
||||
ENCODING_IDENTITY => {
|
||||
if eq.quality == 0.0 {
|
||||
identity_rejected = true;
|
||||
} else {
|
||||
identity_explicitly_accepted = true;
|
||||
}
|
||||
}
|
||||
ENCODING_WILDCARD => {
|
||||
if eq.quality == 0.0 {
|
||||
wildcard_rejected = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcard_rejected && !identity_explicitly_accepted {
|
||||
identity_rejected = true;
|
||||
}
|
||||
|
||||
(encodings, identity_rejected)
|
||||
}
|
||||
|
||||
pub fn is_identity_acceptable<B>(r: &Request<B>) -> bool {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("accept-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let (_, rejected) = parse_accept_encoding_header(header);
|
||||
!rejected
|
||||
}
|
||||
|
||||
pub fn parse_accept_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("accept-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if header.is_empty() {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
let (mut encodings, identity_rejected) = parse_accept_encoding_header(header);
|
||||
encodings.sort_by(|a, b| b.quality.partial_cmp(&a.quality).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
for eq in &encodings {
|
||||
if eq.quality == 0.0 {
|
||||
continue;
|
||||
}
|
||||
if eq.encoding == ENCODING_IDENTITY {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if eq.encoding == ENCODING_WILDCARD {
|
||||
if identity_rejected && !SUPPORTED_ENCODINGS.is_empty() {
|
||||
return Ok(SUPPORTED_ENCODINGS[0]);
|
||||
}
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if eq.encoding == ENCODING_GZIP {
|
||||
return Ok(ENCODING_GZIP);
|
||||
}
|
||||
}
|
||||
|
||||
if !identity_rejected {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
Err(format!("no acceptable encoding found, supported: {SUPPORTED_ENCODINGS:?}"))
|
||||
}
|
||||
|
||||
pub fn parse_content_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if header.is_empty() {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
let encoding = header.trim().to_ascii_lowercase();
|
||||
if encoding == ENCODING_IDENTITY {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if SUPPORTED_ENCODINGS.contains(&encoding.as_str()) {
|
||||
return Ok(ENCODING_GZIP);
|
||||
}
|
||||
|
||||
Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::Request;
|
||||
|
||||
fn req_with_accept(v: &str) -> Request<()> {
|
||||
Request::builder()
|
||||
.header("accept-encoding", v)
|
||||
.body(())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn req_with_content(v: &str) -> Request<()> {
|
||||
Request::builder()
|
||||
.header("content-encoding", v)
|
||||
.body(())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn req_no_headers() -> Request<()> {
|
||||
Request::builder().body(()).unwrap()
|
||||
}
|
||||
|
||||
// parse_encoding_with_quality
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_default_1() {
|
||||
let eq = parse_encoding_with_quality("gzip");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_explicit() {
|
||||
let eq = parse_encoding_with_quality("gzip;q=0.8");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_case_insensitive() {
|
||||
let eq = parse_encoding_with_quality("GZIP;Q=0.5");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_zero() {
|
||||
let eq = parse_encoding_with_quality("gzip;q=0");
|
||||
assert_eq!(eq.quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_whitespace_trimmed() {
|
||||
let eq = parse_encoding_with_quality(" gzip ; q=0.9 ");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.9);
|
||||
}
|
||||
|
||||
// parse_accept_encoding_header
|
||||
|
||||
#[test]
|
||||
fn accept_header_empty() {
|
||||
let (encs, rejected) = parse_accept_encoding_header("");
|
||||
assert!(encs.is_empty());
|
||||
assert!(!rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_identity_q0_rejects() {
|
||||
let (_, rejected) = parse_accept_encoding_header("identity;q=0");
|
||||
assert!(rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_wildcard_q0_rejects_identity() {
|
||||
let (_, rejected) = parse_accept_encoding_header("*;q=0");
|
||||
assert!(rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_wildcard_q0_but_identity_explicit_accepted() {
|
||||
let (_, rejected) = parse_accept_encoding_header("*;q=0, identity");
|
||||
assert!(!rejected);
|
||||
}
|
||||
|
||||
// parse_accept_encoding (full)
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_no_header_returns_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_gzip() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_identity_explicit() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_gzip_higher_quality() {
|
||||
assert_eq!(
|
||||
parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(),
|
||||
"gzip"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_wildcard_returns_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_wildcard_identity_rejected_returns_gzip() {
|
||||
assert_eq!(
|
||||
parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(),
|
||||
"gzip"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_all_rejected_errors() {
|
||||
assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_unsupported_only_falls_to_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
// is_identity_acceptable
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_no_header() {
|
||||
assert!(is_identity_acceptable(&req_no_headers()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_gzip_only() {
|
||||
assert!(is_identity_acceptable(&req_with_accept("gzip")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_not_acceptable_identity_q0() {
|
||||
assert!(!is_identity_acceptable(&req_with_accept("identity;q=0")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_not_acceptable_wildcard_q0() {
|
||||
assert!(!is_identity_acceptable(&req_with_accept("*;q=0")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_wildcard_q0_but_identity_explicit() {
|
||||
assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity")));
|
||||
}
|
||||
|
||||
// parse_content_encoding
|
||||
|
||||
#[test]
|
||||
fn content_encoding_empty_returns_identity() {
|
||||
assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_gzip() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_identity_explicit() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_unsupported_errors() {
|
||||
assert!(parse_content_encoding(&req_with_content("br")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_case_insensitive() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip");
|
||||
}
|
||||
}
|
||||
25
envd-rs/src/http/envs.rs
Normal file
25
envd-rs/src/http/envs.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn get_envs(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
tracing::debug!("getting env vars");
|
||||
|
||||
let envs: HashMap<String, String> = state
|
||||
.defaults
|
||||
.env_vars
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect();
|
||||
|
||||
(
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(envs),
|
||||
)
|
||||
}
|
||||
20
envd-rs/src/http/error.rs
Normal file
20
envd-rs/src/http/error.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use axum::Json;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorBody {
|
||||
code: u16,
|
||||
message: String,
|
||||
}
|
||||
|
||||
pub fn json_error(status: StatusCode, message: &str) -> impl IntoResponse {
|
||||
(
|
||||
status,
|
||||
Json(ErrorBody {
|
||||
code: status.as_u16(),
|
||||
message: message.to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
447
envd-rs/src/http/files.rs
Normal file
447
envd-rs/src/http/files.rs
Normal file
@ -0,0 +1,447 @@
|
||||
use std::io::Write as _;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::extract::{FromRequest, Query, Request, State};
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::signing;
|
||||
use crate::execcontext;
|
||||
use crate::http::encoding;
|
||||
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::state::AppState;
|
||||
|
||||
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FileParams {
|
||||
pub path: Option<String>,
|
||||
pub username: Option<String>,
|
||||
pub signature: Option<String>,
|
||||
pub signature_expiration: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EntryInfo {
|
||||
path: String,
|
||||
name: String,
|
||||
r#type: &'static str,
|
||||
}
|
||||
|
||||
fn json_error(status: StatusCode, msg: &str) -> Response {
|
||||
let body = serde_json::json!({ "code": status.as_u16(), "message": msg });
|
||||
(status, axum::Json(body)).into_response()
|
||||
}
|
||||
|
||||
fn extract_header_token(req: &Request) -> Option<&str> {
|
||||
req.headers()
|
||||
.get(ACCESS_TOKEN_HEADER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
}
|
||||
|
||||
fn validate_file_signing(
|
||||
state: &AppState,
|
||||
header_token: Option<&str>,
|
||||
params: &FileParams,
|
||||
path: &str,
|
||||
operation: &str,
|
||||
username: &str,
|
||||
) -> Result<(), String> {
|
||||
signing::validate_signing(
|
||||
&state.access_token,
|
||||
header_token,
|
||||
params.signature.as_deref(),
|
||||
params.signature_expiration,
|
||||
username,
|
||||
path,
|
||||
operation,
|
||||
)
|
||||
}
|
||||
|
||||
/// GET /files — download a file
|
||||
pub async fn get_files(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<FileParams>,
|
||||
req: Request,
|
||||
) -> Response {
|
||||
let path_str = params.path.as_deref().unwrap_or("");
|
||||
let header_token = extract_header_token(&req);
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let username = match execcontext::resolve_default_username(
|
||||
params.username.as_deref(),
|
||||
&default_user,
|
||||
) {
|
||||
Ok(u) => u.to_string(),
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_file_signing(
|
||||
&state,
|
||||
header_token,
|
||||
¶ms,
|
||||
path_str,
|
||||
signing::READ_OPERATION,
|
||||
&username,
|
||||
) {
|
||||
return json_error(StatusCode::UNAUTHORIZED, &e);
|
||||
}
|
||||
|
||||
let user = match lookup_user(&username) {
|
||||
Ok(u) => u,
|
||||
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
|
||||
};
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref())
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
};
|
||||
|
||||
let meta = match std::fs::metadata(&resolved) {
|
||||
Ok(m) => m,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
return json_error(
|
||||
StatusCode::NOT_FOUND,
|
||||
&format!("path '{}' does not exist", resolved),
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error checking path: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if meta.is_dir() {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("path '{}' is a directory", resolved),
|
||||
);
|
||||
}
|
||||
|
||||
if !meta.file_type().is_file() {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("path '{}' is not a regular file", resolved),
|
||||
);
|
||||
}
|
||||
|
||||
let accept_enc = match encoding::parse_accept_encoding(&req) {
|
||||
Ok(e) => e,
|
||||
Err(e) => return json_error(StatusCode::NOT_ACCEPTABLE, &e),
|
||||
};
|
||||
|
||||
let has_range_or_conditional = req.headers().get("range").is_some()
|
||||
|| req.headers().get("if-modified-since").is_some()
|
||||
|| req.headers().get("if-none-match").is_some()
|
||||
|| req.headers().get("if-range").is_some();
|
||||
|
||||
let use_encoding = if has_range_or_conditional {
|
||||
if !encoding::is_identity_acceptable(&req) {
|
||||
return json_error(
|
||||
StatusCode::NOT_ACCEPTABLE,
|
||||
"identity encoding not acceptable for Range or conditional request",
|
||||
);
|
||||
}
|
||||
"identity"
|
||||
} else {
|
||||
accept_enc
|
||||
};
|
||||
|
||||
let file_data = match std::fs::read(&resolved) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error reading file: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let filename = Path::new(&resolved)
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let content_disposition = format!("inline; filename=\"{}\"", filename);
|
||||
let content_type = mime_guess::from_path(&resolved)
|
||||
.first_raw()
|
||||
.unwrap_or("application/octet-stream");
|
||||
|
||||
if use_encoding == "gzip" {
|
||||
let mut encoder =
|
||||
flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
|
||||
if let Err(e) = encoder.write_all(&file_data) {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("gzip encoding error: {e}"),
|
||||
);
|
||||
}
|
||||
let compressed = match encoder.finish() {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("gzip finish error: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_ENCODING, "gzip")
|
||||
.header(header::CONTENT_DISPOSITION, content_disposition)
|
||||
.header(header::VARY, "Accept-Encoding")
|
||||
.body(Body::from(compressed))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_DISPOSITION, content_disposition)
|
||||
.header(header::VARY, "Accept-Encoding")
|
||||
.header(header::CONTENT_LENGTH, file_data.len())
|
||||
.body(Body::from(file_data))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// POST /files — upload file(s) via multipart
|
||||
pub async fn post_files(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<FileParams>,
|
||||
req: Request,
|
||||
) -> Response {
|
||||
let path_str = params.path.as_deref().unwrap_or("");
|
||||
let header_token = extract_header_token(&req);
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let username = match execcontext::resolve_default_username(
|
||||
params.username.as_deref(),
|
||||
&default_user,
|
||||
) {
|
||||
Ok(u) => u.to_string(),
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_file_signing(
|
||||
&state,
|
||||
header_token,
|
||||
¶ms,
|
||||
path_str,
|
||||
signing::WRITE_OPERATION,
|
||||
&username,
|
||||
) {
|
||||
return json_error(StatusCode::UNAUTHORIZED, &e);
|
||||
}
|
||||
|
||||
let user = match lookup_user(&username) {
|
||||
Ok(u) => u,
|
||||
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
|
||||
};
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let uid = user.uid;
|
||||
let gid = user.gid;
|
||||
|
||||
let content_enc = match encoding::parse_content_encoding(&req) {
|
||||
Ok(e) => e,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
};
|
||||
|
||||
let mut multipart = match axum::extract::Multipart::from_request(req, &()).await {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error parsing multipart: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let mut uploaded: Vec<EntryInfo> = Vec::new();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
|
||||
while let Ok(Some(field)) = multipart.next_field().await {
|
||||
let field_name = field.name().unwrap_or("").to_string();
|
||||
if field_name != "file" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let file_path = if !path_str.is_empty() {
|
||||
match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
}
|
||||
} else {
|
||||
let fname = field
|
||||
.file_name()
|
||||
.unwrap_or("upload")
|
||||
.to_string();
|
||||
match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
}
|
||||
};
|
||||
|
||||
if uploaded.iter().any(|e| e.path == file_path) {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("cannot upload multiple files to same path '{}'", file_path),
|
||||
);
|
||||
}
|
||||
|
||||
let raw_bytes = match field.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error reading field: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let data = if content_enc == "gzip" {
|
||||
use std::io::Read;
|
||||
let mut decoder = flate2::read::GzDecoder::new(&raw_bytes[..]);
|
||||
let mut buf = Vec::new();
|
||||
match decoder.read_to_end(&mut buf) {
|
||||
Ok(_) => buf,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("gzip decompression failed: {e}"),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
raw_bytes.to_vec()
|
||||
};
|
||||
|
||||
if let Err(e) = process_file(&file_path, &data, uid, gid) {
|
||||
let (status, msg) = e;
|
||||
return json_error(status, &msg);
|
||||
}
|
||||
|
||||
let name = Path::new(&file_path)
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
uploaded.push(EntryInfo {
|
||||
path: file_path,
|
||||
name,
|
||||
r#type: "file",
|
||||
});
|
||||
}
|
||||
|
||||
axum::Json(uploaded).into_response()
|
||||
}
|
||||
|
||||
fn process_file(
|
||||
path: &str,
|
||||
data: &[u8],
|
||||
uid: nix::unistd::Uid,
|
||||
gid: nix::unistd::Gid,
|
||||
) -> Result<(), (StatusCode, String)> {
|
||||
let dir = Path::new(path)
|
||||
.parent()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !dir.is_empty() {
|
||||
ensure_dirs(&dir, uid, gid).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error ensuring directories: {e}"),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let can_pre_chown = match std::fs::metadata(path) {
|
||||
Ok(meta) => {
|
||||
if meta.is_dir() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("path is a directory: {path}"),
|
||||
));
|
||||
}
|
||||
true
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => false,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error getting file info: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut chowned = false;
|
||||
if can_pre_chown {
|
||||
match std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())) {
|
||||
Ok(()) => chowned = true,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error changing ownership: {e}"),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o666)
|
||||
.open(path)
|
||||
.map_err(|e| {
|
||||
if e.raw_os_error() == Some(libc::ENOSPC) {
|
||||
return (
|
||||
StatusCode::INSUFFICIENT_STORAGE,
|
||||
"not enough disk space available".to_string(),
|
||||
);
|
||||
}
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error opening file: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !chowned {
|
||||
std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error changing ownership: {e}"),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
file.write_all(data).map_err(|e| {
|
||||
if e.raw_os_error() == Some(libc::ENOSPC) {
|
||||
return (
|
||||
StatusCode::INSUFFICIENT_STORAGE,
|
||||
"not enough disk space available".to_string(),
|
||||
);
|
||||
}
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error writing file: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
41
envd-rs/src/http/health.rs
Normal file
41
envd-rs/src/http/health.rs
Normal file
@ -0,0 +1,41 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn get_health(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
if state
|
||||
.needs_restore
|
||||
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok()
|
||||
{
|
||||
post_restore_recovery(&state);
|
||||
}
|
||||
|
||||
tracing::trace!("health check");
|
||||
|
||||
(
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(json!({ "version": state.version })),
|
||||
)
|
||||
}
|
||||
|
||||
fn post_restore_recovery(state: &AppState) {
|
||||
tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
|
||||
|
||||
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
|
||||
|
||||
state.conn_tracker.restore_after_snapshot();
|
||||
tracing::info!("restore: zombie connections closed");
|
||||
|
||||
if let Some(ref ps) = state.port_subsystem {
|
||||
ps.restart();
|
||||
tracing::info!("restore: port subsystem restarted");
|
||||
}
|
||||
}
|
||||
281
envd-rs/src/http/init.rs
Normal file
281
envd-rs/src/http/init.rs
Normal file
@ -0,0 +1,281 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::crypto;
|
||||
use crate::host::mmds;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitRequest {
|
||||
pub access_token: Option<String>,
|
||||
pub default_user: Option<String>,
|
||||
pub default_workdir: Option<String>,
|
||||
pub env_vars: Option<HashMap<String, String>>,
|
||||
pub hyperloop_ip: Option<String>,
|
||||
pub timestamp: Option<String>,
|
||||
pub volume_mounts: Option<Vec<VolumeMount>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VolumeMount {
|
||||
pub nfs_target: String,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// POST /init — called by host agent after boot and after every resume.
|
||||
pub async fn post_init(
|
||||
State(state): State<Arc<AppState>>,
|
||||
body: Option<Json<InitRequest>>,
|
||||
) -> impl IntoResponse {
|
||||
let init_req = body.map(|b| b.0).unwrap_or_default();
|
||||
|
||||
// Validate access token if provided
|
||||
if let Some(ref token_str) = init_req.access_token {
|
||||
if let Err(e) = validate_init_access_token(&state, token_str).await {
|
||||
tracing::error!(error = %e, "init: access token validation failed");
|
||||
return (StatusCode::UNAUTHORIZED, e).into_response();
|
||||
}
|
||||
}
|
||||
|
||||
// Idempotent timestamp check
|
||||
if let Some(ref ts_str) = init_req.timestamp {
|
||||
if let Ok(ts) = chrono_parse_to_nanos(ts_str) {
|
||||
if !state.last_set_time.set_to_greater(ts) {
|
||||
// Stale request, skip data updates
|
||||
return trigger_restore_and_respond(&state).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply env vars
|
||||
if let Some(ref vars) = init_req.env_vars {
|
||||
tracing::debug!(count = vars.len(), "setting env vars");
|
||||
for (k, v) in vars {
|
||||
state.defaults.env_vars.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set access token
|
||||
if let Some(ref token_str) = init_req.access_token {
|
||||
if !token_str.is_empty() {
|
||||
tracing::debug!("setting access token");
|
||||
let _ = state.access_token.set(token_str.as_bytes());
|
||||
} else if state.access_token.is_set() {
|
||||
tracing::debug!("clearing access token");
|
||||
state.access_token.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
// Set default user
|
||||
if let Some(ref user) = init_req.default_user {
|
||||
if !user.is_empty() {
|
||||
tracing::debug!(user = %user, "setting default user");
|
||||
state.defaults.set_user(user.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set default workdir
|
||||
if let Some(ref workdir) = init_req.default_workdir {
|
||||
if !workdir.is_empty() {
|
||||
tracing::debug!(workdir = %workdir, "setting default workdir");
|
||||
state.defaults.set_workdir(Some(workdir.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Hyperloop /etc/hosts setup
|
||||
if let Some(ref ip) = init_req.hyperloop_ip {
|
||||
let ip = ip.clone();
|
||||
let env_vars = Arc::clone(&state.defaults.env_vars);
|
||||
tokio::spawn(async move {
|
||||
setup_hyperloop(&ip, &env_vars).await;
|
||||
});
|
||||
}
|
||||
|
||||
// NFS mounts
|
||||
if let Some(ref mounts) = init_req.volume_mounts {
|
||||
for mount in mounts {
|
||||
let target = mount.nfs_target.clone();
|
||||
let path = mount.path.clone();
|
||||
tokio::spawn(async move {
|
||||
setup_nfs(&target, &path).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Re-poll MMDS in background
|
||||
if state.is_fc {
|
||||
let env_vars = Arc::clone(&state.defaults.env_vars);
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
let cancel_clone = cancel.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::timeout(std::time::Duration::from_secs(60), async {
|
||||
mmds::poll_for_opts(env_vars, cancel_clone).await;
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
});
|
||||
}
|
||||
|
||||
trigger_restore_and_respond(&state).await
|
||||
}
|
||||
|
||||
async fn trigger_restore_and_respond(state: &AppState) -> axum::response::Response {
|
||||
// Safety net: if health check's postRestoreRecovery hasn't run yet
|
||||
if state
|
||||
.needs_restore
|
||||
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok()
|
||||
{
|
||||
post_restore_recovery(state);
|
||||
}
|
||||
|
||||
state.conn_tracker.restore_after_snapshot();
|
||||
if let Some(ref ps) = state.port_subsystem {
|
||||
ps.restart();
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::NO_CONTENT,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn post_restore_recovery(state: &AppState) {
|
||||
tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
|
||||
|
||||
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
|
||||
|
||||
state.conn_tracker.restore_after_snapshot();
|
||||
|
||||
if let Some(ref ps) = state.port_subsystem {
|
||||
ps.restart();
|
||||
tracing::info!("restore: port subsystem restarted");
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_init_access_token(state: &AppState, request_token: &str) -> Result<(), String> {
|
||||
// Fast path: matches existing token
|
||||
if state.access_token.is_set() && !request_token.is_empty() && state.access_token.equals(request_token) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check MMDS hash
|
||||
if state.is_fc {
|
||||
if let Ok(mmds_hash) = mmds::get_access_token_hash().await {
|
||||
if !mmds_hash.is_empty() {
|
||||
if request_token.is_empty() {
|
||||
let empty_hash = crypto::sha512::hash_access_token("");
|
||||
if mmds_hash == empty_hash {
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
let token_hash = crypto::sha512::hash_access_token(request_token);
|
||||
if mmds_hash == token_hash {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
return Err("access token validation failed".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First-time setup: no existing token and no MMDS
|
||||
if !state.access_token.is_set() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if request_token.is_empty() {
|
||||
return Err("access token reset not authorized".into());
|
||||
}
|
||||
|
||||
Err("access token validation failed".into())
|
||||
}
|
||||
|
||||
async fn setup_hyperloop(address: &str, env_vars: &dashmap::DashMap<String, String>) {
|
||||
// Write to /etc/hosts: events.wrenn.local → address
|
||||
let entry = format!("{address} events.wrenn.local\n");
|
||||
|
||||
match std::fs::read_to_string("/etc/hosts") {
|
||||
Ok(contents) => {
|
||||
let filtered: String = contents
|
||||
.lines()
|
||||
.filter(|line| !line.contains("events.wrenn.local"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let new_contents = format!("{filtered}\n{entry}");
|
||||
if let Err(e) = std::fs::write("/etc/hosts", new_contents) {
|
||||
tracing::error!(error = %e, "failed to modify hosts file");
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to read hosts file");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
env_vars.insert(
|
||||
"WRENN_EVENTS_ADDRESS".into(),
|
||||
format!("http://{address}"),
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_nfs(nfs_target: &str, path: &str) {
|
||||
let mkdir = tokio::process::Command::new("mkdir")
|
||||
.args(["-p", path])
|
||||
.output()
|
||||
.await;
|
||||
if let Err(e) = mkdir {
|
||||
tracing::error!(error = %e, path, "nfs: mkdir failed");
|
||||
return;
|
||||
}
|
||||
|
||||
let mount = tokio::process::Command::new("mount")
|
||||
.args([
|
||||
"-v",
|
||||
"-t",
|
||||
"nfs",
|
||||
"-o",
|
||||
"mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl",
|
||||
nfs_target,
|
||||
path,
|
||||
])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match mount {
|
||||
Ok(output) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
if output.status.success() {
|
||||
tracing::info!(nfs_target, path, stdout = %stdout, "nfs: mount success");
|
||||
} else {
|
||||
tracing::error!(nfs_target, path, stderr = %stderr, "nfs: mount failed");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, nfs_target, path, "nfs: mount command failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn chrono_parse_to_nanos(ts: &str) -> Result<i64, ()> {
|
||||
// Parse RFC3339 timestamp to nanoseconds since epoch
|
||||
// Simple approach: parse as seconds + fractional
|
||||
let secs = ts.parse::<f64>().ok();
|
||||
if let Some(s) = secs {
|
||||
return Ok((s * 1_000_000_000.0) as i64);
|
||||
}
|
||||
// Try RFC3339 format
|
||||
// For now, fall back to allowing the update
|
||||
Err(())
|
||||
}
|
||||
89
envd-rs/src/http/metrics.rs
Normal file
89
envd-rs/src/http/metrics.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Metrics {
|
||||
ts: i64,
|
||||
cpu_count: u32,
|
||||
cpu_used_pct: f32,
|
||||
mem_total_mib: u64,
|
||||
mem_used_mib: u64,
|
||||
mem_total: u64,
|
||||
mem_used: u64,
|
||||
disk_used: u64,
|
||||
disk_total: u64,
|
||||
}
|
||||
|
||||
pub async fn get_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
tracing::trace!("get metrics");
|
||||
|
||||
match collect_metrics(&state) {
|
||||
Ok(m) => (
|
||||
StatusCode::OK,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(m),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to get metrics");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_metrics(state: &AppState) -> Result<Metrics, String> {
|
||||
let cpu_count = state.cpu_count();
|
||||
let cpu_used_pct_rounded = state.cpu_used_pct();
|
||||
|
||||
let mut sys = sysinfo::System::new();
|
||||
sys.refresh_memory();
|
||||
let mem_total = sys.total_memory();
|
||||
let mem_available = sys.available_memory();
|
||||
let mem_used = mem_total.saturating_sub(mem_available);
|
||||
let mem_total_mib = mem_total / 1024 / 1024;
|
||||
let mem_used_mib = mem_used / 1024 / 1024;
|
||||
|
||||
let (disk_total, disk_used) = disk_stats("/").map_err(|e| e.to_string())?;
|
||||
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
Ok(Metrics {
|
||||
ts,
|
||||
cpu_count,
|
||||
cpu_used_pct: cpu_used_pct_rounded,
|
||||
mem_total_mib,
|
||||
mem_used_mib,
|
||||
mem_total,
|
||||
mem_used,
|
||||
disk_used,
|
||||
disk_total,
|
||||
})
|
||||
}
|
||||
|
||||
fn disk_stats(path: &str) -> Result<(u64, u64), nix::Error> {
|
||||
use std::ffi::CString;
|
||||
|
||||
let c_path = CString::new(path).unwrap();
|
||||
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
|
||||
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
|
||||
if ret != 0 {
|
||||
return Err(nix::Error::last());
|
||||
}
|
||||
|
||||
let block = stat.f_bsize as u64;
|
||||
let total = stat.f_blocks * block;
|
||||
let available = stat.f_bavail * block;
|
||||
|
||||
Ok((total, total - available))
|
||||
}
|
||||
56
envd-rs/src/http/mod.rs
Normal file
56
envd-rs/src/http/mod.rs
Normal file
@ -0,0 +1,56 @@
|
||||
pub mod encoding;
|
||||
pub mod envs;
|
||||
pub mod error;
|
||||
pub mod files;
|
||||
pub mod health;
|
||||
pub mod init;
|
||||
pub mod metrics;
|
||||
pub mod snapshot;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::Router;
|
||||
use axum::routing::{get, post};
|
||||
use http::header::{CACHE_CONTROL, HeaderName};
|
||||
use http::Method;
|
||||
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
|
||||
|
||||
use crate::config::CORS_MAX_AGE;
|
||||
use crate::state::AppState;
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::any())
|
||||
.allow_methods(AllowMethods::list([
|
||||
Method::HEAD,
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
]))
|
||||
.allow_headers(AllowHeaders::any())
|
||||
.expose_headers([
|
||||
HeaderName::from_static("location"),
|
||||
CACHE_CONTROL,
|
||||
HeaderName::from_static("x-content-type-options"),
|
||||
HeaderName::from_static("connect-content-encoding"),
|
||||
HeaderName::from_static("connect-protocol-version"),
|
||||
HeaderName::from_static("grpc-encoding"),
|
||||
HeaderName::from_static("grpc-message"),
|
||||
HeaderName::from_static("grpc-status"),
|
||||
HeaderName::from_static("grpc-status-details-bin"),
|
||||
])
|
||||
.max_age(Duration::from_secs(CORS_MAX_AGE.as_secs()));
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health::get_health))
|
||||
.route("/metrics", get(metrics::get_metrics))
|
||||
.route("/envs", get(envs::get_envs))
|
||||
.route("/init", post(init::post_init))
|
||||
.route("/snapshot/prepare", post(snapshot::post_snapshot_prepare))
|
||||
.route("/files", get(files::get_files).post(files::post_files))
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
}
|
||||
49
envd-rs/src/http/snapshot.rs
Normal file
49
envd-rs/src/http/snapshot.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// POST /snapshot/prepare — quiesce subsystems before Firecracker snapshot.
|
||||
///
|
||||
/// In Rust there is no GC dance. We just:
|
||||
/// 1. Drop page cache to shrink snapshot size
|
||||
/// 2. Stop port subsystem
|
||||
/// 3. Close idle connections via conntracker
|
||||
/// 4. Set needs_restore flag
|
||||
pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
// Drop page cache BEFORE blocking the reclaimer — avoids snapshotting
|
||||
// gigabytes of stale cache that inflates the memory dump on disk.
|
||||
// "1" = pagecache only (keep dentries/inodes for faster resume).
|
||||
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "1") {
|
||||
tracing::warn!(error = %e, "snapshot/prepare: drop_caches failed");
|
||||
} else {
|
||||
tracing::info!("snapshot/prepare: page cache dropped");
|
||||
}
|
||||
|
||||
// Block memory reclaimer — prevents drop_caches from running mid-freeze
|
||||
// which would corrupt kernel page table state.
|
||||
state.snapshot_in_progress.store(true, Ordering::Release);
|
||||
|
||||
if let Some(ref ps) = state.port_subsystem {
|
||||
ps.stop();
|
||||
tracing::info!("snapshot/prepare: port subsystem stopped");
|
||||
}
|
||||
|
||||
state.conn_tracker.prepare_for_snapshot();
|
||||
tracing::info!("snapshot/prepare: connections prepared");
|
||||
|
||||
// Sync filesystem buffers so dirty pages are flushed before freeze.
|
||||
unsafe { libc::sync(); }
|
||||
|
||||
state.needs_restore.store(true, Ordering::Release);
|
||||
tracing::info!("snapshot/prepare: ready for freeze");
|
||||
|
||||
(
|
||||
StatusCode::NO_CONTENT,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
)
|
||||
}
|
||||
17
envd-rs/src/logging.rs
Normal file
17
envd-rs/src/logging.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
pub fn init(json: bool) {
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
if json {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer().json().flatten_event(true))
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer())
|
||||
.init();
|
||||
}
|
||||
}
|
||||
273
envd-rs/src/main.rs
Normal file
273
envd-rs/src/main.rs
Normal file
@ -0,0 +1,273 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod auth;
|
||||
mod cgroups;
|
||||
mod config;
|
||||
mod conntracker;
|
||||
mod crypto;
|
||||
mod execcontext;
|
||||
mod host;
|
||||
mod http;
|
||||
mod logging;
|
||||
mod permissions;
|
||||
mod port;
|
||||
mod rpc;
|
||||
mod state;
|
||||
mod util;
|
||||
|
||||
use std::fs;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use config::{DEFAULT_PORT, DEFAULT_USER, WRENN_RUN_DIR};
|
||||
use execcontext::Defaults;
|
||||
use port::subsystem::PortSubsystem;
|
||||
use state::AppState;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
const COMMIT: &str = {
|
||||
match option_env!("ENVD_COMMIT") {
|
||||
Some(c) => c,
|
||||
None => "unknown",
|
||||
}
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "envd", about = "Wrenn guest agent daemon")]
|
||||
struct Cli {
|
||||
#[arg(long, default_value_t = DEFAULT_PORT)]
|
||||
port: u16,
|
||||
|
||||
#[arg(long = "isnotfc", default_value_t = false)]
|
||||
is_not_fc: bool,
|
||||
|
||||
#[arg(long)]
|
||||
version: bool,
|
||||
|
||||
#[arg(long)]
|
||||
commit: bool,
|
||||
|
||||
#[arg(long = "cmd", default_value = "")]
|
||||
start_cmd: String,
|
||||
|
||||
#[arg(long = "cgroup-root", default_value = "/sys/fs/cgroup")]
|
||||
cgroup_root: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cli = Cli::parse();
|
||||
|
||||
if cli.version {
|
||||
println!("{VERSION}");
|
||||
return;
|
||||
}
|
||||
if cli.commit {
|
||||
println!("{COMMIT}");
|
||||
return;
|
||||
}
|
||||
|
||||
let use_json = !cli.is_not_fc;
|
||||
logging::init(use_json);
|
||||
|
||||
if let Err(e) = fs::create_dir_all(WRENN_RUN_DIR) {
|
||||
tracing::error!(error = %e, "failed to create wrenn run directory");
|
||||
}
|
||||
|
||||
let defaults = Defaults::new(DEFAULT_USER);
|
||||
let is_fc_str = if cli.is_not_fc { "false" } else { "true" };
|
||||
defaults
|
||||
.env_vars
|
||||
.insert("WRENN_SANDBOX".into(), is_fc_str.into());
|
||||
|
||||
let wrenn_sandbox_path = Path::new(WRENN_RUN_DIR).join(".WRENN_SANDBOX");
|
||||
if let Err(e) = fs::write(&wrenn_sandbox_path, is_fc_str.as_bytes()) {
|
||||
tracing::error!(error = %e, "failed to write sandbox file");
|
||||
}
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
// MMDS polling (only in FC mode)
|
||||
if !cli.is_not_fc {
|
||||
let env_vars = Arc::clone(&defaults.env_vars);
|
||||
let cancel_clone = cancel.clone();
|
||||
tokio::spawn(async move {
|
||||
host::mmds::poll_for_opts(env_vars, cancel_clone).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Cgroup manager
|
||||
let cgroup_manager: Arc<dyn cgroups::CgroupManager> =
|
||||
match cgroups::Cgroup2Manager::new(
|
||||
&cli.cgroup_root,
|
||||
&[
|
||||
(
|
||||
cgroups::ProcessType::Pty,
|
||||
"wrenn/pty",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
(
|
||||
cgroups::ProcessType::User,
|
||||
"wrenn/user",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
(
|
||||
cgroups::ProcessType::Socat,
|
||||
"wrenn/socat",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
],
|
||||
) {
|
||||
Ok(m) => {
|
||||
tracing::info!("cgroup2 manager initialized");
|
||||
Arc::new(m)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "cgroup2 init failed, using noop");
|
||||
Arc::new(cgroups::NoopCgroupManager)
|
||||
}
|
||||
};
|
||||
|
||||
// Port subsystem
|
||||
let port_subsystem = Arc::new(PortSubsystem::new(Arc::clone(&cgroup_manager)));
|
||||
port_subsystem.start();
|
||||
tracing::info!("port subsystem started");
|
||||
|
||||
let state = AppState::new(
|
||||
defaults,
|
||||
VERSION.to_string(),
|
||||
COMMIT.to_string(),
|
||||
!cli.is_not_fc,
|
||||
Some(Arc::clone(&port_subsystem)),
|
||||
);
|
||||
|
||||
// Memory reclaimer — drop page cache when available memory is low.
|
||||
// Firecracker balloon device can only reclaim pages the guest kernel freed.
|
||||
// Pauses during snapshot/prepare to avoid corrupting kernel page table state.
|
||||
if !cli.is_not_fc {
|
||||
let state_for_reclaimer = Arc::clone(&state);
|
||||
std::thread::spawn(move || memory_reclaimer(state_for_reclaimer));
|
||||
}
|
||||
|
||||
// RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
|
||||
let connect_router = rpc::rpc_router(Arc::clone(&state));
|
||||
|
||||
let app = http::router(Arc::clone(&state))
|
||||
.fallback_service(connect_router.into_axum_service());
|
||||
|
||||
// --cmd: spawn initial process if specified
|
||||
if !cli.start_cmd.is_empty() {
|
||||
let cmd = cli.start_cmd.clone();
|
||||
let state_clone = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
spawn_initial_command(&cmd, &state_clone);
|
||||
});
|
||||
}
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
|
||||
tracing::info!(port = cli.port, version = VERSION, commit = COMMIT, "envd starting");
|
||||
|
||||
let listener = TcpListener::bind(addr).await.expect("failed to bind");
|
||||
|
||||
let graceful = axum::serve(listener, app).with_graceful_shutdown(async move {
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM")
|
||||
.recv()
|
||||
.await;
|
||||
tracing::info!("SIGTERM received, shutting down");
|
||||
});
|
||||
|
||||
if let Err(e) = graceful.await {
|
||||
tracing::error!(error = %e, "server error");
|
||||
}
|
||||
|
||||
port_subsystem.stop();
|
||||
cancel.cancel();
|
||||
}
|
||||
|
||||
fn spawn_initial_command(cmd: &str, state: &AppState) {
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::process_handler;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let user = match lookup_user(&default_user) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "cmd: failed to lookup user");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let home = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
let cwd = default_workdir
|
||||
.as_deref()
|
||||
.unwrap_or(&home);
|
||||
|
||||
match process_handler::spawn_process(
|
||||
cmd,
|
||||
&[],
|
||||
&HashMap::new(),
|
||||
cwd,
|
||||
None,
|
||||
false,
|
||||
Some("init-cmd".to_string()),
|
||||
&user,
|
||||
&state.defaults.env_vars,
|
||||
) {
|
||||
Ok(spawned) => {
|
||||
tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, cmd, "failed to spawn initial command");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn memory_reclaimer(state: Arc<AppState>) {
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
const DROP_THRESHOLD_PCT: u64 = 80;
|
||||
|
||||
loop {
|
||||
std::thread::sleep(CHECK_INTERVAL);
|
||||
|
||||
if state.snapshot_in_progress.load(Ordering::Acquire) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut sys = sysinfo::System::new();
|
||||
sys.refresh_memory();
|
||||
let total = sys.total_memory();
|
||||
let available = sys.available_memory();
|
||||
|
||||
if total == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let used_pct = ((total - available) * 100) / total;
|
||||
if used_pct >= DROP_THRESHOLD_PCT {
|
||||
if state.snapshot_in_progress.load(Ordering::Acquire) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
|
||||
tracing::debug!(error = %e, "drop_caches failed");
|
||||
} else {
|
||||
let mut sys2 = sysinfo::System::new();
|
||||
sys2.refresh_memory();
|
||||
let freed_mb =
|
||||
sys2.available_memory().saturating_sub(available) / (1024 * 1024);
|
||||
tracing::info!(used_pct, freed_mb, "page cache dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
2
envd-rs/src/permissions/mod.rs
Normal file
2
envd-rs/src/permissions/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod user;
|
||||
pub mod path;
|
||||
184
envd-rs/src/permissions/path.rs
Normal file
184
envd-rs/src/permissions/path.rs
Normal file
@ -0,0 +1,184 @@
|
||||
use std::fs;
|
||||
use std::os::unix::fs::chown;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use nix::unistd::{Gid, Uid};
|
||||
|
||||
fn expand_tilde(path: &str, home_dir: &str) -> Result<String, String> {
|
||||
if path.is_empty() || !path.starts_with('~') {
|
||||
return Ok(path.to_string());
|
||||
}
|
||||
if path.len() > 1 && path.as_bytes()[1] != b'/' && path.as_bytes()[1] != b'\\' {
|
||||
return Err("cannot expand user-specific home dir".into());
|
||||
}
|
||||
Ok(format!("{}{}", home_dir, &path[1..]))
|
||||
}
|
||||
|
||||
pub fn expand_and_resolve(
|
||||
path: &str,
|
||||
home_dir: &str,
|
||||
default_path: Option<&str>,
|
||||
) -> Result<String, String> {
|
||||
let path = if path.is_empty() {
|
||||
default_path.unwrap_or("").to_string()
|
||||
} else {
|
||||
path.to_string()
|
||||
};
|
||||
|
||||
let path = expand_tilde(&path, home_dir)?;
|
||||
|
||||
if Path::new(&path).is_absolute() {
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
let joined = PathBuf::from(home_dir).join(&path);
|
||||
joined
|
||||
.canonicalize()
|
||||
.or_else(|_| Ok(joined))
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> {
|
||||
let path = Path::new(path);
|
||||
let mut current = PathBuf::new();
|
||||
|
||||
for component in path.components() {
|
||||
current.push(component);
|
||||
let current_str = current.to_string_lossy();
|
||||
|
||||
if current_str == "/" {
|
||||
continue;
|
||||
}
|
||||
|
||||
match fs::metadata(¤t) {
|
||||
Ok(meta) => {
|
||||
if !meta.is_dir() {
|
||||
return Err(format!("path is a file: {current_str}"));
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
fs::create_dir(¤t)
|
||||
.map_err(|e| format!("failed to create directory {current_str}: {e}"))?;
|
||||
chown(¤t, Some(uid.as_raw()), Some(gid.as_raw()))
|
||||
.map_err(|e| format!("failed to chown directory {current_str}: {e}"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(format!("failed to stat directory {current_str}: {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// expand_tilde
|
||||
|
||||
#[test]
|
||||
fn tilde_empty_passthrough() {
|
||||
assert_eq!(expand_tilde("", "/home/u").unwrap(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_no_tilde_passthrough() {
|
||||
assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_bare() {
|
||||
assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_slash_path() {
|
||||
assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_nested() {
|
||||
assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_other_user_errors() {
|
||||
assert!(expand_tilde("~bob/foo", "/home/user").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_relative_no_tilde() {
|
||||
assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path");
|
||||
}
|
||||
|
||||
// expand_and_resolve
|
||||
|
||||
#[test]
|
||||
fn resolve_absolute_passthrough() {
|
||||
assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_empty_uses_default() {
|
||||
assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_empty_no_default_falls_back_to_home() {
|
||||
// Empty path with no default → joins "" with home_dir → returns home_dir
|
||||
let result = expand_and_resolve("", "/home", None).unwrap();
|
||||
assert_eq!(result, "/home");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_tilde_expands() {
|
||||
assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_relative_joins_home() {
|
||||
let result = expand_and_resolve("subdir", "/tmp", None).unwrap();
|
||||
// Relative path joined with home and canonicalized (or raw join on missing)
|
||||
assert!(result.starts_with("/tmp"));
|
||||
assert!(result.contains("subdir"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_tilde_other_user_errors() {
|
||||
assert!(expand_and_resolve("~bob", "/home/u", None).is_err());
|
||||
}
|
||||
|
||||
// ensure_dirs
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_creates_nested() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let path = tmp.path().join("a/b/c");
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap();
|
||||
assert!(path.is_dir());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_existing_is_ok() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_file_in_path_errors() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let file_path = tmp.path().join("afile");
|
||||
std::fs::write(&file_path, "").unwrap();
|
||||
let nested = file_path.join("subdir");
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
let result = ensure_dirs(nested.to_str().unwrap(), uid, gid);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("path is a file"));
|
||||
}
|
||||
}
|
||||
32
envd-rs/src/permissions/user.rs
Normal file
32
envd-rs/src/permissions/user.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use nix::unistd::{Gid, Group, Uid, User};
|
||||
|
||||
pub fn lookup_user(username: &str) -> Result<User, String> {
|
||||
User::from_name(username)
|
||||
.map_err(|e| format!("error looking up user '{username}': {e}"))?
|
||||
.ok_or_else(|| format!("user '{username}' not found"))
|
||||
}
|
||||
|
||||
pub fn get_uid_gid(user: &User) -> (Uid, Gid) {
|
||||
(user.uid, user.gid)
|
||||
}
|
||||
|
||||
pub fn get_user_groups(user: &User) -> Vec<Gid> {
|
||||
let c_name = std::ffi::CString::new(user.name.as_str()).unwrap();
|
||||
nix::unistd::getgrouplist(&c_name, user.gid).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn lookup_username_by_uid(uid: Uid) -> String {
|
||||
User::from_uid(uid)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|u| u.name)
|
||||
.unwrap_or_else(|| uid.to_string())
|
||||
}
|
||||
|
||||
pub fn lookup_groupname_by_gid(gid: Gid) -> String {
|
||||
Group::from_gid(gid)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|g| g.name)
|
||||
.unwrap_or_else(|| gid.to_string())
|
||||
}
|
||||
260
envd-rs/src/port/conn.rs
Normal file
260
envd-rs/src/port/conn.rs
Normal file
@ -0,0 +1,260 @@
|
||||
use std::io::{self, BufRead};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnStat {
|
||||
pub local_ip: String,
|
||||
pub local_port: u32,
|
||||
pub status: String,
|
||||
pub family: u32,
|
||||
pub inode: u64,
|
||||
}
|
||||
|
||||
fn tcp_state_name(hex: &str) -> &'static str {
|
||||
match hex {
|
||||
"01" => "ESTABLISHED",
|
||||
"02" => "SYN_SENT",
|
||||
"03" => "SYN_RECV",
|
||||
"04" => "FIN_WAIT1",
|
||||
"05" => "FIN_WAIT2",
|
||||
"06" => "TIME_WAIT",
|
||||
"07" => "CLOSE",
|
||||
"08" => "CLOSE_WAIT",
|
||||
"09" => "LAST_ACK",
|
||||
"0A" => "LISTEN",
|
||||
"0B" => "CLOSING",
|
||||
_ => "UNKNOWN",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_tcp_connections() -> Vec<ConnStat> {
|
||||
let mut conns = Vec::new();
|
||||
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp", libc::AF_INET as u32) {
|
||||
conns.extend(c);
|
||||
}
|
||||
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp6", libc::AF_INET6 as u32) {
|
||||
conns.extend(c);
|
||||
}
|
||||
conns
|
||||
}
|
||||
|
||||
fn parse_proc_net_tcp(path: &str, family: u32) -> io::Result<Vec<ConnStat>> {
|
||||
let file = std::fs::File::open(path)?;
|
||||
let reader = io::BufReader::new(file);
|
||||
let mut conns = Vec::new();
|
||||
let mut first = true;
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if first {
|
||||
first = false;
|
||||
continue;
|
||||
}
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let fields: Vec<&str> = line.split_whitespace().collect();
|
||||
if fields.len() < 10 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (ip, port) = match parse_hex_addr(fields[1], family) {
|
||||
Some(v) => v,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let state = tcp_state_name(fields[3]);
|
||||
|
||||
let inode: u64 = match fields[9].parse() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
conns.push(ConnStat {
|
||||
local_ip: ip,
|
||||
local_port: port,
|
||||
status: state.to_string(),
|
||||
family,
|
||||
inode,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(conns)
|
||||
}
|
||||
|
||||
fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> {
|
||||
let (ip_hex, port_hex) = s.split_once(':')?;
|
||||
let port = u32::from_str_radix(port_hex, 16).ok()?;
|
||||
let ip_bytes = hex::decode(ip_hex).ok()?;
|
||||
|
||||
let ip_str = if family == libc::AF_INET as u32 {
|
||||
if ip_bytes.len() != 4 {
|
||||
return None;
|
||||
}
|
||||
format!("{}.{}.{}.{}", ip_bytes[3], ip_bytes[2], ip_bytes[1], ip_bytes[0])
|
||||
} else {
|
||||
if ip_bytes.len() != 16 {
|
||||
return None;
|
||||
}
|
||||
let mut octets = [0u8; 16];
|
||||
for i in 0..4 {
|
||||
octets[i * 4] = ip_bytes[i * 4 + 3];
|
||||
octets[i * 4 + 1] = ip_bytes[i * 4 + 2];
|
||||
octets[i * 4 + 2] = ip_bytes[i * 4 + 1];
|
||||
octets[i * 4 + 3] = ip_bytes[i * 4];
|
||||
}
|
||||
let addr = std::net::Ipv6Addr::from(octets);
|
||||
addr.to_string()
|
||||
};
|
||||
|
||||
Some((ip_str, port))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
// tcp_state_name
|
||||
|
||||
#[test]
|
||||
fn state_all_known_codes() {
|
||||
assert_eq!(tcp_state_name("01"), "ESTABLISHED");
|
||||
assert_eq!(tcp_state_name("02"), "SYN_SENT");
|
||||
assert_eq!(tcp_state_name("03"), "SYN_RECV");
|
||||
assert_eq!(tcp_state_name("04"), "FIN_WAIT1");
|
||||
assert_eq!(tcp_state_name("05"), "FIN_WAIT2");
|
||||
assert_eq!(tcp_state_name("06"), "TIME_WAIT");
|
||||
assert_eq!(tcp_state_name("07"), "CLOSE");
|
||||
assert_eq!(tcp_state_name("08"), "CLOSE_WAIT");
|
||||
assert_eq!(tcp_state_name("09"), "LAST_ACK");
|
||||
assert_eq!(tcp_state_name("0A"), "LISTEN");
|
||||
assert_eq!(tcp_state_name("0B"), "CLOSING");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_unknown_code() {
|
||||
assert_eq!(tcp_state_name("FF"), "UNKNOWN");
|
||||
assert_eq!(tcp_state_name("00"), "UNKNOWN");
|
||||
}
|
||||
|
||||
// parse_hex_addr
|
||||
|
||||
#[test]
|
||||
fn ipv4_localhost() {
|
||||
let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "127.0.0.1");
|
||||
assert_eq!(port, 80);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_any() {
|
||||
let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "0.0.0.0");
|
||||
assert_eq!(port, 53);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_real_address() {
|
||||
// 192.168.1.1 in little-endian = 0101A8C0
|
||||
let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "192.168.1.1");
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_wrong_byte_count_returns_none() {
|
||||
assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_hex_returns_none() {
|
||||
assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_colon_returns_none() {
|
||||
assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv6_loopback() {
|
||||
// ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000
|
||||
let (ip, port) = parse_hex_addr(
|
||||
"00000000000000000000000001000000:0035",
|
||||
libc::AF_INET6 as u32,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(ip, "::1");
|
||||
assert_eq!(port, 53);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv6_wrong_byte_count_returns_none() {
|
||||
assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none());
|
||||
}
|
||||
|
||||
// parse_proc_net_tcp
|
||||
|
||||
fn write_tcp_file(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_file() {
|
||||
let f = write_tcp_file(
|
||||
" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n",
|
||||
);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert!(conns.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_single_entry() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 1);
|
||||
assert_eq!(conns[0].local_ip, "127.0.0.1");
|
||||
assert_eq!(conns[0].local_port, 80);
|
||||
assert_eq!(conns[0].status, "LISTEN");
|
||||
assert_eq!(conns[0].inode, 12345);
|
||||
assert_eq!(conns[0].family, libc::AF_INET as u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skips_malformed_rows() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000
|
||||
bad line
|
||||
1: short\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_entries() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000
|
||||
1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 2);
|
||||
assert_eq!(conns[0].local_port, 80);
|
||||
assert_eq!(conns[1].local_port, 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_nonexistent_file_errors() {
|
||||
assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err());
|
||||
}
|
||||
}
|
||||
181
envd-rs/src/port/forwarder.rs
Normal file
181
envd-rs/src/port/forwarder.rs
Normal file
@ -0,0 +1,181 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::cgroups::{CgroupManager, ProcessType};
|
||||
|
||||
use super::conn::ConnStat;
|
||||
|
||||
const DEFAULT_GATEWAY_IP: &str = "169.254.0.21";
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum PortState {
|
||||
Forward,
|
||||
Delete,
|
||||
}
|
||||
|
||||
struct PortToForward {
|
||||
pid: Option<u32>,
|
||||
inode: u64,
|
||||
family: u32,
|
||||
state: PortState,
|
||||
port: u32,
|
||||
}
|
||||
|
||||
fn family_to_ip_version(family: u32) -> u32 {
|
||||
if family == libc::AF_INET as u32 {
|
||||
4
|
||||
} else if family == libc::AF_INET6 as u32 {
|
||||
6
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Forwarder {
|
||||
cgroup_manager: Arc<dyn CgroupManager>,
|
||||
ports: HashMap<String, PortToForward>,
|
||||
source_ip: String,
|
||||
}
|
||||
|
||||
impl Forwarder {
|
||||
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
|
||||
Self {
|
||||
cgroup_manager,
|
||||
ports: HashMap::new(),
|
||||
source_ip: DEFAULT_GATEWAY_IP.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_forwarding(
|
||||
&mut self,
|
||||
mut rx: mpsc::Receiver<Vec<ConnStat>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
self.stop_all();
|
||||
return;
|
||||
}
|
||||
msg = rx.recv() => {
|
||||
match msg {
|
||||
Some(conns) => self.process_scan(conns),
|
||||
None => {
|
||||
self.stop_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_scan(&mut self, conns: Vec<ConnStat>) {
|
||||
for ptf in self.ports.values_mut() {
|
||||
ptf.state = PortState::Delete;
|
||||
}
|
||||
|
||||
for conn in &conns {
|
||||
let key = format!("{}-{}", conn.inode, conn.local_port);
|
||||
if let Some(ptf) = self.ports.get_mut(&key) {
|
||||
ptf.state = PortState::Forward;
|
||||
} else {
|
||||
tracing::debug!(
|
||||
ip = %conn.local_ip,
|
||||
port = conn.local_port,
|
||||
family = family_to_ip_version(conn.family),
|
||||
"detected new port on localhost"
|
||||
);
|
||||
let mut ptf = PortToForward {
|
||||
pid: None,
|
||||
inode: conn.inode,
|
||||
family: family_to_ip_version(conn.family),
|
||||
state: PortState::Forward,
|
||||
port: conn.local_port,
|
||||
};
|
||||
self.start_port_forwarding(&mut ptf);
|
||||
self.ports.insert(key, ptf);
|
||||
}
|
||||
}
|
||||
|
||||
let to_stop: Vec<String> = self
|
||||
.ports
|
||||
.iter()
|
||||
.filter(|(_, v)| v.state == PortState::Delete)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
|
||||
for key in to_stop {
|
||||
if let Some(ptf) = self.ports.get(&key) {
|
||||
stop_port_forwarding(ptf);
|
||||
}
|
||||
self.ports.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
fn start_port_forwarding(&self, ptf: &mut PortToForward) {
|
||||
let listen_arg = format!(
|
||||
"TCP4-LISTEN:{},bind={},reuseaddr,fork",
|
||||
ptf.port, self.source_ip
|
||||
);
|
||||
let connect_arg = format!("TCP{}:localhost:{}", ptf.family, ptf.port);
|
||||
|
||||
let mut cmd = Command::new("socat");
|
||||
cmd.args(["-d", "-d", "-d", &listen_arg, &connect_arg]);
|
||||
|
||||
unsafe {
|
||||
let cgroup_fd = self.cgroup_manager.get_fd(ProcessType::Socat);
|
||||
cmd.pre_exec(move || {
|
||||
libc::setpgid(0, 0);
|
||||
if let Some(fd) = cgroup_fd {
|
||||
let pid_str = format!("{}", libc::getpid());
|
||||
let tasks_path = format!("/proc/self/fd/{}/cgroup.procs", fd);
|
||||
let _ = std::fs::write(&tasks_path, pid_str.as_bytes());
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
port = ptf.port,
|
||||
inode = ptf.inode,
|
||||
family = ptf.family,
|
||||
source_ip = %self.source_ip,
|
||||
"starting port forwarding"
|
||||
);
|
||||
|
||||
match cmd.spawn() {
|
||||
Ok(child) => {
|
||||
ptf.pid = Some(child.id());
|
||||
std::thread::spawn(move || {
|
||||
let mut child = child;
|
||||
let _ = child.wait();
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, port = ptf.port, "failed to start socat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_all(&mut self) {
|
||||
for ptf in self.ports.values() {
|
||||
stop_port_forwarding(ptf);
|
||||
}
|
||||
self.ports.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_port_forwarding(ptf: &PortToForward) {
|
||||
if let Some(pid) = ptf.pid {
|
||||
tracing::debug!(port = ptf.port, pid, "stopping port forwarding");
|
||||
unsafe {
|
||||
libc::kill(-(pid as i32), libc::SIGKILL);
|
||||
}
|
||||
}
|
||||
}
|
||||
4
envd-rs/src/port/mod.rs
Normal file
4
envd-rs/src/port/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod conn;
|
||||
pub mod forwarder;
|
||||
pub mod scanner;
|
||||
pub mod subsystem;
|
||||
79
envd-rs/src/port/scanner.rs
Normal file
79
envd-rs/src/port/scanner.rs
Normal file
@ -0,0 +1,79 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::conn::{ConnStat, read_tcp_connections};
|
||||
|
||||
pub struct ScannerFilter {
|
||||
pub ips: Vec<String>,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
impl ScannerFilter {
|
||||
pub fn matches(&self, conn: &ConnStat) -> bool {
|
||||
if self.state.is_empty() && self.ips.is_empty() {
|
||||
return false;
|
||||
}
|
||||
self.ips.contains(&conn.local_ip) && self.state == conn.status
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ScannerSubscriber {
|
||||
pub tx: mpsc::Sender<Vec<ConnStat>>,
|
||||
pub filter: Option<ScannerFilter>,
|
||||
}
|
||||
|
||||
pub struct Scanner {
|
||||
period: Duration,
|
||||
subs: RwLock<Vec<(String, Arc<ScannerSubscriber>)>>,
|
||||
}
|
||||
|
||||
impl Scanner {
|
||||
pub fn new(period: Duration) -> Self {
|
||||
Self {
|
||||
period,
|
||||
subs: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_subscriber(
|
||||
&self,
|
||||
id: &str,
|
||||
filter: Option<ScannerFilter>,
|
||||
) -> mpsc::Receiver<Vec<ConnStat>> {
|
||||
let (tx, rx) = mpsc::channel(4);
|
||||
let sub = Arc::new(ScannerSubscriber { tx, filter });
|
||||
let mut subs = self.subs.write().unwrap();
|
||||
subs.push((id.to_string(), sub));
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn remove_subscriber(&self, id: &str) {
|
||||
let mut subs = self.subs.write().unwrap();
|
||||
subs.retain(|(sid, _)| sid != id);
|
||||
}
|
||||
|
||||
pub async fn scan_and_broadcast(&self, cancel: CancellationToken) {
|
||||
loop {
|
||||
let conns = read_tcp_connections();
|
||||
|
||||
{
|
||||
let subs = self.subs.read().unwrap();
|
||||
for (_, sub) in subs.iter() {
|
||||
let payload = match &sub.filter {
|
||||
Some(f) => conns.iter().filter(|c| f.matches(c)).cloned().collect(),
|
||||
None => conns.clone(),
|
||||
};
|
||||
let _ = sub.tx.try_send(payload);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => return,
|
||||
_ = tokio::time::sleep(self.period) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
78
envd-rs/src/port/subsystem.rs
Normal file
78
envd-rs/src/port/subsystem.rs
Normal file
@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::cgroups::CgroupManager;
|
||||
use crate::config::PORT_SCANNER_INTERVAL;
|
||||
|
||||
use super::forwarder::Forwarder;
|
||||
use super::scanner::{Scanner, ScannerFilter};
|
||||
|
||||
pub struct PortSubsystem {
|
||||
cgroup_manager: Arc<dyn CgroupManager>,
|
||||
cancel: std::sync::Mutex<Option<CancellationToken>>,
|
||||
}
|
||||
|
||||
impl PortSubsystem {
|
||||
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
|
||||
Self {
|
||||
cgroup_manager,
|
||||
cancel: std::sync::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self) {
|
||||
let mut guard = self.cancel.lock().unwrap();
|
||||
if guard.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
*guard = Some(cancel.clone());
|
||||
drop(guard);
|
||||
|
||||
let cgroup_manager = Arc::clone(&self.cgroup_manager);
|
||||
let cancel_scanner = cancel.clone();
|
||||
let cancel_forwarder = cancel.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let scanner = Arc::new(Scanner::new(PORT_SCANNER_INTERVAL));
|
||||
let rx = scanner.add_subscriber(
|
||||
"port-forwarder",
|
||||
Some(ScannerFilter {
|
||||
ips: vec![
|
||||
"127.0.0.1".to_string(),
|
||||
"localhost".to_string(),
|
||||
"::1".to_string(),
|
||||
],
|
||||
state: "LISTEN".to_string(),
|
||||
}),
|
||||
);
|
||||
|
||||
let scanner_clone = Arc::clone(&scanner);
|
||||
|
||||
let scanner_handle = tokio::spawn(async move {
|
||||
scanner_clone.scan_and_broadcast(cancel_scanner).await;
|
||||
});
|
||||
|
||||
let forwarder_handle = tokio::spawn(async move {
|
||||
let mut forwarder = Forwarder::new(cgroup_manager);
|
||||
forwarder.start_forwarding(rx, cancel_forwarder).await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(scanner_handle, forwarder_handle);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
let mut guard = self.cancel.lock().unwrap();
|
||||
if let Some(cancel) = guard.take() {
|
||||
cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn restart(&self) {
|
||||
self.stop();
|
||||
self.start();
|
||||
}
|
||||
}
|
||||
231
envd-rs/src/rpc/entry.rs
Normal file
231
envd-rs/src/rpc/entry.rs
Normal file
@ -0,0 +1,231 @@
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
use std::path::Path;
|
||||
|
||||
use connectrpc::{ConnectError, ErrorCode};
|
||||
|
||||
use crate::permissions::user::{lookup_groupname_by_gid, lookup_username_by_uid};
|
||||
use crate::rpc::pb::filesystem::{EntryInfo, FileType};
|
||||
use nix::unistd::{Gid, Uid};
|
||||
|
||||
const NFS_SUPER_MAGIC: i64 = 0x6969;
|
||||
const CIFS_MAGIC: i64 = 0xFF534D42;
|
||||
const SMB_SUPER_MAGIC: i64 = 0x517B;
|
||||
const SMB2_MAGIC_NUMBER: i64 = 0xFE534D42;
|
||||
const FUSE_SUPER_MAGIC: i64 = 0x65735546;
|
||||
|
||||
pub fn is_network_mount(path: &str) -> Result<bool, String> {
|
||||
let c_path = std::ffi::CString::new(path).map_err(|e| e.to_string())?;
|
||||
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
|
||||
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
|
||||
if ret != 0 {
|
||||
return Err(format!(
|
||||
"statfs {path}: {}",
|
||||
std::io::Error::last_os_error()
|
||||
));
|
||||
}
|
||||
let fs_type = stat.f_type as i64;
|
||||
Ok(matches!(
|
||||
fs_type,
|
||||
NFS_SUPER_MAGIC | CIFS_MAGIC | SMB_SUPER_MAGIC | SMB2_MAGIC_NUMBER | FUSE_SUPER_MAGIC
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_entry_info(path: &str) -> Result<EntryInfo, ConnectError> {
|
||||
let p = Path::new(path);
|
||||
|
||||
let lstat = std::fs::symlink_metadata(p).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("file not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
|
||||
}
|
||||
})?;
|
||||
|
||||
let is_symlink = lstat.file_type().is_symlink();
|
||||
|
||||
let (file_type, mode, symlink_target) = if is_symlink {
|
||||
let target = std::fs::canonicalize(p)
|
||||
.map(|t| t.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| path.to_string());
|
||||
|
||||
let target_type = match std::fs::metadata(p) {
|
||||
Ok(meta) => meta_to_file_type(&meta),
|
||||
Err(_) => FileType::FILE_TYPE_UNSPECIFIED,
|
||||
};
|
||||
|
||||
let target_mode = std::fs::metadata(p)
|
||||
.map(|m| m.mode() & 0o7777)
|
||||
.unwrap_or(0);
|
||||
|
||||
(target_type, target_mode, Some(target))
|
||||
} else {
|
||||
let ft = meta_to_file_type(&lstat);
|
||||
let mode = lstat.mode() & 0o7777;
|
||||
(ft, mode, None)
|
||||
};
|
||||
|
||||
let uid = lstat.uid();
|
||||
let gid = lstat.gid();
|
||||
let owner = lookup_username_by_uid(Uid::from_raw(uid));
|
||||
let group = lookup_groupname_by_gid(Gid::from_raw(gid));
|
||||
|
||||
let modified_time = {
|
||||
let mtime_sec = lstat.mtime();
|
||||
let mtime_nsec = lstat.mtime_nsec() as i32;
|
||||
if mtime_sec == 0 && mtime_nsec == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(buffa_types::google::protobuf::Timestamp {
|
||||
seconds: mtime_sec,
|
||||
nanos: mtime_nsec,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let name = p
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let permissions = format_permissions(lstat.mode());
|
||||
|
||||
Ok(EntryInfo {
|
||||
name,
|
||||
r#type: buffa::EnumValue::Known(file_type),
|
||||
path: path.to_string(),
|
||||
size: lstat.len() as i64,
|
||||
mode,
|
||||
permissions,
|
||||
owner,
|
||||
group,
|
||||
modified_time: modified_time.into(),
|
||||
symlink_target: symlink_target,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn meta_to_file_type(meta: &std::fs::Metadata) -> FileType {
|
||||
if meta.is_file() {
|
||||
FileType::FILE_TYPE_FILE
|
||||
} else if meta.is_dir() {
|
||||
FileType::FILE_TYPE_DIRECTORY
|
||||
} else if meta.file_type().is_symlink() {
|
||||
FileType::FILE_TYPE_SYMLINK
|
||||
} else {
|
||||
FileType::FILE_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
fn format_permissions(mode: u32) -> String {
|
||||
let file_type = match mode & libc::S_IFMT {
|
||||
libc::S_IFDIR => 'd',
|
||||
libc::S_IFLNK => 'L',
|
||||
libc::S_IFREG => '-',
|
||||
libc::S_IFBLK => 'b',
|
||||
libc::S_IFCHR => 'c',
|
||||
libc::S_IFIFO => 'p',
|
||||
libc::S_IFSOCK => 'S',
|
||||
_ => '?',
|
||||
};
|
||||
|
||||
let perms = mode & 0o777;
|
||||
let mut s = String::with_capacity(10);
|
||||
s.push(file_type);
|
||||
for shift in [6, 3, 0] {
|
||||
let bits = (perms >> shift) & 7;
|
||||
s.push(if bits & 4 != 0 { 'r' } else { '-' });
|
||||
s.push(if bits & 2 != 0 { 'w' } else { '-' });
|
||||
s.push(if bits & 1 != 0 { 'x' } else { '-' });
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// format_permissions
|
||||
|
||||
#[test]
|
||||
fn regular_file_755() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn directory_755() {
|
||||
assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symlink_777() {
|
||||
assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn regular_file_000() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn regular_file_644() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_device() {
|
||||
assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn char_device() {
|
||||
assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fifo() {
|
||||
assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn socket() {
|
||||
assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_type() {
|
||||
assert_eq!(format_permissions(0o755), "?rwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn setuid_in_mode_only_affects_lower_bits() {
|
||||
// setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755
|
||||
assert_eq!(
|
||||
format_permissions(libc::S_IFREG | 0o4755),
|
||||
format_permissions(libc::S_IFREG | 0o755),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_always_10_chars() {
|
||||
for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// meta_to_file_type — needs real filesystem
|
||||
|
||||
#[test]
|
||||
fn meta_regular_file() {
|
||||
let f = tempfile::NamedTempFile::new().unwrap();
|
||||
let meta = std::fs::metadata(f.path()).unwrap();
|
||||
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn meta_directory() {
|
||||
let d = tempfile::TempDir::new().unwrap();
|
||||
let meta = std::fs::metadata(d.path()).unwrap();
|
||||
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY);
|
||||
}
|
||||
}
|
||||
402
envd-rs/src/rpc/filesystem_service.rs
Normal file
402
envd-rs/src/rpc/filesystem_service.rs
Normal file
@ -0,0 +1,402 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use connectrpc::{ConnectError, Context, ErrorCode};
|
||||
use dashmap::DashMap;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::entry::build_entry_info;
|
||||
use crate::rpc::pb::filesystem::*;
|
||||
use crate::state::AppState;
|
||||
|
||||
pub struct FilesystemServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
watchers: DashMap<String, WatcherHandle>,
|
||||
}
|
||||
|
||||
struct WatcherHandle {
|
||||
events: Arc<Mutex<Vec<FilesystemEvent>>>,
|
||||
_watcher: notify::RecommendedWatcher,
|
||||
}
|
||||
|
||||
impl FilesystemServiceImpl {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
watchers: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> {
|
||||
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user = lookup_user(&username).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}"))
|
||||
})?;
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = self.state.defaults.workdir();
|
||||
|
||||
expand_and_resolve(path, &home_dir, default_workdir.as_deref())
|
||||
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_username(ctx: &Context) -> Option<String> {
|
||||
ctx.extensions.get::<AuthUser>().map(|u| u.0.clone())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthUser(pub String);
|
||||
|
||||
impl Filesystem for FilesystemServiceImpl {
|
||||
async fn stat(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<StatRequestView<'static>>,
|
||||
) -> Result<(StatResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
let entry = build_entry_info(&path)?;
|
||||
Ok((
|
||||
StatResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn make_dir(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<MakeDirRequestView<'static>>,
|
||||
) -> Result<(MakeDirResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
match std::fs::metadata(&path) {
|
||||
Ok(meta) => {
|
||||
if meta.is_dir() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::AlreadyExists,
|
||||
format!("directory already exists: {path}"),
|
||||
));
|
||||
}
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("path exists but is not a directory: {path}"),
|
||||
));
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error getting file info: {e}"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
ensure_dirs(&path, user.uid, user.gid)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
let entry = build_entry_info(&path)?;
|
||||
Ok((
|
||||
MakeDirResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn r#move(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<MoveRequestView<'static>>,
|
||||
) -> Result<(MoveResponse, Context), ConnectError> {
|
||||
let source = self.resolve_path(request.source, &ctx)?;
|
||||
let destination = self.resolve_path(request.destination, &ctx)?;
|
||||
|
||||
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
if let Some(parent) = Path::new(&destination).parent() {
|
||||
ensure_dirs(&parent.to_string_lossy(), user.uid, user.gid)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
}
|
||||
|
||||
std::fs::rename(&source, &destination).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("source not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error renaming: {e}"))
|
||||
}
|
||||
})?;
|
||||
|
||||
let entry = build_entry_info(&destination)?;
|
||||
Ok((
|
||||
MoveResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn list_dir(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<ListDirRequestView<'static>>,
|
||||
) -> Result<(ListDirResponse, Context), ConnectError> {
|
||||
let mut depth = request.depth as usize;
|
||||
if depth == 0 {
|
||||
depth = 1;
|
||||
}
|
||||
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
let resolved = std::fs::canonicalize(&path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("path not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error resolving path: {e}"))
|
||||
}
|
||||
})?;
|
||||
let resolved_str = resolved.to_string_lossy().to_string();
|
||||
|
||||
let meta = std::fs::metadata(&resolved).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
|
||||
})?;
|
||||
if !meta.is_dir() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("path is not a directory: {path}"),
|
||||
));
|
||||
}
|
||||
|
||||
let entries = walk_dir(&path, &resolved_str, depth)?;
|
||||
Ok((
|
||||
ListDirResponse {
|
||||
entries,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn remove(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<RemoveRequestView<'static>>,
|
||||
) -> Result<(RemoveResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
if let Err(e1) = std::fs::remove_dir_all(&path) {
|
||||
if let Err(e2) = std::fs::remove_file(&path) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error removing: {e1}; also tried as file: {e2}"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((RemoveResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn watch_dir(
|
||||
&self,
|
||||
_ctx: Context,
|
||||
_request: buffa::view::OwnedView<WatchDirRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<WatchDirResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
Err(ConnectError::new(
|
||||
ErrorCode::Unimplemented,
|
||||
"watch_dir streaming not yet implemented",
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_watcher(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<CreateWatcherRequestView<'static>>,
|
||||
) -> Result<(CreateWatcherResponse, Context), ConnectError> {
|
||||
use notify::{RecursiveMode, Watcher};
|
||||
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
let recursive = request.recursive;
|
||||
|
||||
if let Ok(true) = crate::rpc::entry::is_network_mount(&path) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"watching network mounts is not supported",
|
||||
));
|
||||
}
|
||||
|
||||
let watcher_id = simple_id();
|
||||
let events: Arc<Mutex<Vec<FilesystemEvent>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_cb = Arc::clone(&events);
|
||||
|
||||
let mut watcher = notify::recommended_watcher(
|
||||
move |res: Result<notify::Event, notify::Error>| {
|
||||
if let Ok(event) = res {
|
||||
let event_type = match event.kind {
|
||||
notify::EventKind::Create(_) => EventType::EVENT_TYPE_CREATE,
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Data(_)) => {
|
||||
EventType::EVENT_TYPE_WRITE
|
||||
}
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Metadata(_)) => {
|
||||
EventType::EVENT_TYPE_CHMOD
|
||||
}
|
||||
notify::EventKind::Remove(_) => EventType::EVENT_TYPE_REMOVE,
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Name(_)) => {
|
||||
EventType::EVENT_TYPE_RENAME
|
||||
}
|
||||
_ => return,
|
||||
};
|
||||
|
||||
for p in &event.paths {
|
||||
if let Ok(mut guard) = events_cb.lock() {
|
||||
guard.push(FilesystemEvent {
|
||||
name: p.to_string_lossy().to_string(),
|
||||
r#type: buffa::EnumValue::Known(event_type),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("failed to create watcher: {e}"))
|
||||
})?;
|
||||
|
||||
let mode = if recursive {
|
||||
RecursiveMode::Recursive
|
||||
} else {
|
||||
RecursiveMode::NonRecursive
|
||||
};
|
||||
|
||||
watcher.watch(Path::new(&path), mode).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("failed to watch path: {e}"))
|
||||
})?;
|
||||
|
||||
self.watchers.insert(
|
||||
watcher_id.clone(),
|
||||
WatcherHandle {
|
||||
events,
|
||||
_watcher: watcher,
|
||||
},
|
||||
);
|
||||
|
||||
Ok((
|
||||
CreateWatcherResponse {
|
||||
watcher_id,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_watcher_events(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<GetWatcherEventsRequestView<'static>>,
|
||||
) -> Result<(GetWatcherEventsResponse, Context), ConnectError> {
|
||||
let watcher_id: &str = request.watcher_id;
|
||||
let handle = self.watchers.get(watcher_id).ok_or_else(|| {
|
||||
ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("watcher not found: {watcher_id}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let events = {
|
||||
let mut guard = handle.events.lock().unwrap();
|
||||
std::mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
Ok((
|
||||
GetWatcherEventsResponse {
|
||||
events,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn remove_watcher(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<RemoveWatcherRequestView<'static>>,
|
||||
) -> Result<(RemoveWatcherResponse, Context), ConnectError> {
|
||||
let watcher_id: &str = request.watcher_id;
|
||||
self.watchers.remove(watcher_id);
|
||||
Ok((RemoveWatcherResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
fn walk_dir(
|
||||
requested_path: &str,
|
||||
resolved_path: &str,
|
||||
depth: usize,
|
||||
) -> Result<Vec<EntryInfo>, ConnectError> {
|
||||
let mut entries = Vec::new();
|
||||
let base = Path::new(resolved_path);
|
||||
|
||||
for result in walkdir::WalkDir::new(resolved_path)
|
||||
.min_depth(1)
|
||||
.max_depth(depth)
|
||||
.follow_links(false)
|
||||
{
|
||||
let dir_entry = match result {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
if e.io_error()
|
||||
.is_some_and(|io| io.kind() == std::io::ErrorKind::NotFound)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error reading directory: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let entry_path = dir_entry.path();
|
||||
let mut entry = match build_entry_info(&entry_path.to_string_lossy()) {
|
||||
Ok(e) => e,
|
||||
Err(e) if e.code == ErrorCode::NotFound => continue,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
if let Ok(rel) = entry_path.strip_prefix(base) {
|
||||
let remapped = PathBuf::from(requested_path).join(rel);
|
||||
entry.path = remapped.to_string_lossy().to_string();
|
||||
}
|
||||
|
||||
entries.push(entry);
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
fn simple_id() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
format!("w-{nanos:x}")
|
||||
}
|
||||
26
envd-rs/src/rpc/mod.rs
Normal file
26
envd-rs/src/rpc/mod.rs
Normal file
@ -0,0 +1,26 @@
|
||||
pub mod pb;
|
||||
pub mod entry;
|
||||
pub mod process_handler;
|
||||
pub mod process_service;
|
||||
pub mod filesystem_service;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::process_service::ProcessServiceImpl;
|
||||
use crate::rpc::filesystem_service::FilesystemServiceImpl;
|
||||
use crate::state::AppState;
|
||||
|
||||
use pb::process::ProcessExt;
|
||||
use pb::filesystem::FilesystemExt;
|
||||
|
||||
/// Build the connect-rust Router with both RPC services registered.
|
||||
pub fn rpc_router(state: Arc<AppState>) -> connectrpc::Router {
|
||||
let process_svc = Arc::new(ProcessServiceImpl::new(Arc::clone(&state)));
|
||||
let filesystem_svc = Arc::new(FilesystemServiceImpl::new(Arc::clone(&state)));
|
||||
|
||||
let router = connectrpc::Router::new();
|
||||
let router = process_svc.register(router);
|
||||
let router = filesystem_svc.register(router);
|
||||
|
||||
router
|
||||
}
|
||||
10
envd-rs/src/rpc/pb.rs
Normal file
10
envd-rs/src/rpc/pb.rs
Normal file
@ -0,0 +1,10 @@
|
||||
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
|
||||
|
||||
use ::buffa;
|
||||
use ::buffa_types;
|
||||
use ::connectrpc;
|
||||
use ::futures;
|
||||
use ::http_body;
|
||||
use ::serde;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/_connectrpc.rs"));
|
||||
419
envd-rs/src/rpc/process_handler.rs
Normal file
419
envd-rs/src/rpc/process_handler.rs
Normal file
@ -0,0 +1,419 @@
|
||||
use std::io::Read;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::process::Stdio;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use connectrpc::{ConnectError, ErrorCode};
|
||||
use nix::pty::{openpty, Winsize};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::rpc::pb::process::*;
|
||||
|
||||
const STD_CHUNK_SIZE: usize = 32768;
|
||||
const PTY_CHUNK_SIZE: usize = 16384;
|
||||
const BROADCAST_CAPACITY: usize = 4096;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum DataEvent {
|
||||
Stdout(Vec<u8>),
|
||||
Stderr(Vec<u8>),
|
||||
Pty(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EndEvent {
|
||||
pub exit_code: i32,
|
||||
pub exited: bool,
|
||||
pub status: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
pub struct ProcessHandle {
|
||||
pub config: ProcessConfig,
|
||||
pub tag: Option<String>,
|
||||
pub pid: u32,
|
||||
|
||||
data_tx: broadcast::Sender<DataEvent>,
|
||||
end_tx: broadcast::Sender<EndEvent>,
|
||||
ended: Mutex<Option<EndEvent>>,
|
||||
|
||||
stdin: Mutex<Option<std::process::ChildStdin>>,
|
||||
pty_master: Mutex<Option<std::fs::File>>,
|
||||
}
|
||||
|
||||
impl ProcessHandle {
|
||||
pub fn subscribe_data(&self) -> broadcast::Receiver<DataEvent> {
|
||||
self.data_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn subscribe_end(&self) -> broadcast::Receiver<EndEvent> {
|
||||
self.end_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn cached_end(&self) -> Option<EndEvent> {
|
||||
self.ended.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> {
|
||||
signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write_stdin(&self, data: &[u8]) -> Result<(), ConnectError> {
|
||||
use std::io::Write;
|
||||
let mut guard = self.stdin.lock().unwrap();
|
||||
match guard.as_mut() {
|
||||
Some(stdin) => stdin.write_all(data).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error writing to stdin: {e}"))
|
||||
}),
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"stdin not enabled or closed",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pty(&self, data: &[u8]) -> Result<(), ConnectError> {
|
||||
use std::io::Write;
|
||||
let mut guard = self.pty_master.lock().unwrap();
|
||||
match guard.as_mut() {
|
||||
Some(master) => master.write_all(data).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error writing to pty: {e}"))
|
||||
}),
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"pty not assigned to process",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close_stdin(&self) -> Result<(), ConnectError> {
|
||||
if self.pty_master.lock().unwrap().is_some() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"cannot close stdin for PTY process — send Ctrl+D (0x04) instead",
|
||||
));
|
||||
}
|
||||
let mut guard = self.stdin.lock().unwrap();
|
||||
*guard = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resize_pty(&self, cols: u16, rows: u16) -> Result<(), ConnectError> {
|
||||
let guard = self.pty_master.lock().unwrap();
|
||||
match guard.as_ref() {
|
||||
Some(master) => {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let ws = libc::winsize {
|
||||
ws_row: rows,
|
||||
ws_col: cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
};
|
||||
let ret = unsafe { libc::ioctl(master.as_raw_fd(), libc::TIOCSWINSZ, &ws) };
|
||||
if ret != 0 {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!(
|
||||
"ioctl TIOCSWINSZ failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"tty not assigned to process",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SpawnedProcess {
|
||||
pub handle: Arc<ProcessHandle>,
|
||||
pub data_rx: broadcast::Receiver<DataEvent>,
|
||||
pub end_rx: broadcast::Receiver<EndEvent>,
|
||||
}
|
||||
|
||||
pub fn spawn_process(
|
||||
cmd_str: &str,
|
||||
args: &[String],
|
||||
envs: &std::collections::HashMap<String, String>,
|
||||
cwd: &str,
|
||||
pty_opts: Option<(u16, u16)>,
|
||||
enable_stdin: bool,
|
||||
tag: Option<String>,
|
||||
user: &nix::unistd::User,
|
||||
default_env_vars: &dashmap::DashMap<String, String>,
|
||||
) -> Result<SpawnedProcess, ConnectError> {
|
||||
let mut env: Vec<(String, String)> = Vec::new();
|
||||
env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default()));
|
||||
let home = user.dir.to_string_lossy().to_string();
|
||||
env.push(("HOME".into(), home));
|
||||
env.push(("USER".into(), user.name.clone()));
|
||||
env.push(("LOGNAME".into(), user.name.clone()));
|
||||
|
||||
default_env_vars.iter().for_each(|entry| {
|
||||
env.push((entry.key().clone(), entry.value().clone()));
|
||||
});
|
||||
|
||||
for (k, v) in envs {
|
||||
env.push((k.clone(), v.clone()));
|
||||
}
|
||||
|
||||
let nice_delta = 0 - current_nice();
|
||||
let oom_script = format!(
|
||||
r#"echo 100 > /proc/$$/oom_score_adj && exec /usr/bin/nice -n {} "${{@}}""#,
|
||||
nice_delta
|
||||
);
|
||||
let mut wrapper_args = vec![
|
||||
"-c".to_string(),
|
||||
oom_script,
|
||||
"--".to_string(),
|
||||
cmd_str.to_string(),
|
||||
];
|
||||
wrapper_args.extend_from_slice(args);
|
||||
|
||||
let uid = user.uid.as_raw();
|
||||
let gid = user.gid.as_raw();
|
||||
|
||||
let (data_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||
let (end_tx, _) = broadcast::channel(16);
|
||||
|
||||
let config = ProcessConfig {
|
||||
cmd: cmd_str.to_string(),
|
||||
args: args.to_vec(),
|
||||
envs: envs.clone(),
|
||||
cwd: Some(cwd.to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some((cols, rows)) = pty_opts {
|
||||
let pty_result = openpty(
|
||||
Some(&Winsize {
|
||||
ws_row: rows,
|
||||
ws_col: cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, format!("openpty failed: {e}")))?;
|
||||
|
||||
let master_fd = pty_result.master;
|
||||
let slave_fd = pty_result.slave;
|
||||
|
||||
let mut command = std::process::Command::new("/bin/sh");
|
||||
command
|
||||
.args(&wrapper_args)
|
||||
.env_clear()
|
||||
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
|
||||
.current_dir(cwd);
|
||||
|
||||
unsafe {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let slave_raw = slave_fd.as_raw_fd();
|
||||
let master_raw = master_fd.as_raw_fd();
|
||||
command.pre_exec(move || {
|
||||
libc::close(master_raw);
|
||||
nix::unistd::setsid()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
libc::ioctl(slave_raw, libc::TIOCSCTTY, 0);
|
||||
libc::dup2(slave_raw, 0);
|
||||
libc::dup2(slave_raw, 1);
|
||||
libc::dup2(slave_raw, 2);
|
||||
if slave_raw > 2 {
|
||||
libc::close(slave_raw);
|
||||
}
|
||||
libc::setgid(gid);
|
||||
libc::setuid(uid);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
command.stdin(Stdio::null());
|
||||
command.stdout(Stdio::null());
|
||||
command.stderr(Stdio::null());
|
||||
|
||||
let child = command.spawn().map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error starting pty process: {e}"))
|
||||
})?;
|
||||
|
||||
drop(slave_fd);
|
||||
|
||||
let pid = child.id();
|
||||
let master_file: std::fs::File = master_fd.into();
|
||||
let master_clone = master_file.try_clone().unwrap();
|
||||
|
||||
let handle = Arc::new(ProcessHandle {
|
||||
config,
|
||||
tag,
|
||||
pid,
|
||||
data_tx: data_tx.clone(),
|
||||
end_tx: end_tx.clone(),
|
||||
ended: Mutex::new(None),
|
||||
stdin: Mutex::new(None),
|
||||
pty_master: Mutex::new(Some(master_file)),
|
||||
});
|
||||
|
||||
let data_rx = handle.subscribe_data();
|
||||
let end_rx = handle.subscribe_end();
|
||||
|
||||
let data_tx_clone = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut master = master_clone;
|
||||
let mut buf = vec![0u8; PTY_CHUNK_SIZE];
|
||||
loop {
|
||||
match master.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = data_tx_clone.send(DataEvent::Pty(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let end_tx_clone = end_tx.clone();
|
||||
let handle_for_waiter = Arc::clone(&handle);
|
||||
std::thread::spawn(move || {
|
||||
let mut child = child;
|
||||
let end_event = match child.wait() {
|
||||
Ok(s) => EndEvent {
|
||||
exit_code: s.code().unwrap_or(-1),
|
||||
exited: s.code().is_some(),
|
||||
status: format!("{s}"),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => EndEvent {
|
||||
exit_code: -1,
|
||||
exited: false,
|
||||
status: "error".into(),
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
};
|
||||
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
|
||||
let _ = end_tx_clone.send(end_event);
|
||||
});
|
||||
|
||||
tracing::info!(pid, cmd = cmd_str, "process started (pty)");
|
||||
Ok(SpawnedProcess { handle, data_rx, end_rx })
|
||||
} else {
|
||||
let mut command = std::process::Command::new("/bin/sh");
|
||||
command
|
||||
.args(&wrapper_args)
|
||||
.env_clear()
|
||||
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
|
||||
.current_dir(cwd)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
if enable_stdin {
|
||||
command.stdin(Stdio::piped());
|
||||
} else {
|
||||
command.stdin(Stdio::null());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
libc::setgid(gid);
|
||||
libc::setuid(uid);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
let mut child = command.spawn().map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error starting process: {e}"))
|
||||
})?;
|
||||
|
||||
let pid = child.id();
|
||||
let stdin = child.stdin.take();
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
|
||||
let handle = Arc::new(ProcessHandle {
|
||||
config,
|
||||
tag,
|
||||
pid,
|
||||
data_tx: data_tx.clone(),
|
||||
end_tx: end_tx.clone(),
|
||||
ended: Mutex::new(None),
|
||||
stdin: Mutex::new(stdin),
|
||||
pty_master: Mutex::new(None),
|
||||
});
|
||||
|
||||
let data_rx = handle.subscribe_data();
|
||||
let end_rx = handle.subscribe_end();
|
||||
|
||||
if let Some(mut out) = stdout {
|
||||
let tx = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut buf = vec![0u8; STD_CHUNK_SIZE];
|
||||
loop {
|
||||
match out.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = tx.send(DataEvent::Stdout(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(mut err_pipe) = stderr {
|
||||
let tx = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut buf = vec![0u8; STD_CHUNK_SIZE];
|
||||
loop {
|
||||
match err_pipe.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = tx.send(DataEvent::Stderr(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let end_tx_clone = end_tx.clone();
|
||||
let handle_for_waiter = Arc::clone(&handle);
|
||||
std::thread::spawn(move || {
|
||||
let end_event = match child.wait() {
|
||||
Ok(s) => EndEvent {
|
||||
exit_code: s.code().unwrap_or(-1),
|
||||
exited: s.code().is_some(),
|
||||
status: format!("{s}"),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => EndEvent {
|
||||
exit_code: -1,
|
||||
exited: false,
|
||||
status: "error".into(),
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
};
|
||||
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
|
||||
let _ = end_tx_clone.send(end_event);
|
||||
});
|
||||
|
||||
tracing::info!(pid, cmd = cmd_str, "process started (pipe)");
|
||||
Ok(SpawnedProcess { handle, data_rx, end_rx })
|
||||
}
|
||||
}
|
||||
|
||||
fn current_nice() -> i32 {
|
||||
unsafe {
|
||||
*libc::__errno_location() = 0;
|
||||
let prio = libc::getpriority(libc::PRIO_PROCESS, 0);
|
||||
if *libc::__errno_location() != 0 {
|
||||
return 0;
|
||||
}
|
||||
20 - prio
|
||||
}
|
||||
}
|
||||
481
envd-rs/src/rpc/process_service.rs
Normal file
481
envd-rs/src/rpc/process_service.rs
Normal file
@ -0,0 +1,481 @@
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use connectrpc::{ConnectError, Context, ErrorCode};
|
||||
use dashmap::DashMap;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::permissions::path::expand_and_resolve;
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::pb::process::*;
|
||||
use crate::rpc::process_handler::{self, DataEvent, ProcessHandle};
|
||||
use crate::state::AppState;
|
||||
|
||||
pub struct ProcessServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
processes: DashMap<u32, Arc<ProcessHandle>>,
|
||||
}
|
||||
|
||||
impl ProcessServiceImpl {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
processes: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_process_by_selector(
|
||||
&self,
|
||||
selector: &ProcessSelectorView,
|
||||
) -> Result<Arc<ProcessHandle>, ConnectError> {
|
||||
match &selector.selector {
|
||||
Some(process_selector::SelectorView::Pid(pid)) => {
|
||||
let pid_val = *pid;
|
||||
self.processes
|
||||
.get(&pid_val)
|
||||
.map(|entry| Arc::clone(entry.value()))
|
||||
.ok_or_else(|| {
|
||||
ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("process with pid {pid_val} not found"),
|
||||
)
|
||||
})
|
||||
}
|
||||
Some(process_selector::SelectorView::Tag(tag)) => {
|
||||
let tag_str: &str = tag;
|
||||
for entry in self.processes.iter() {
|
||||
if let Some(ref t) = entry.value().tag {
|
||||
if t == tag_str {
|
||||
return Ok(Arc::clone(entry.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("process with tag {tag_str} not found"),
|
||||
))
|
||||
}
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
"process selector required",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_from_request(
|
||||
&self,
|
||||
request: &StartRequestView<'_>,
|
||||
) -> Result<process_handler::SpawnedProcess, ConnectError> {
|
||||
let proc_config = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process config required")
|
||||
})?;
|
||||
|
||||
let username = self.state.defaults.user();
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
let cmd: &str = proc_config.cmd;
|
||||
let args: Vec<String> = proc_config.args.iter().map(|s| s.to_string()).collect();
|
||||
let envs: HashMap<String, String> = proc_config
|
||||
.envs
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect();
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let cwd_str: &str = proc_config.cwd.unwrap_or("");
|
||||
let default_workdir = self.state.defaults.workdir();
|
||||
let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref())
|
||||
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?;
|
||||
|
||||
let effective_cwd = if cwd.is_empty() { "/" } else { &cwd };
|
||||
if let Err(_) = std::fs::metadata(effective_cwd) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("cwd '{effective_cwd}' does not exist"),
|
||||
));
|
||||
}
|
||||
|
||||
let pty_opts = request.pty.as_option().and_then(|pty| {
|
||||
pty.size
|
||||
.as_option()
|
||||
.map(|sz| (sz.cols as u16, sz.rows as u16))
|
||||
});
|
||||
|
||||
let enable_stdin = request.stdin.unwrap_or(true);
|
||||
let tag = request.tag.map(|s| s.to_string());
|
||||
|
||||
tracing::info!(
|
||||
cmd = cmd,
|
||||
has_pty = pty_opts.is_some(),
|
||||
pty_size = ?pty_opts,
|
||||
tag = ?tag,
|
||||
stdin = enable_stdin,
|
||||
cwd = effective_cwd,
|
||||
user = %username,
|
||||
"process.Start request"
|
||||
);
|
||||
|
||||
let spawned = process_handler::spawn_process(
|
||||
cmd,
|
||||
&args,
|
||||
&envs,
|
||||
effective_cwd,
|
||||
pty_opts,
|
||||
enable_stdin,
|
||||
tag,
|
||||
&user,
|
||||
&self.state.defaults.env_vars,
|
||||
)?;
|
||||
|
||||
self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle));
|
||||
|
||||
let processes = self.processes.clone();
|
||||
let pid = spawned.handle.pid;
|
||||
let mut cleanup_end_rx = spawned.handle.subscribe_end();
|
||||
tokio::spawn(async move {
|
||||
let _ = cleanup_end_rx.recv().await;
|
||||
processes.remove(&pid);
|
||||
});
|
||||
|
||||
Ok(spawned)
|
||||
}
|
||||
}
|
||||
|
||||
impl Process for ProcessServiceImpl {
|
||||
async fn list(
|
||||
&self,
|
||||
ctx: Context,
|
||||
_request: buffa::view::OwnedView<ListRequestView<'static>>,
|
||||
) -> Result<(ListResponse, Context), ConnectError> {
|
||||
let processes: Vec<ProcessInfo> = self
|
||||
.processes
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let h = entry.value();
|
||||
ProcessInfo {
|
||||
config: buffa::MessageField::some(h.config.clone()),
|
||||
pid: h.pid,
|
||||
tag: h.tag.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok((
|
||||
ListResponse {
|
||||
processes,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<StartRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<StartResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
let spawned = self.spawn_from_request(&request)?;
|
||||
let pid = spawned.handle.pid;
|
||||
|
||||
let mut data_rx = spawned.data_rx;
|
||||
let mut end_rx = spawned.end_rx;
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
yield Ok(make_start_response(pid));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
data = data_rx.recv() => {
|
||||
match data {
|
||||
Ok(ev) => yield Ok(make_data_start_response(ev)),
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
end = end_rx.recv() => {
|
||||
while let Ok(ev) = data_rx.try_recv() {
|
||||
yield Ok(make_data_start_response(ev));
|
||||
}
|
||||
if let Ok(end) = end {
|
||||
yield Ok(make_end_start_response(end));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok((Box::pin(stream), ctx))
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<ConnectRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<ConnectResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
let pid = handle.pid;
|
||||
|
||||
let mut data_rx = handle.subscribe_data();
|
||||
let mut end_rx = handle.subscribe_end();
|
||||
let cached_end = handle.cached_end();
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(ProcessEvent {
|
||||
event: Some(process_event::Event::Start(Box::new(
|
||||
process_event::StartEvent { pid, ..Default::default() },
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
if let Some(end) = cached_end {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
});
|
||||
} else {
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
data = data_rx.recv() => {
|
||||
match data {
|
||||
Ok(ev) => {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
end = end_rx.recv() => {
|
||||
while let Ok(ev) = data_rx.try_recv() {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
if let Ok(end) = end {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok((Box::pin(stream), ctx))
|
||||
}
|
||||
|
||||
async fn update(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<UpdateRequestView<'static>>,
|
||||
) -> Result<(UpdateResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
if let Some(pty) = request.pty.as_option() {
|
||||
if let Some(size) = pty.size.as_option() {
|
||||
handle.resize_pty(size.cols as u16, size.rows as u16)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((UpdateResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn stream_input(
|
||||
&self,
|
||||
ctx: Context,
|
||||
mut requests: Pin<
|
||||
Box<
|
||||
dyn Stream<
|
||||
Item = Result<
|
||||
buffa::view::OwnedView<StreamInputRequestView<'static>>,
|
||||
ConnectError,
|
||||
>,
|
||||
> + Send,
|
||||
>,
|
||||
>,
|
||||
) -> Result<(StreamInputResponse, Context), ConnectError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut handle: Option<Arc<ProcessHandle>> = None;
|
||||
|
||||
while let Some(result) = requests.next().await {
|
||||
let req = result?;
|
||||
match &req.event {
|
||||
Some(stream_input_request::EventView::Start(start)) => {
|
||||
if let Some(selector) = start.process.as_option() {
|
||||
handle = Some(self.get_process_by_selector(selector)?);
|
||||
}
|
||||
}
|
||||
Some(stream_input_request::EventView::Data(data)) => {
|
||||
let h = handle.as_ref().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::FailedPrecondition, "no start event received")
|
||||
})?;
|
||||
if let Some(input) = data.input.as_option() {
|
||||
write_input(h, input)?;
|
||||
}
|
||||
}
|
||||
Some(stream_input_request::EventView::Keepalive(_)) => {}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((StreamInputResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn send_input(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<SendInputRequestView<'static>>,
|
||||
) -> Result<(SendInputResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
if let Some(input) = request.input.as_option() {
|
||||
write_input(&handle, input)?;
|
||||
}
|
||||
|
||||
Ok((SendInputResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn send_signal(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<SendSignalRequestView<'static>>,
|
||||
) -> Result<(SendSignalResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
let sig = match request.signal.as_known() {
|
||||
Some(Signal::SIGNAL_SIGKILL) => nix::sys::signal::Signal::SIGKILL,
|
||||
Some(Signal::SIGNAL_SIGTERM) => nix::sys::signal::Signal::SIGTERM,
|
||||
_ => {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
"invalid or unspecified signal",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
handle.send_signal(sig)?;
|
||||
Ok((SendSignalResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn close_stdin(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<CloseStdinRequestView<'static>>,
|
||||
) -> Result<(CloseStdinResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
handle.close_stdin()?;
|
||||
Ok((CloseStdinResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
fn write_input(handle: &ProcessHandle, input: &ProcessInputView) -> Result<(), ConnectError> {
|
||||
match &input.input {
|
||||
Some(process_input::InputView::Pty(d)) => handle.write_pty(d),
|
||||
Some(process_input::InputView::Stdin(d)) => handle.write_stdin(d),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_start_response(pid: u32) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(ProcessEvent {
|
||||
event: Some(process_event::Event::Start(Box::new(
|
||||
process_event::StartEvent {
|
||||
pid,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_data_event(ev: DataEvent) -> ProcessEvent {
|
||||
let output = match ev {
|
||||
DataEvent::Stdout(d) => Some(process_event::data_event::Output::Stdout(d.into())),
|
||||
DataEvent::Stderr(d) => Some(process_event::data_event::Output::Stderr(d.into())),
|
||||
DataEvent::Pty(d) => Some(process_event::data_event::Output::Pty(d.into())),
|
||||
};
|
||||
ProcessEvent {
|
||||
event: Some(process_event::Event::Data(Box::new(
|
||||
process_event::DataEvent {
|
||||
output,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_data_start_response(ev: DataEvent) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_end_event(end: process_handler::EndEvent) -> ProcessEvent {
|
||||
ProcessEvent {
|
||||
event: Some(process_event::Event::End(Box::new(
|
||||
process_event::EndEvent {
|
||||
exit_code: end.exit_code,
|
||||
exited: end.exited,
|
||||
status: end.status,
|
||||
error: end.error,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_end_start_response(end: process_handler::EndEvent) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
89
envd-rs/src/state.rs
Normal file
89
envd-rs/src/state.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::token::SecureToken;
|
||||
use crate::conntracker::ConnTracker;
|
||||
use crate::execcontext::Defaults;
|
||||
use crate::port::subsystem::PortSubsystem;
|
||||
use crate::util::AtomicMax;
|
||||
|
||||
pub struct AppState {
|
||||
pub defaults: Defaults,
|
||||
pub version: String,
|
||||
pub commit: String,
|
||||
pub is_fc: bool,
|
||||
pub needs_restore: AtomicBool,
|
||||
pub last_set_time: AtomicMax,
|
||||
pub access_token: SecureToken,
|
||||
pub conn_tracker: ConnTracker,
|
||||
pub port_subsystem: Option<Arc<PortSubsystem>>,
|
||||
pub cpu_used_pct: AtomicU32,
|
||||
pub cpu_count: AtomicU32,
|
||||
pub snapshot_in_progress: AtomicBool,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
defaults: Defaults,
|
||||
version: String,
|
||||
commit: String,
|
||||
is_fc: bool,
|
||||
port_subsystem: Option<Arc<PortSubsystem>>,
|
||||
) -> Arc<Self> {
|
||||
let state = Arc::new(Self {
|
||||
defaults,
|
||||
version,
|
||||
commit,
|
||||
is_fc,
|
||||
needs_restore: AtomicBool::new(false),
|
||||
last_set_time: AtomicMax::new(),
|
||||
access_token: SecureToken::new(),
|
||||
conn_tracker: ConnTracker::new(),
|
||||
port_subsystem,
|
||||
cpu_used_pct: AtomicU32::new(0),
|
||||
cpu_count: AtomicU32::new(0),
|
||||
snapshot_in_progress: AtomicBool::new(false),
|
||||
});
|
||||
|
||||
let state_clone = Arc::clone(&state);
|
||||
std::thread::spawn(move || {
|
||||
cpu_sampler(state_clone);
|
||||
});
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
pub fn cpu_used_pct(&self) -> f32 {
|
||||
f32::from_bits(self.cpu_used_pct.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
pub fn cpu_count(&self) -> u32 {
|
||||
self.cpu_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_sampler(state: Arc<AppState>) {
|
||||
use sysinfo::System;
|
||||
|
||||
let mut sys = System::new();
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
loop {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
let pct = sys.global_cpu_usage();
|
||||
let rounded = if pct > 0.0 {
|
||||
(pct * 100.0).round() / 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
state
|
||||
.cpu_used_pct
|
||||
.store(rounded.to_bits(), Ordering::Relaxed);
|
||||
state
|
||||
.cpu_count
|
||||
.store(sys.cpus().len() as u32, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
102
envd-rs/src/util.rs
Normal file
102
envd-rs/src/util.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
|
||||
pub struct AtomicMax {
|
||||
val: AtomicI64,
|
||||
}
|
||||
|
||||
impl AtomicMax {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
val: AtomicI64::new(i64::MIN),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self) -> i64 {
|
||||
self.val.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Sets the stored value to `new` if `new` is strictly greater than
|
||||
/// the current value. Returns `true` if the value was updated.
|
||||
pub fn set_to_greater(&self, new: i64) -> bool {
|
||||
loop {
|
||||
let current = self.val.load(Ordering::Acquire);
|
||||
if new <= current {
|
||||
return false;
|
||||
}
|
||||
match self.val.compare_exchange_weak(
|
||||
current,
|
||||
new,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return true,
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn initial_value_is_i64_min() {
|
||||
let m = AtomicMax::new();
|
||||
assert_eq!(m.get(), i64::MIN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn updates_when_larger() {
|
||||
let m = AtomicMax::new();
|
||||
assert!(m.set_to_greater(0));
|
||||
assert_eq!(m.get(), 0);
|
||||
assert!(m.set_to_greater(100));
|
||||
assert_eq!(m.get(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_false_when_equal() {
|
||||
let m = AtomicMax::new();
|
||||
m.set_to_greater(42);
|
||||
assert!(!m.set_to_greater(42));
|
||||
assert_eq!(m.get(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_false_when_smaller() {
|
||||
let m = AtomicMax::new();
|
||||
m.set_to_greater(100);
|
||||
assert!(!m.set_to_greater(50));
|
||||
assert_eq!(m.get(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent_convergence() {
|
||||
let m = Arc::new(AtomicMax::new());
|
||||
let threads: Vec<_> = (0..8)
|
||||
.map(|t| {
|
||||
let m = Arc::clone(&m);
|
||||
std::thread::spawn(move || {
|
||||
for i in (t * 100)..((t + 1) * 100) {
|
||||
m.set_to_greater(i);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
assert_eq!(m.get(), 799);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_max_boundary() {
|
||||
let m = AtomicMax::new();
|
||||
assert!(m.set_to_greater(i64::MAX));
|
||||
assert!(!m.set_to_greater(i64::MAX));
|
||||
assert!(!m.set_to_greater(0));
|
||||
assert_eq!(m.get(), i64::MAX);
|
||||
}
|
||||
}
|
||||
202
envd/LICENSE
202
envd/LICENSE
@ -1,202 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023 FoundryLabs, Inc.
|
||||
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -1,62 +0,0 @@
|
||||
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
|
||||
BUILDS := ../builds
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Build
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: build build-debug
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
|
||||
@file $(BUILDS)/envd | grep -q "statically linked" || \
|
||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||
|
||||
build-debug:
|
||||
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Run (debug mode, not inside a VM)
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: run-debug
|
||||
|
||||
run-debug: build-debug
|
||||
$(BUILDS)/debug/envd -isnotfc -port 49983
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Code Generation
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: generate proto openapi
|
||||
|
||||
generate: proto openapi
|
||||
|
||||
proto:
|
||||
cd spec && buf generate --template buf.gen.yaml
|
||||
|
||||
openapi:
|
||||
go generate ./internal/api/...
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Quality
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: fmt vet test tidy
|
||||
|
||||
fmt:
|
||||
gofmt -w .
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
|
||||
test:
|
||||
go test -race -v ./...
|
||||
|
||||
tidy:
|
||||
go mod tidy
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Clean
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: clean
|
||||
|
||||
clean:
|
||||
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd
|
||||
@ -1 +0,0 @@
|
||||
0.1.0
|
||||
42
envd/go.mod
42
envd/go.mod
@ -1,42 +0,0 @@
|
||||
module git.omukk.dev/wrenn/sandbox/envd
|
||||
|
||||
go 1.25.8
|
||||
|
||||
require (
|
||||
connectrpc.com/authn v0.1.0
|
||||
connectrpc.com/connect v1.19.1
|
||||
connectrpc.com/cors v0.1.0
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/dchest/uniuri v1.2.0
|
||||
github.com/e2b-dev/fsnotify v0.0.1
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oapi-codegen/runtime v1.2.0
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1
|
||||
github.com/rs/cors v1.11.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/shirou/gopsutil/v4 v4.26.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/txn2/txeh v1.8.0
|
||||
golang.org/x/sys v0.43.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||
github.com/tklauser/numcpus v0.11.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
92
envd/go.sum
92
envd/go.sum
@ -1,92 +0,0 @@
|
||||
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
|
||||
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
|
||||
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
|
||||
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
|
||||
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
|
||||
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
|
||||
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
|
||||
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
|
||||
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
|
||||
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
|
||||
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
|
||||
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
|
||||
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||
@ -1,604 +0,0 @@
|
||||
// Package api provides primitives to interact with the openapi HTTP API.
|
||||
//
|
||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oapi-codegen/runtime"
|
||||
openapi_types "github.com/oapi-codegen/runtime/types"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for EntryInfoType.
|
||||
const (
|
||||
File EntryInfoType = "file"
|
||||
)
|
||||
|
||||
// Valid indicates whether the value is a known member of the EntryInfoType enum.
|
||||
func (e EntryInfoType) Valid() bool {
|
||||
switch e {
|
||||
case File:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EntryInfo defines model for EntryInfo.
|
||||
type EntryInfo struct {
|
||||
// Name Name of the file
|
||||
Name string `json:"name"`
|
||||
|
||||
// Path Path to the file
|
||||
Path string `json:"path"`
|
||||
|
||||
// Type Type of the file
|
||||
Type EntryInfoType `json:"type"`
|
||||
}
|
||||
|
||||
// EntryInfoType Type of the file
|
||||
type EntryInfoType string
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
type EnvVars map[string]string
|
||||
|
||||
// Error defines model for Error.
|
||||
type Error struct {
|
||||
// Code Error code
|
||||
Code int `json:"code"`
|
||||
|
||||
// Message Error message
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Metrics Resource usage metrics
|
||||
type Metrics struct {
|
||||
// CpuCount Number of CPU cores
|
||||
CpuCount *int `json:"cpu_count,omitempty"`
|
||||
|
||||
// CpuUsedPct CPU usage percentage
|
||||
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
|
||||
|
||||
// DiskTotal Total disk space in bytes
|
||||
DiskTotal *int `json:"disk_total,omitempty"`
|
||||
|
||||
// DiskUsed Used disk space in bytes
|
||||
DiskUsed *int `json:"disk_used,omitempty"`
|
||||
|
||||
// MemTotal Total virtual memory in bytes
|
||||
MemTotal *int `json:"mem_total,omitempty"`
|
||||
|
||||
// MemUsed Used virtual memory in bytes
|
||||
MemUsed *int `json:"mem_used,omitempty"`
|
||||
|
||||
// Ts Unix timestamp in UTC for current sandbox time
|
||||
Ts *int64 `json:"ts,omitempty"`
|
||||
}
|
||||
|
||||
// VolumeMount Volume
|
||||
type VolumeMount struct {
|
||||
NfsTarget string `json:"nfs_target"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// FilePath defines model for FilePath.
|
||||
type FilePath = string
|
||||
|
||||
// Signature defines model for Signature.
|
||||
type Signature = string
|
||||
|
||||
// SignatureExpiration defines model for SignatureExpiration.
|
||||
type SignatureExpiration = int
|
||||
|
||||
// User defines model for User.
|
||||
type User = string
|
||||
|
||||
// FileNotFound defines model for FileNotFound.
|
||||
type FileNotFound = Error
|
||||
|
||||
// InternalServerError defines model for InternalServerError.
|
||||
type InternalServerError = Error
|
||||
|
||||
// InvalidPath defines model for InvalidPath.
|
||||
type InvalidPath = Error
|
||||
|
||||
// InvalidUser defines model for InvalidUser.
|
||||
type InvalidUser = Error
|
||||
|
||||
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
|
||||
type NotEnoughDiskSpace = Error
|
||||
|
||||
// UploadSuccess defines model for UploadSuccess.
|
||||
type UploadSuccess = []EntryInfo
|
||||
|
||||
// GetFilesParams defines parameters for GetFiles.
|
||||
type GetFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartBody defines parameters for PostFiles.
|
||||
type PostFilesMultipartBody struct {
|
||||
File *openapi_types.File `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesParams defines parameters for PostFiles.
|
||||
type PostFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostInitJSONBody defines parameters for PostInit.
|
||||
type PostInitJSONBody struct {
|
||||
// AccessToken Access token for secure access to envd service
|
||||
AccessToken *SecureToken `json:"accessToken,omitempty"`
|
||||
|
||||
// DefaultUser The default user to use for operations
|
||||
DefaultUser *string `json:"defaultUser,omitempty"`
|
||||
|
||||
// DefaultWorkdir The default working directory to use for operations
|
||||
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
EnvVars *EnvVars `json:"envVars,omitempty"`
|
||||
|
||||
// HyperloopIP IP address of the hyperloop server to connect to
|
||||
HyperloopIP *string `json:"hyperloopIP,omitempty"`
|
||||
|
||||
// Timestamp The current timestamp in RFC3339 format
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
|
||||
type PostFilesMultipartRequestBody PostFilesMultipartBody
|
||||
|
||||
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
|
||||
type PostInitJSONRequestBody PostInitJSONBody
|
||||
|
||||
// ServerInterface represents all server handlers.
|
||||
type ServerInterface interface {
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
GetEnvs(w http.ResponseWriter, r *http.Request)
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
GetHealth(w http.ResponseWriter, r *http.Request)
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
PostInit(w http.ResponseWriter, r *http.Request)
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||
// Quiesce continuous goroutines before Firecracker snapshot
|
||||
// (POST /snapshot/prepare)
|
||||
PostSnapshotPrepare(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||
|
||||
type Unimplemented struct{}
|
||||
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Quiesce continuous goroutines before Firecracker snapshot
|
||||
// (POST /snapshot/prepare)
|
||||
func (_ Unimplemented) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// ServerInterfaceWrapper converts contexts to parameters.
|
||||
type ServerInterfaceWrapper struct {
|
||||
Handler ServerInterface
|
||||
HandlerMiddlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// GetEnvs operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetEnvs(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params GetFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params PostFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetHealth operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetHealth(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostInit operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostInit(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetMetrics operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetMetrics(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostSnapshotPrepare operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostSnapshotPrepare(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type UnescapedCookieParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Error() string {
|
||||
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type UnmarshalingParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Error() string {
|
||||
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type RequiredParamError struct {
|
||||
ParamName string
|
||||
}
|
||||
|
||||
func (e *RequiredParamError) Error() string {
|
||||
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
type RequiredHeaderError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Error() string {
|
||||
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type InvalidParamFormatError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Error() string {
|
||||
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TooManyValuesForParamError struct {
|
||||
ParamName string
|
||||
Count int
|
||||
}
|
||||
|
||||
func (e *TooManyValuesForParamError) Error() string {
|
||||
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
|
||||
}
|
||||
|
||||
// Handler creates http.Handler with routing matching OpenAPI spec.
|
||||
func Handler(si ServerInterface) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{})
|
||||
}
|
||||
|
||||
type ChiServerOptions struct {
|
||||
BaseURL string
|
||||
BaseRouter chi.Router
|
||||
Middlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
|
||||
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseURL: baseURL,
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithOptions creates http.Handler with additional options
|
||||
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
|
||||
r := options.BaseRouter
|
||||
|
||||
if r == nil {
|
||||
r = chi.NewRouter()
|
||||
}
|
||||
if options.ErrorHandlerFunc == nil {
|
||||
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
wrapper := ServerInterfaceWrapper{
|
||||
Handler: si,
|
||||
HandlerMiddlewares: options.Middlewares,
|
||||
ErrorHandlerFunc: options.ErrorHandlerFunc,
|
||||
}
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/init", wrapper.PostInit)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/snapshot/prepare", wrapper.PostSnapshotPrepare)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
@ -1,133 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningReadOperation = "read"
|
||||
SigningWriteOperation = "write"
|
||||
|
||||
accessTokenHeader = "X-Access-Token"
|
||||
)
|
||||
|
||||
// paths that are always allowed without general authentication
|
||||
// POST/init is secured via MMDS hash validation instead
|
||||
var authExcludedPaths = []string{
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
"POST/snapshot/prepare",
|
||||
}
|
||||
|
||||
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if a.accessToken.IsSet() {
|
||||
authHeader := req.Header.Get(accessTokenHeader)
|
||||
|
||||
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
|
||||
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
|
||||
|
||||
if !a.accessToken.Equals(authHeader) && !allowedPath {
|
||||
a.logger.Error().Msg("Trying to access secured envd without correct access token")
|
||||
|
||||
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
|
||||
tokenBytes, err := a.accessToken.Bytes()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("access token is not set: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
var signature string
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
|
||||
if signatureExpiration == nil {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
|
||||
} else {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
|
||||
}
|
||||
|
||||
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
|
||||
var expectedSignature string
|
||||
|
||||
// no need to validate signing key if access token is not set
|
||||
if !a.accessToken.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if access token is sent in the header
|
||||
tokenFromHeader := r.Header.Get(accessTokenHeader)
|
||||
if tokenFromHeader != "" {
|
||||
if !a.accessToken.Equals(tokenFromHeader) {
|
||||
return fmt.Errorf("access token present in header but does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if signature == nil {
|
||||
return fmt.Errorf("missing signature query parameter")
|
||||
}
|
||||
|
||||
// Empty string is used when no username is provided and the default user should be used
|
||||
signatureUsername := ""
|
||||
if username != nil {
|
||||
signatureUsername = *username
|
||||
}
|
||||
|
||||
if signatureExpiration == nil {
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
|
||||
} else {
|
||||
exp := int64(*signatureExpiration)
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("error generating signing key")
|
||||
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
// signature validation
|
||||
if expectedSignature != *signature {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
// signature expiration
|
||||
if signatureExpiration != nil {
|
||||
exp := int64(*signatureExpiration)
|
||||
if exp < time.Now().Unix() {
|
||||
return fmt.Errorf("signature is already expired")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,64 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/demo.txt"
|
||||
username := "root"
|
||||
operation := "write"
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, ×tamp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
|
||||
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/resource.txt"
|
||||
username := "user"
|
||||
operation := "read"
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
@ -1,10 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
|
||||
|
||||
package: api
|
||||
output: api.gen.go
|
||||
generate:
|
||||
models: true
|
||||
chi-server: true
|
||||
client: false
|
||||
@ -1,187 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
)
|
||||
|
||||
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File read")
|
||||
}()
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := os.Stat(resolvedPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
|
||||
errorCode = http.StatusNotFound
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Reject anything that isn't a regular file (devices, pipes, sockets, etc.).
|
||||
// Reading device files like /dev/zero or /dev/urandom produces infinite data
|
||||
// and will exhaust memory on all layers of the stack.
|
||||
if !stat.Mode().IsRegular() {
|
||||
errMsg = fmt.Errorf("path '%s' is not a regular file", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Accept-Encoding header
|
||||
encoding, err := parseAcceptEncoding(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Tell caches to store separate variants for different Accept-Encoding values
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
|
||||
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
|
||||
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
|
||||
// is acceptable per the Accept-Encoding header.
|
||||
hasRangeOrConditional := r.Header.Get("Range") != "" ||
|
||||
r.Header.Get("If-Modified-Since") != "" ||
|
||||
r.Header.Get("If-None-Match") != "" ||
|
||||
r.Header.Get("If-Range") != ""
|
||||
if hasRangeOrConditional {
|
||||
if !isIdentityAcceptable(r) {
|
||||
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
encoding = EncodingIdentity
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvedPath)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
|
||||
|
||||
// Serve with gzip encoding if requested.
|
||||
if encoding == EncodingGzip {
|
||||
w.Header().Set("Content-Encoding", EncodingGzip)
|
||||
|
||||
// Set Content-Type based on file extension, preserving the original type
|
||||
contentType := mime.TypeByExtension(filepath.Ext(path))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
gw := gzip.NewWriter(w)
|
||||
defer gw.Close()
|
||||
|
||||
_, err = io.Copy(gw, file)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, path, stat.ModTime(), file)
|
||||
}
|
||||
@ -1,405 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetFilesContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
name: "simple filename",
|
||||
filename: "test.txt",
|
||||
expectedHeader: `inline; filename=test.txt`,
|
||||
},
|
||||
{
|
||||
name: "filename with extension",
|
||||
filename: "presentation.pptx",
|
||||
expectedHeader: `inline; filename=presentation.pptx`,
|
||||
},
|
||||
{
|
||||
name: "filename with multiple dots",
|
||||
filename: "archive.tar.gz",
|
||||
expectedHeader: `inline; filename=archive.tar.gz`,
|
||||
},
|
||||
{
|
||||
name: "filename with spaces",
|
||||
filename: "my document.pdf",
|
||||
expectedHeader: `inline; filename="my document.pdf"`,
|
||||
},
|
||||
{
|
||||
name: "filename with quotes",
|
||||
filename: `file"name.txt`,
|
||||
expectedHeader: `inline; filename="file\"name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "filename with backslash",
|
||||
filename: `file\name.txt`,
|
||||
expectedHeader: `inline; filename="file\\name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "unicode filename",
|
||||
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
||||
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
||||
},
|
||||
{
|
||||
name: "dotfile preserved",
|
||||
filename: ".env",
|
||||
expectedHeader: `inline; filename=.env`,
|
||||
},
|
||||
{
|
||||
name: "dotfile with extension preserved",
|
||||
filename: ".gitignore",
|
||||
expectedHeader: `inline; filename=.gitignore`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory and file
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, tt.filename)
|
||||
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with nested structure
|
||||
tempDir := t.TempDir()
|
||||
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
||||
err = os.MkdirAll(nestedDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(nestedDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header uses only the base filename, not the full path
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with a test file
|
||||
tempDir := t.TempDir()
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(tempDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
||||
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file for gzip compression")
|
||||
|
||||
// Create a temp file with known content
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.txt")
|
||||
err = os.WriteFile(tempFile, originalContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Decompress the gzip response body
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
|
||||
func TestPostFiles_GzipUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "uploaded.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify the file was written with the original (decompressed) content
|
||||
data, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
}
|
||||
|
||||
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
||||
|
||||
// --- Upload with gzip ---
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
uploadReq.Header.Set("Content-Encoding", "gzip")
|
||||
uploadW := httptest.NewRecorder()
|
||||
|
||||
uploadParams := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(uploadW, uploadReq, uploadParams)
|
||||
|
||||
uploadResp := uploadW.Result()
|
||||
defer uploadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
||||
|
||||
// --- Download with gzip ---
|
||||
|
||||
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
||||
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
||||
downloadW := httptest.NewRecorder()
|
||||
|
||||
downloadParams := GetFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(downloadW, downloadReq, downloadParams)
|
||||
|
||||
downloadResp := downloadW.Result()
|
||||
defer downloadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
||||
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
||||
|
||||
// Decompress and verify content matches original
|
||||
gzReader, err := gzip.NewReader(downloadResp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
@ -1,229 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncodingGzip is the gzip content encoding.
|
||||
EncodingGzip = "gzip"
|
||||
// EncodingIdentity means no encoding (passthrough).
|
||||
EncodingIdentity = "identity"
|
||||
// EncodingWildcard means any encoding is acceptable.
|
||||
EncodingWildcard = "*"
|
||||
)
|
||||
|
||||
// SupportedEncodings lists the content encodings supported for file transfer.
|
||||
// The order matters - encodings are checked in order of preference.
|
||||
var SupportedEncodings = []string{
|
||||
EncodingGzip,
|
||||
}
|
||||
|
||||
// encodingWithQuality holds an encoding name and its quality value.
|
||||
type encodingWithQuality struct {
|
||||
encoding string
|
||||
quality float64
|
||||
}
|
||||
|
||||
// isSupportedEncoding checks if the given encoding is in the supported list.
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func isSupportedEncoding(encoding string) bool {
|
||||
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
|
||||
}
|
||||
|
||||
// parseEncodingWithQuality parses an encoding value and extracts the quality.
|
||||
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func parseEncodingWithQuality(value string) encodingWithQuality {
|
||||
value = strings.TrimSpace(value)
|
||||
quality := 1.0
|
||||
|
||||
if idx := strings.Index(value, ";"); idx != -1 {
|
||||
params := value[idx+1:]
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
|
||||
// Parse q=X.X parameter
|
||||
for param := range strings.SplitSeq(params, ";") {
|
||||
param = strings.TrimSpace(param)
|
||||
if strings.HasPrefix(strings.ToLower(param), "q=") {
|
||||
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
|
||||
quality = q
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize encoding to lowercase per RFC 7231
|
||||
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
|
||||
}
|
||||
|
||||
// parseEncoding extracts the encoding name from a header value, stripping quality.
|
||||
func parseEncoding(value string) string {
|
||||
return parseEncodingWithQuality(value).encoding
|
||||
}
|
||||
|
||||
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// If no Content-Encoding header is present, returns empty string.
|
||||
func parseContentEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Content-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encoding := parseEncoding(header)
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if !isSupportedEncoding(encoding) {
|
||||
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
|
||||
}
|
||||
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
|
||||
// the parsed encodings along with the identity rejection state.
|
||||
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
|
||||
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
|
||||
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
|
||||
if header == "" {
|
||||
return nil, false // identity not rejected when header is empty
|
||||
}
|
||||
|
||||
// Parse all encodings with their quality values
|
||||
var encodings []encodingWithQuality
|
||||
for value := range strings.SplitSeq(header, ",") {
|
||||
eq := parseEncodingWithQuality(value)
|
||||
encodings = append(encodings, eq)
|
||||
}
|
||||
|
||||
// Check if identity is rejected per RFC 7231 Section 5.3.4:
|
||||
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
|
||||
// without a more specific entry for identity with q>0.
|
||||
identityRejected := false
|
||||
identityExplicitlyAccepted := false
|
||||
wildcardRejected := false
|
||||
|
||||
for _, eq := range encodings {
|
||||
switch eq.encoding {
|
||||
case EncodingIdentity:
|
||||
if eq.quality == 0 {
|
||||
identityRejected = true
|
||||
} else {
|
||||
identityExplicitlyAccepted = true
|
||||
}
|
||||
case EncodingWildcard:
|
||||
if eq.quality == 0 {
|
||||
wildcardRejected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcardRejected && !identityExplicitlyAccepted {
|
||||
identityRejected = true
|
||||
}
|
||||
|
||||
return encodings, identityRejected
|
||||
}
|
||||
|
||||
// isIdentityAcceptable checks if identity encoding is acceptable based on the
|
||||
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
|
||||
// implicitly acceptable unless explicitly rejected with q=0.
|
||||
func isIdentityAcceptable(r *http.Request) bool {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
_, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
return !identityRejected
|
||||
}
|
||||
|
||||
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
|
||||
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
|
||||
// identity is always implicitly acceptable unless explicitly rejected with q=0.
|
||||
// If no Accept-Encoding header is present, returns empty string (identity).
|
||||
func parseAcceptEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encodings, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
// Sort by quality value (highest first)
|
||||
sort.Slice(encodings, func(i, j int) bool {
|
||||
return encodings[i].quality > encodings[j].quality
|
||||
})
|
||||
|
||||
// Find the best supported encoding
|
||||
for _, eq := range encodings {
|
||||
// Skip encodings with q=0 (explicitly rejected)
|
||||
if eq.quality == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if eq.encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
|
||||
if eq.encoding == EncodingWildcard {
|
||||
if identityRejected && len(SupportedEncodings) > 0 {
|
||||
return SupportedEncodings[0], nil
|
||||
}
|
||||
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if isSupportedEncoding(eq.encoding) {
|
||||
return eq.encoding, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Per RFC 7231, identity is implicitly acceptable unless rejected
|
||||
if !identityRejected {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Identity rejected and no supported encodings found
|
||||
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
|
||||
}
|
||||
|
||||
// getDecompressedBody returns a reader that decompresses the request body based on
|
||||
// Content-Encoding header. Returns the original body if no encoding is specified.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// The caller is responsible for closing both the returned ReadCloser and the
|
||||
// original request body (r.Body) separately.
|
||||
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
|
||||
encoding, err := parseContentEncoding(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return r.Body, nil
|
||||
}
|
||||
|
||||
switch encoding {
|
||||
case EncodingGzip:
|
||||
gzReader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
|
||||
return gzReader, nil
|
||||
default:
|
||||
// This shouldn't happen if isSupportedEncoding is correct
|
||||
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
|
||||
}
|
||||
}
|
||||
@ -1,496 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSupportedEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("gzip is supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("GZIP"))
|
||||
})
|
||||
|
||||
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("Gzip"))
|
||||
})
|
||||
|
||||
t.Run("br is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("br"))
|
||||
})
|
||||
|
||||
t.Run("deflate is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("deflate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncodingWithQuality(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip ; q=0.8")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.8, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles invalid quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=invalid")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
|
||||
})
|
||||
|
||||
t.Run("trims whitespace from encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality(" gzip ")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("GZIP")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
})
|
||||
|
||||
t.Run("normalizes mixed case encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("Gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding as-is", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("trims whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding(" gzip "))
|
||||
})
|
||||
|
||||
t.Run("strips quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
|
||||
})
|
||||
|
||||
t.Run("strips quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseContentEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "Gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := parseContentEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
assert.Contains(t, err.Error(), "supported: [gzip]")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("skips encoding with q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, identity;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*, identity;q=0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
|
||||
})
|
||||
|
||||
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingGzip, encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetDecompressedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid gzip data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
invalidGzip := []byte("this is not gzip data")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to create gzip reader")
|
||||
})
|
||||
|
||||
t.Run("returns original body for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
)
|
||||
|
||||
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
|
||||
|
||||
envs := make(EnvVars)
|
||||
a.defaults.EnvVars.Range(func(key, value string) bool {
|
||||
envs[key] = value
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(envs); err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
|
||||
}
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func jsonError(w http.ResponseWriter, code int, err error) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
w.WriteHeader(code)
|
||||
encodeErr := json.NewEncoder(w).Encode(Error{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
})
|
||||
if encodeErr != nil {
|
||||
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml
|
||||
@ -1,296 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/txn2/txeh"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||
)
|
||||
|
||||
// validateInitAccessToken validates the access token for /init requests.
|
||||
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||
// If neither exists, first-time setup is allowed.
|
||||
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||
requestTokenSet := requestToken.IsSet()
|
||||
|
||||
// Fast path: token matches existing
|
||||
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check MMDS only if token didn't match existing
|
||||
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||
|
||||
switch {
|
||||
case matchesMMDS:
|
||||
return nil
|
||||
case !a.accessToken.IsSet() && !mmdsExists:
|
||||
return nil // first-time setup
|
||||
case !requestTokenSet:
|
||||
return ErrAccessTokenResetNotAuthorized
|
||||
default:
|
||||
return ErrAccessTokenMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||
// Returns (matches, mmdsExists).
|
||||
//
|
||||
// The MMDS hash is set by the orchestrator during Resume:
|
||||
// - hash(token): requires this specific token
|
||||
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||
// - "": MMDS not properly configured, no authorization granted
|
||||
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||
if a.isNotFC {
|
||||
return false, false
|
||||
}
|
||||
|
||||
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if mmdsHash == "" {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if !requestToken.IsSet() {
|
||||
return mmdsHash == keys.HashAccessToken(""), true
|
||||
}
|
||||
|
||||
tokenBytes, err := requestToken.Bytes()
|
||||
if err != nil {
|
||||
return false, true
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||
}
|
||||
|
||||
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||
|
||||
if r.Body != nil {
|
||||
// Read raw body so we can wipe it after parsing
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// Ensure body is wiped after we're done
|
||||
defer memguard.WipeBytes(body)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var initRequest PostInitJSONBody
|
||||
if len(body) > 0 {
|
||||
err = json.Unmarshal(body, &initRequest)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||
defer initRequest.AccessToken.Destroy()
|
||||
|
||||
a.initLock.Lock()
|
||||
defer a.initLock.Unlock()
|
||||
|
||||
// Update data only if the request is newer or if there's no timestamp at all
|
||||
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||
err = a.SetData(ctx, logger, initRequest)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Error().Msgf("Failed to set data: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() { //nolint:contextcheck // TODO: fix this later
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||
}()
|
||||
|
||||
// Start the port scanner and forwarder if they were stopped by a
|
||||
// pre-snapshot prepare call. Start is a no-op if already running,
|
||||
// so this is safe on first boot and only takes effect after restore.
|
||||
if a.portSubsystem != nil {
|
||||
a.portSubsystem.Start(a.rootCtx)
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||
// Validate access token before proceeding with any action
|
||||
// The request must provide a token that is either:
|
||||
// 1. Matches the existing access token (if set), OR
|
||||
// 2. Matches the MMDS hash (for token change during resume)
|
||||
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.EnvVars != nil {
|
||||
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||
|
||||
for key, value := range *data.EnvVars {
|
||||
logger.Debug().Msgf("Setting env var for %s", key)
|
||||
a.defaults.EnvVars.Store(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if data.AccessToken.IsSet() {
|
||||
logger.Debug().Msg("Setting access token")
|
||||
a.accessToken.TakeFrom(data.AccessToken)
|
||||
} else if a.accessToken.IsSet() {
|
||||
logger.Debug().Msg("Clearing access token")
|
||||
a.accessToken.Destroy()
|
||||
}
|
||||
|
||||
if data.HyperloopIP != nil {
|
||||
go a.SetupHyperloop(*data.HyperloopIP)
|
||||
}
|
||||
|
||||
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||
a.defaults.User = *data.DefaultUser
|
||||
}
|
||||
|
||||
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||
a.defaults.Workdir = data.DefaultWorkdir
|
||||
}
|
||||
|
||||
if data.VolumeMounts != nil {
|
||||
for _, volume := range *data.VolumeMounts {
|
||||
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||
|
||||
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||
commands := [][]string{
|
||||
{"mkdir", "-p", path},
|
||||
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||
}
|
||||
|
||||
for _, command := range commands {
|
||||
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||
|
||||
logger := a.getLogger(err)
|
||||
|
||||
logger.
|
||||
Strs("command", command).
|
||||
Str("output", string(data)).
|
||||
Msg("Mount NFS")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) SetupHyperloop(address string) {
|
||||
a.hyperloopLock.Lock()
|
||||
defer a.hyperloopLock.Unlock()
|
||||
|
||||
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||
} else {
|
||||
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||
}
|
||||
}
|
||||
|
||||
const eventsHost = "events.wrenn.local"
|
||||
|
||||
func rewriteHostsFile(address, path string) error {
|
||||
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||
ReadFilePath: path,
|
||||
WriteFilePath: path,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create hosts: %w", err)
|
||||
}
|
||||
|
||||
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||
// This will remove any existing entries for events.wrenn.local first
|
||||
ipFamily, err := getIPFamily(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip family: %w", err)
|
||||
}
|
||||
|
||||
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||
return nil // nothing to be done
|
||||
}
|
||||
|
||||
hosts.AddHost(address, eventsHost)
|
||||
|
||||
return hosts.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAddress = errors.New("invalid IP address")
|
||||
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||
)
|
||||
|
||||
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||
addressIP, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case addressIP.Is4():
|
||||
return txeh.IPFamilyV4, nil
|
||||
case addressIP.Is6():
|
||||
return txeh.IPFamilyV6, nil
|
||||
default:
|
||||
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||
}
|
||||
}
|
||||
@ -1,524 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestSimpleCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]func(string) string{
|
||||
"both newlines": func(s string) string { return s },
|
||||
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
|
||||
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
|
||||
"no newline prefix or suffix": strings.TrimSpace,
|
||||
}
|
||||
|
||||
for name, preprocessor := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
value := `
|
||||
# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
`
|
||||
value = preprocessor(value)
|
||||
inputPath := filepath.Join(tempDir, "hosts")
|
||||
err := os.WriteFile(inputPath, []byte(value), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rewriteHostsFile("127.0.0.3", inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func secureTokenPtr(s string) *SecureToken {
|
||||
token := &SecureToken{}
|
||||
_ = token.Set([]byte(s))
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
type mockMMDSClient struct {
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
|
||||
return m.hash, m.err
|
||||
}
|
||||
|
||||
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
if accessToken != nil {
|
||||
api.accessToken.TakeFrom(accessToken)
|
||||
}
|
||||
api.mmdsClient = mmdsClient
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
func TestValidateInitAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "fast path: token matches existing",
|
||||
accessToken: secureTokenPtr("secret-token"),
|
||||
requestToken: secureTokenPtr("secret-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "MMDS match: token hash matches MMDS hash",
|
||||
accessToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, MMDS error",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, empty MMDS hash",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: both tokens nil, no MMDS",
|
||||
accessToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, MMDS hash mismatch",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, MMDS exists",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.accessToken, mmdsClient)
|
||||
|
||||
err := api.validateInitAccessToken(ctx, tt.requestToken)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMMDSHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
token := "my-secret-token"
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetData(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := zerolog.Nop()
|
||||
|
||||
t.Run("access token updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existingToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
wantFinalToken *SecureToken
|
||||
}{
|
||||
{
|
||||
name: "first-time setup: sets initial token",
|
||||
existingToken: nil,
|
||||
requestToken: secureTokenPtr("initial-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("initial-token"),
|
||||
},
|
||||
{
|
||||
name: "first-time setup: nil request token leaves token unset",
|
||||
existingToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
{
|
||||
name: "re-init with same token: token unchanged",
|
||||
existingToken: secureTokenPtr("same-token"),
|
||||
requestToken: secureTokenPtr("same-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("same-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: updates token when hash matches",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("new-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: fails when hash doesn't match",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("old-token"),
|
||||
},
|
||||
{
|
||||
name: "fails when existing token and request token mismatch without MMDS",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request token",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request with MMDS present",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken(""),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.existingToken, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: tt.requestToken,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.wantFinalToken == nil {
|
||||
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
|
||||
} else {
|
||||
require.True(t, api.accessToken.IsSet(), "expected token to be set")
|
||||
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets environment variables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
|
||||
data := PostInitJSONBody{
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
val, ok := api.defaults.EnvVars.Load("FOO")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bar", val)
|
||||
val, ok = api.defaults.EnvVars.Load("BAZ")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "qux", val)
|
||||
})
|
||||
|
||||
t.Run("sets default user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr("testuser"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("does not set default user when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
api.defaults.User = "original"
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "original", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("sets default workdir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/home/user", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("does not set default workdir when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
originalWorkdir := "/original"
|
||||
api.defaults.Workdir = &originalWorkdir
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/original", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("sets multiple fields at once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"KEY": "value"}
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: secureTokenPtr("token"),
|
||||
DefaultUser: utilsShared.ToPtr("user"),
|
||||
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
|
||||
assert.Equal(t, "user", api.defaults.User)
|
||||
assert.Equal(t, "/workdir", *api.defaults.Workdir)
|
||||
val, ok := api.defaults.EnvVars.Load("KEY")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value", val)
|
||||
})
|
||||
}
|
||||
@ -1,214 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotSet = errors.New("access token not set")
|
||||
ErrTokenEmpty = errors.New("empty token not allowed")
|
||||
)
|
||||
|
||||
// SecureToken wraps memguard for secure token storage.
|
||||
// It uses LockedBuffer which provides memory locking, guard pages,
|
||||
// and secure zeroing on destroy.
|
||||
type SecureToken struct {
|
||||
mu sync.RWMutex
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// Set securely replaces the token, destroying the old one first.
|
||||
// The old token memory is zeroed before the new token is stored.
|
||||
// The input byte slice is wiped after copying to secure memory.
|
||||
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
|
||||
func (s *SecureToken) Set(token []byte) error {
|
||||
if len(token) == 0 {
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Destroy old token first (zeros memory)
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
|
||||
s.buffer = memguard.NewBufferFromBytes(token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
|
||||
// directly into memguard, wiping the input bytes after copying.
|
||||
//
|
||||
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
|
||||
// so they never contain JSON escape sequences.
|
||||
func (s *SecureToken) UnmarshalJSON(data []byte) error {
|
||||
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
|
||||
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token JSON string")
|
||||
}
|
||||
|
||||
content := data[1 : len(data)-1]
|
||||
|
||||
// Access tokens are hex strings - reject if contains backslash
|
||||
if bytes.ContainsRune(content, '\\') {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token: unexpected escape sequence")
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Allocate secure buffer and copy directly into it
|
||||
s.buffer = memguard.NewBuffer(len(content))
|
||||
copy(s.buffer.Bytes(), content)
|
||||
|
||||
// Wipe the input data
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TakeFrom transfers the token from src to this SecureToken, destroying any
|
||||
// existing token. The source token is cleared after transfer.
|
||||
// This avoids copying the underlying bytes.
|
||||
func (s *SecureToken) TakeFrom(src *SecureToken) {
|
||||
if src == nil || s == src {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract buffer from source
|
||||
src.mu.Lock()
|
||||
buffer := src.buffer
|
||||
src.buffer = nil
|
||||
src.mu.Unlock()
|
||||
|
||||
// Install buffer in destination
|
||||
s.mu.Lock()
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
}
|
||||
s.buffer = buffer
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Equals checks if token matches using constant-time comparison.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) Equals(token string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo([]byte(token))
|
||||
}
|
||||
|
||||
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
|
||||
// Returns false if either receiver or other is nil.
|
||||
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
|
||||
if s == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s == other {
|
||||
return s.IsSet()
|
||||
}
|
||||
|
||||
// Get a copy of other's bytes (avoids holding two locks simultaneously)
|
||||
otherBytes, err := other.Bytes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer memguard.WipeBytes(otherBytes)
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo(otherBytes)
|
||||
}
|
||||
|
||||
// IsSet returns true if a token is stored.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) IsSet() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.buffer != nil && s.buffer.IsAlive()
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the token bytes (for signature generation).
|
||||
// The caller should zero the returned slice after use.
|
||||
// Returns ErrTokenNotSet if the receiver is nil.
|
||||
func (s *SecureToken) Bytes() ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
// Return a copy (unavoidable for signature generation)
|
||||
src := s.buffer.Bytes()
|
||||
result := make([]byte, len(src))
|
||||
copy(result, src)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Destroy securely wipes the token from memory.
|
||||
// No-op if the receiver is nil.
|
||||
func (s *SecureToken) Destroy() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
}
|
||||
@ -1,463 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSecureTokenSetAndEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Initially not set
|
||||
assert.False(t, st.IsSet(), "token should not be set initially")
|
||||
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
|
||||
|
||||
// Set token
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet(), "token should be set after Set()")
|
||||
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
|
||||
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
|
||||
assert.False(t, st.Equals(""), "equals should return false for empty token")
|
||||
}
|
||||
|
||||
func TestSecureTokenReplace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set initial token
|
||||
err := st.Set([]byte("first-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("first-token"))
|
||||
|
||||
// Replace with new token (old one should be destroyed)
|
||||
err = st.Set([]byte("second-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("second-token"), "should match new token")
|
||||
assert.False(t, st.Equals("first-token"), "should not match old token")
|
||||
}
|
||||
|
||||
func TestSecureTokenDestroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set and then destroy
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
|
||||
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
|
||||
|
||||
// Destroy on already destroyed should be safe
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet())
|
||||
|
||||
// Nil receiver should be safe
|
||||
var nilToken *SecureToken
|
||||
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
|
||||
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
|
||||
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
|
||||
nilToken.Destroy() // should not panic
|
||||
|
||||
_, err = nilToken.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
|
||||
}
|
||||
|
||||
func TestSecureTokenBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Bytes should return error when not set
|
||||
_, err := st.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet)
|
||||
|
||||
// Set token and get bytes
|
||||
err = st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes, err := st.Bytes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("test-token"), bytes)
|
||||
|
||||
// Zero out the bytes (as caller should do)
|
||||
memguard.WipeBytes(bytes)
|
||||
|
||||
// Original should still be intact
|
||||
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
|
||||
|
||||
// After destroy, bytes should fail
|
||||
st.Destroy()
|
||||
_, err = st.Bytes()
|
||||
assert.ErrorIs(t, err, ErrTokenNotSet)
|
||||
}
|
||||
|
||||
func TestSecureTokenConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("initial-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 100
|
||||
|
||||
// Concurrent reads
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
st.IsSet()
|
||||
st.Equals("initial-token")
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrent writes
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
st.Set([]byte("token-" + string(rune('a'+idx))))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should still be in a valid state
|
||||
assert.True(t, st.IsSet())
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Setting empty token should return an error
|
||||
err := st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after empty token error")
|
||||
|
||||
// Setting nil should also return an error
|
||||
err = st.Set(nil)
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after nil token error")
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set a valid token first
|
||||
err := st.Set([]byte("valid-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
// Attempting to set empty token should fail and preserve existing token
|
||||
err = st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
|
||||
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
|
||||
}
|
||||
|
||||
func TestSecureTokenUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unmarshals valid JSON string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
})
|
||||
|
||||
t.Run("returns error for empty string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`""`))
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`not-valid-json`))
|
||||
require.Error(t, err)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("replaces existing token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = st.UnmarshalJSON([]byte(`"new-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("new-token"))
|
||||
assert.False(t, st.Equals("old-token"))
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer after parsing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte(`"secret-token-12345"`)
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("secret-token-12345"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer on error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with an empty token (will error)
|
||||
input := []byte(`""`)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the input buffer was still wiped
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects escape sequences", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "escape sequence")
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenSetWipesInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("wipes input buffer after storing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte("my-secret-token")
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenTakeFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transfers token from source to destination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("source-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.IsSet())
|
||||
assert.True(t, dst.Equals("source-token"))
|
||||
assert.False(t, src.IsSet(), "source should be empty after transfer")
|
||||
})
|
||||
|
||||
t.Run("replaces existing destination token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("new-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
err = dst.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.Equals("new-token"))
|
||||
assert.False(t, dst.Equals("old-token"))
|
||||
assert.False(t, src.IsSet())
|
||||
})
|
||||
|
||||
t.Run("handles nil source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(nil)
|
||||
|
||||
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
|
||||
assert.True(t, dst.Equals("existing-token"))
|
||||
})
|
||||
|
||||
t.Run("handles empty source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
|
||||
})
|
||||
|
||||
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st.TakeFrom(st)
|
||||
|
||||
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
|
||||
assert.True(t, st.Equals("token"), "token value should be unchanged")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenEqualsSecure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns true for matching tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st1.EqualsSecure(st2))
|
||||
assert.True(t, st2.EqualsSecure(st1))
|
||||
})
|
||||
|
||||
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies the fix for the lock ordering deadlock bug.
|
||||
|
||||
const iterations = 100
|
||||
|
||||
for range iterations {
|
||||
a := &SecureToken{}
|
||||
err := a.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &SecureToken{}
|
||||
err = b.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Goroutine 1: a.TakeFrom(b)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.TakeFrom(b)
|
||||
}()
|
||||
|
||||
// Goroutine 2: b.EqualsSecure(a)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.EqualsSecure(a)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false for different tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when comparing with nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st.EqualsSecure(nil))
|
||||
})
|
||||
|
||||
t.Run("returns false when other is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when self is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err := st2.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns true when set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns false when not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
|
||||
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
|
||||
})
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder)
|
||||
// and forces a GC cycle before Firecracker takes a VM snapshot. This ensures
|
||||
// the Go runtime's page allocator is in a consistent state when vCPUs are frozen.
|
||||
//
|
||||
// Called by the host agent as a best-effort signal before vm.Pause().
|
||||
func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
if a.portSubsystem != nil {
|
||||
a.portSubsystem.Stop()
|
||||
a.logger.Info().Msg("snapshot/prepare: port subsystem quiesced")
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -1,108 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
// MMDSClient provides access to MMDS metadata.
|
||||
type MMDSClient interface {
|
||||
GetAccessTokenHash(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
|
||||
type DefaultMMDSClient struct{}
|
||||
|
||||
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
|
||||
return host.GetAccessTokenHashFromMMDS(ctx)
|
||||
}
|
||||
|
||||
type API struct {
|
||||
isNotFC bool
|
||||
logger *zerolog.Logger
|
||||
accessToken *SecureToken
|
||||
defaults *execcontext.Defaults
|
||||
version string
|
||||
|
||||
mmdsChan chan *host.MMDSOpts
|
||||
hyperloopLock sync.Mutex
|
||||
mmdsClient MMDSClient
|
||||
|
||||
lastSetTime *utils.AtomicMax
|
||||
initLock sync.Mutex
|
||||
|
||||
// rootCtx is the parent context from main(), used to restart
|
||||
// long-lived goroutines after snapshot restore.
|
||||
rootCtx context.Context
|
||||
portSubsystem *publicport.PortSubsystem
|
||||
}
|
||||
|
||||
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem, version string) *API {
|
||||
return &API{
|
||||
logger: l,
|
||||
defaults: defaults,
|
||||
mmdsChan: mmdsChan,
|
||||
isNotFC: isNotFC,
|
||||
mmdsClient: &DefaultMMDSClient{},
|
||||
lastSetTime: utils.NewAtomicMax(),
|
||||
accessToken: &SecureToken{},
|
||||
rootCtx: rootCtx,
|
||||
portSubsystem: portSubsystem,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Health check")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"version": a.version,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Get metrics")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
metrics, err := host.GetMetrics()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get metrics")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to encode metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) getLogger(err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
|
||||
}
|
||||
|
||||
return a.logger.Info() //nolint:zerologlint // this is only prep
|
||||
}
|
||||
@ -1,311 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
|
||||
|
||||
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
|
||||
logger.Debug().
|
||||
Str("path", path).
|
||||
Msg("File processing")
|
||||
|
||||
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error ensuring directories: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
canBePreChowned := false
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
errMsg := fmt.Errorf("error getting file info: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, errMsg
|
||||
} else if err == nil {
|
||||
if stat.IsDir() {
|
||||
err := fmt.Errorf("path is a directory: %s", path)
|
||||
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
canBePreChowned = true
|
||||
}
|
||||
|
||||
hasBeenChowned := false
|
||||
if canBePreChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
err = fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
} else {
|
||||
hasBeenChowned = true
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = fmt.Errorf("not enough inodes available: %w", err)
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err := fmt.Errorf("error opening file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
if !hasBeenChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = file.ReadFrom(part)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = ErrNoDiskSpace
|
||||
if r.ContentLength > 0 {
|
||||
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
|
||||
}
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err = fmt.Errorf("error writing file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return http.StatusNoContent, nil
|
||||
}
|
||||
|
||||
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
|
||||
var pathToResolve string
|
||||
|
||||
if params.Path != nil {
|
||||
pathToResolve = *params.Path
|
||||
} else {
|
||||
var err error
|
||||
customPart := utils.NewCustomPart(part)
|
||||
pathToResolve, err = customPart.FileNameWithPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving path: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range *paths {
|
||||
if entry.Path == filePath {
|
||||
var alreadyUploaded []string
|
||||
for _, uploadedFile := range *paths {
|
||||
if uploadedFile.Path != filePath {
|
||||
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
|
||||
|
||||
if len(alreadyUploaded) > 1 {
|
||||
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
|
||||
}
|
||||
|
||||
return "", errMsg
|
||||
}
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
|
||||
defer part.Close()
|
||||
|
||||
if part.FormName() != "file" {
|
||||
return nil, http.StatusOK, nil
|
||||
}
|
||||
|
||||
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
logger := a.logger.
|
||||
With().
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("event_type", "file_processing").
|
||||
Logger()
|
||||
|
||||
status, err := processFile(r, filePath, part, uid, gid, logger)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return &EntryInfo{
|
||||
Path: filePath,
|
||||
Name: filepath.Base(filePath),
|
||||
Type: File,
|
||||
}, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
// Capture original body to ensure it's always closed
|
||||
originalBody := r.Body
|
||||
defer originalBody.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File write")
|
||||
}()
|
||||
|
||||
// Handle gzip-encoded request body
|
||||
body, err := getDecompressedBody(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error decompressing request body: %w", err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = body
|
||||
|
||||
f, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
uid, gid, err := permissions.GetUserIdInts(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error getting user ids: %w", err)
|
||||
|
||||
jsonError(w, http.StatusInternalServerError, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
paths := UploadSuccess{}
|
||||
|
||||
for {
|
||||
part, partErr := f.NextPart()
|
||||
|
||||
if partErr == io.EOF {
|
||||
// We're done reading the parts.
|
||||
break
|
||||
} else if partErr != nil {
|
||||
errMsg = fmt.Errorf("error reading form: %w", partErr)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
|
||||
if err != nil {
|
||||
errorCode = status
|
||||
errMsg = err
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if entry != nil {
|
||||
paths = append(paths, *entry)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(paths)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error marshaling response: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
@ -1,251 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
uid := os.Getuid()
|
||||
gid := os.Getgid()
|
||||
|
||||
newRequest := func(content []byte) (*http.Request, io.Reader) {
|
||||
request := &http.Request{
|
||||
ContentLength: int64(len(content)),
|
||||
}
|
||||
buffer := bytes.NewBuffer(content)
|
||||
|
||||
return request, buffer
|
||||
}
|
||||
|
||||
var emptyReq http.Request
|
||||
var emptyPart *bytes.Buffer
|
||||
var emptyLogger zerolog.Logger
|
||||
|
||||
t.Run("failed to ensure directories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error ensuring directories: ")
|
||||
})
|
||||
|
||||
t.Run("attempt to replace directory with a file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
|
||||
assert.ErrorContains(t, err, "path is a directory: ")
|
||||
})
|
||||
|
||||
t.Run("fail to create file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error opening file: ")
|
||||
})
|
||||
|
||||
t.Run("out of disk space", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
mountSize := 1024
|
||||
tempDir := createTmpfsMount(t, mountSize)
|
||||
|
||||
// create test file
|
||||
firstFileSize := mountSize / 2
|
||||
tempFile1 := filepath.Join(tempDir, "test-file-1")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new file that would fill up the
|
||||
secondFileContents := make([]byte, mountSize*2)
|
||||
for index := range secondFileContents {
|
||||
secondFileContents[index] = 'a'
|
||||
}
|
||||
|
||||
// try to replace it
|
||||
request, buffer := newRequest(secondFileContents)
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
|
||||
})
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("overwrite file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to replace it
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough disk space available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file with no inodes available", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough inodes available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
|
||||
newContent := []byte("102\n")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
|
||||
t.Run("replace file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
newContent := []byte("new-file-contents")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
}
|
||||
|
||||
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
|
||||
t.Helper()
|
||||
|
||||
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
|
||||
}
|
||||
|
||||
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
|
||||
t.Helper()
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"mount",
|
||||
"tmpfs",
|
||||
tempDir,
|
||||
"-t", "tmpfs",
|
||||
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx := context.WithoutCancel(t.Context())
|
||||
cmd := exec.CommandContext(ctx, "umount", tempDir)
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return tempDir
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package execcontext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Defaults struct {
|
||||
EnvVars *utils.Map[string, string]
|
||||
User string
|
||||
Workdir *string
|
||||
}
|
||||
|
||||
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
|
||||
if workdir != "" {
|
||||
return workdir
|
||||
}
|
||||
|
||||
if defaultWorkdir != nil {
|
||||
return *defaultWorkdir
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
|
||||
if username != nil {
|
||||
return *username, nil
|
||||
}
|
||||
|
||||
if defaultUsername != "" {
|
||||
return defaultUsername, nil
|
||||
}
|
||||
|
||||
return "", errors.New("username not provided")
|
||||
}
|
||||
@ -1,96 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package host
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
|
||||
|
||||
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
|
||||
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
|
||||
|
||||
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
|
||||
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
|
||||
|
||||
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
|
||||
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
|
||||
}
|
||||
|
||||
func GetMetrics() (*Metrics, error) {
|
||||
v, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memUsedMiB := v.Used / 1024 / 1024
|
||||
memTotalMiB := v.Total / 1024 / 1024
|
||||
|
||||
cpuTotal, err := cpu.Counts(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPcts, err := cpu.Percent(0, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPct := cpuUsedPcts[0]
|
||||
cpuUsedPctRounded := float32(cpuUsedPct)
|
||||
if cpuUsedPct > 0 {
|
||||
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
|
||||
}
|
||||
|
||||
diskMetrics, err := diskStats("/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
Timestamp: time.Now().UTC().Unix(),
|
||||
CPUCount: uint32(cpuTotal),
|
||||
CPUUsedPercent: cpuUsedPctRounded,
|
||||
MemUsedMiB: memUsedMiB,
|
||||
MemTotalMiB: memTotalMiB,
|
||||
MemTotal: v.Total,
|
||||
MemUsed: v.Used,
|
||||
DiskUsed: diskMetrics.Total - diskMetrics.Available,
|
||||
DiskTotal: diskMetrics.Total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type diskSpace struct {
|
||||
Total uint64
|
||||
Available uint64
|
||||
}
|
||||
|
||||
func diskStats(path string) (diskSpace, error) {
|
||||
var st unix.Statfs_t
|
||||
if err := unix.Statfs(path, &st); err != nil {
|
||||
return diskSpace{}, err
|
||||
}
|
||||
|
||||
block := uint64(st.Bsize)
|
||||
|
||||
// all data blocks
|
||||
total := st.Blocks * block
|
||||
// blocks available
|
||||
available := st.Bavail * block
|
||||
|
||||
return diskSpace{Total: total, Available: available}, nil
|
||||
}
|
||||
@ -1,185 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
|
||||
|
||||
mmdsDefaultAddress = "169.254.169.254"
|
||||
mmdsTokenExpiration = 60 * time.Second
|
||||
|
||||
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var mmdsAccessTokenClient = &http.Client{
|
||||
Timeout: mmdsAccessTokenRequestClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
}
|
||||
|
||||
type MMDSOpts struct {
|
||||
SandboxID string `json:"instanceID"`
|
||||
TemplateID string `json:"envID"`
|
||||
LogsCollectorAddress string `json:"address"`
|
||||
AccessTokenHash string `json:"accessTokenHash"`
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
|
||||
opts.SandboxID = sandboxID
|
||||
opts.TemplateID = templateID
|
||||
opts.LogsCollectorAddress = collectorAddress
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
|
||||
parsed := make(map[string]any)
|
||||
|
||||
err := json.Unmarshal(jsonLogs, &parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed["instanceID"] = opts.SandboxID
|
||||
parsed["envID"] = opts.TemplateID
|
||||
|
||||
data, err := json.Marshal(parsed)
|
||||
|
||||
return data, err
|
||||
}
|
||||
|
||||
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := string(body)
|
||||
|
||||
if len(token) == 0 {
|
||||
return "", fmt.Errorf("mmds token is an empty string")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token"] = []string{token}
|
||||
request.Header["Accept"] = []string{"application/json"}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts MMDSOpts
|
||||
|
||||
err = json.Unmarshal(body, &opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
|
||||
// This is used to validate that /init requests come from the orchestrator.
|
||||
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
|
||||
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
|
||||
}
|
||||
|
||||
return opts.AccessTokenHash, nil
|
||||
}
|
||||
|
||||
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
|
||||
httpClient := &http.Client{}
|
||||
defer httpClient.CloseIdleConnections()
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
token, err := getMMDSToken(ctx, httpClient)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
|
||||
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
|
||||
}
|
||||
|
||||
if mmdsOpts.LogsCollectorAddress != "" {
|
||||
mmdsChan <- mmdsOpts
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBufferSize = 2 << 15
|
||||
defaultTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
|
||||
timer := time.NewTicker(defaultTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
var buffer []byte
|
||||
defer func() {
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
}
|
||||
case data, ok := <-dataCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
buffer = append(buffer, data...)
|
||||
|
||||
if len(buffer) >= defaultMaxBufferSize {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package exporter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
)
|
||||
|
||||
const ExporterTimeout = 10 * time.Second
|
||||
|
||||
type HTTPExporter struct {
|
||||
client http.Client
|
||||
logs [][]byte
|
||||
isNotFC bool
|
||||
mmdsOpts *host.MMDSOpts
|
||||
|
||||
// Concurrency coordination
|
||||
triggers chan struct{}
|
||||
logLock sync.RWMutex
|
||||
mmdsLock sync.RWMutex
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
|
||||
exporter := &HTTPExporter{
|
||||
client: http.Client{
|
||||
Timeout: ExporterTimeout,
|
||||
},
|
||||
triggers: make(chan struct{}, 1),
|
||||
isNotFC: isNotFC,
|
||||
startOnce: sync.Once{},
|
||||
mmdsOpts: &host.MMDSOpts{
|
||||
SandboxID: "unknown",
|
||||
TemplateID: "unknown",
|
||||
LogsCollectorAddress: "",
|
||||
},
|
||||
}
|
||||
|
||||
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
|
||||
|
||||
return exporter
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
|
||||
if address == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := w.client.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printLog(logs []byte) {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(logs))
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case mmdsOpts, ok := <-mmdsChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
w.mmdsLock.Lock()
|
||||
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
|
||||
w.mmdsLock.Unlock()
|
||||
|
||||
w.startOnce.Do(func() {
|
||||
go w.start(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) start(ctx context.Context) {
|
||||
for range w.triggers {
|
||||
logs := w.getAllLogs()
|
||||
|
||||
if len(logs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if w.isNotFC {
|
||||
for _, log := range logs {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(log))
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, logLine := range logs {
|
||||
w.mmdsLock.RLock()
|
||||
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
|
||||
w.mmdsLock.RUnlock()
|
||||
if err != nil {
|
||||
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
|
||||
if err != nil {
|
||||
log.Printf("error sending instance logs: %+v", err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) resumeProcessing() {
|
||||
select {
|
||||
case w.triggers <- struct{}{}:
|
||||
default:
|
||||
// Exporter processing already triggered
|
||||
// This is expected behavior if the exporter is already processing logs
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) Write(logs []byte) (int, error) {
|
||||
logsCopy := make([]byte, len(logs))
|
||||
copy(logsCopy, logs)
|
||||
|
||||
go w.addLogs(logsCopy)
|
||||
|
||||
return len(logs), nil
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) getAllLogs() [][]byte {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
logs := w.logs
|
||||
w.logs = nil
|
||||
|
||||
return logs
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) addLogs(logs []byte) {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
w.logs = append(w.logs, logs)
|
||||
|
||||
w.resumeProcessing()
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type OperationID string
|
||||
|
||||
const (
|
||||
OperationIDKey OperationID = "operation_id"
|
||||
DefaultHTTPMethod string = "POST"
|
||||
)
|
||||
|
||||
var operationID = atomic.Int32{}
|
||||
|
||||
func AssignOperationID() string {
|
||||
id := operationID.Add(1)
|
||||
|
||||
return strconv.Itoa(int(id))
|
||||
}
|
||||
|
||||
func AddRequestIDToContext(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
|
||||
}
|
||||
|
||||
func formatMethod(method string) string {
|
||||
parts := strings.Split(method, ".")
|
||||
if len(parts) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
split := strings.Split(parts[1], "/")
|
||||
if len(split) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
servicePart := split[0]
|
||||
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
|
||||
|
||||
methodPart := split[1]
|
||||
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
|
||||
|
||||
return fmt.Sprintf("%s %s", servicePart, methodPart)
|
||||
}
|
||||
|
||||
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
|
||||
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return connect.UnaryFunc(func(
|
||||
ctx context.Context,
|
||||
req connect.AnyRequest,
|
||||
) (connect.AnyResponse, error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
res, err := next(ctx, req)
|
||||
|
||||
l := logger.
|
||||
Err(err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
l = l.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
l = l.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
l = l.Interface("response", nil)
|
||||
}
|
||||
|
||||
l.Msg(formatMethod(req.Spec().Procedure))
|
||||
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
return connect.UnaryInterceptorFunc(interceptor)
|
||||
}
|
||||
|
||||
func LogServerStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
req *connect.Request[R],
|
||||
stream *connect.ServerStream[T],
|
||||
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
|
||||
) error {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
l := logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
err := handler(ctx, req, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
} else {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func LogClientStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
stream *connect.ClientStream[T],
|
||||
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
|
||||
) (*connect.Response[R], error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
|
||||
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
res, err := handler(ctx, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Return logger with error level if err is not nil, otherwise return logger with debug level
|
||||
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
|
||||
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
|
||||
)
|
||||
|
||||
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
|
||||
zerolog.TimestampFieldName = "timestamp"
|
||||
zerolog.TimeFieldFormat = time.RFC3339Nano
|
||||
|
||||
exporters := []io.Writer{}
|
||||
|
||||
if isNotFC {
|
||||
exporters = append(exporters, os.Stdout)
|
||||
} else {
|
||||
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
|
||||
}
|
||||
|
||||
l := zerolog.
|
||||
New(io.MultiWriter(exporters...)).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger().
|
||||
Level(zerolog.DebugLevel)
|
||||
|
||||
return &l
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
|
||||
username, _, ok := req.BasicAuth()
|
||||
if !ok {
|
||||
// When no username is provided, ignore the authentication method (not all endpoints require it)
|
||||
// Missing user is then handled in the GetAuthUser function
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid username: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
|
||||
u, ok := authn.GetInfo(ctx).(*user.User)
|
||||
if !ok {
|
||||
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid default user: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
)
|
||||
|
||||
const defaultKeepAliveInterval = 90 * time.Second
|
||||
|
||||
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
|
||||
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
|
||||
|
||||
var interval time.Duration
|
||||
|
||||
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
|
||||
if err != nil {
|
||||
interval = defaultKeepAliveInterval
|
||||
} else {
|
||||
interval = time.Duration(keepAliveIntervalInt) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
|
||||
return ticker, func() {
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func expand(path, homedir string) (string, error) {
|
||||
if len(path) == 0 {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if path[0] != '~' {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
|
||||
return "", errors.New("cannot expand user-specific home dir")
|
||||
}
|
||||
|
||||
return filepath.Join(homedir, path[1:]), nil
|
||||
}
|
||||
|
||||
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
|
||||
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
|
||||
|
||||
path, err := expand(path, user.HomeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
|
||||
}
|
||||
|
||||
if filepath.IsAbs(path) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// The filepath.Abs can correctly resolve paths like /home/user/../file
|
||||
path = filepath.Join(user.HomeDir, path)
|
||||
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
|
||||
}
|
||||
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
func getSubpaths(path string) (subpaths []string) {
|
||||
for {
|
||||
subpaths = append(subpaths, path)
|
||||
|
||||
path = filepath.Dir(path)
|
||||
if path == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(subpaths)
|
||||
|
||||
return subpaths
|
||||
}
|
||||
|
||||
func EnsureDirs(path string, uid, gid int) error {
|
||||
subpaths := getSubpaths(path)
|
||||
for _, subpath := range subpaths {
|
||||
info, err := os.Stat(subpath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to stat directory: %w", err)
|
||||
}
|
||||
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
err = os.Mkdir(subpath, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(subpath, uid, gid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to chown directory: %w", err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path is a file: %s", subpath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,46 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
|
||||
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return uint32(newUID), uint32(newGID), nil
|
||||
}
|
||||
|
||||
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
|
||||
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return int(newUID), int(newGID), nil
|
||||
}
|
||||
|
||||
func GetUser(username string) (u *user.User, err error) {
|
||||
u, err = user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user