1
0
forked from wrenn/wrenn

2 Commits
v0.1.5 ... main

Author SHA1 Message Date
d439dbcc29 Add CI for release notes 2
All checks were successful
ci/woodpecker/push/pipeline Pipeline was successful
2026-05-13 22:40:11 +06:00
4707f16c76 v0.1.6 (#45)
## What's New?
Performance updates for large capsules, admin panel enhancement and bug fixes

### Envd
- Fixed bug with sandbox metrics calculation
- Page cache drop and balloon inflation to reduce memfile snapshot
- Updated rpc timeout logic for better control
- Added tests

### Admin Panel
- Add/Remove platform admin
- Updated template deletion logic for fine grained permission

### Others
- Minor frontend visual improvement
- Minor bugfixes
- Version bump

Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/wrenn#45
Co-authored-by: pptx704 <rafeed@omukk.dev>
Co-committed-by: pptx704 <rafeed@omukk.dev>
2026-05-13 05:05:35 +00:00
60 changed files with 2763 additions and 238 deletions

62
.woodpecker/pipeline.yml Normal file
View File

@ -0,0 +1,62 @@
when:
- event: push
branch: main
steps:
build-go:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
commands:
- pip install wrenn
- export GO_VERSION=$$(grep '^go ' go.mod | cut -d' ' -f2)
- python .woodpecker/scripts/build_go.py
depends_on: []
build-rust:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
commands:
- pip install wrenn
- python .woodpecker/scripts/build_rust.py
depends_on: []
tag-release:
image: python:3.13
environment:
GITEA_TOKEN:
from_secret: gitea_token
commands:
- VERSION=$$(cat VERSION_CP)
- git config user.name "R3dRum92"
- git config user.email "tksadik@omukk.dev"
- git tag "v$${VERSION}"
- git push "https://tksadik92:$${GITEA_TOKEN}@git.omukk.dev/tksadik92/wrenn-releases.git" "v$${VERSION}"
depends_on: [build-go, build-rust]
release-notes:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
GITEA_TOKEN:
from_secret: gitea_token
ZHIPU_API_KEY:
from_secret: zhipu_api_key
commands:
- pip install wrenn
- python .woodpecker/scripts/release_notes.py
depends_on: [tag-release]
publish-github:
image: python:3.13
environment:
GITHUB_TOKEN:
from_secret: github_token
commands:
- pip install httpx
- python .woodpecker/scripts/publish_github.py
depends_on: [release-notes]

View File

@ -0,0 +1,136 @@
import os
import sys
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
from wrenn._git import GitCommandError
GO_VERSION = os.getenv("GO_VERSION", "1.25.8")
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
REPO_DIR = "/opt/wrenn"
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
def read_remote_version(capsule: Capsule, filename: str) -> str:
content = capsule.files.read_bytes(f"{REPO_DIR}/{filename}")
return content.decode("utf-8").strip()
def run(capsule: Capsule, cmd: str, timeout: int = 30) -> int:
result = capsule.commands.run(cmd, timeout=timeout)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def install_go(capsule: Capsule) -> bool:
tarball = f"go{GO_VERSION}.linux-amd64.tar.gz"
url = f"https://go.dev/dl/{tarball}"
if run(capsule, "apt update", timeout=120) != 0:
return False
if run(capsule, "apt install -y make build-essential file", timeout=300) != 0:
return False
if run(capsule, f"curl -LO {url}", timeout=120) != 0:
return False
if run(capsule, f"tar -C /usr/local -xzf {tarball}", timeout=300) != 0:
return False
if run(capsule, 'echo "export PATH=$PATH:/usr/local/go/bin" >> ~/.profile') != 0:
return False
if run(capsule, "rm -f " + tarball) != 0:
return False
result = capsule.commands.run("/usr/local/go/bin/go version")
print(result.stdout.strip())
return result.exit_code == 0
def clone_repo(capsule: Capsule) -> bool:
try:
capsule.git.clone(REPO_URL, REPO_DIR)
print("OK [git clone]")
return True
except GitCommandError as e:
print(f"FAIL [git clone]: {e}", file=sys.stderr)
return False
def build_go(capsule: Capsule) -> bool:
command = "CGO_ENABLED=1 make build-cp build-agent"
handle = capsule.commands.run(
command,
background=True,
cwd=REPO_DIR,
envs={
"PATH": "/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
},
)
print(f"{command} started (pid={handle.pid}), streaming output...")
exit_code = 0
for event in capsule.commands.connect(handle.pid):
if isinstance(event, StreamStdoutEvent):
print(event.data, end="")
elif isinstance(event, StreamStderrEvent):
print(event.data, end="", file=sys.stderr)
elif isinstance(event, StreamExitEvent):
exit_code = event.exit_code
if exit_code != 0:
print(f"FAIL [go build]: exit={exit_code}", file=sys.stderr)
return False
print("OK [go build]")
return True
def download_artifacts(capsule: Capsule) -> bool:
remote_dir = f"{REPO_DIR}/builds"
entries = capsule.files.list(remote_dir, depth=1)
files = [e for e in entries if e.type != "directory"]
if not files:
print("FAIL [download]: no files found in builds/", file=sys.stderr)
return False
local_dir = os.path.normpath(BUILDS_DIR)
os.makedirs(local_dir, exist_ok=True)
versions = {
"wrenn-cp": read_remote_version(capsule, "VERSION_CP"),
"wrenn-agent": read_remote_version(capsule, "VERSION_AGENT"),
}
for entry in files:
name = entry.name or "unknown"
remote_path = f"{remote_dir}/{name}"
local_name = f"{name}-{versions[name]}" if name in versions else name
local_path = os.path.join(local_dir, local_name)
print(f"Downloading {name} as {local_name} ({entry.size or '?'} bytes)...")
with open(local_path, "wb") as f:
for chunk in capsule.files.download_stream(remote_path):
f.write(chunk)
print(f"OK [download {local_name}]")
return True
def main() -> None:
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
print(f"Capsule: {capsule.capsule_id}")
if not install_go(capsule):
sys.exit(1)
if not clone_repo(capsule):
sys.exit(1)
if not build_go(capsule):
sys.exit(1)
if not download_artifacts(capsule):
sys.exit(1)
print("Done.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,173 @@
import os
import sys
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
from wrenn._git import GitCommandError
RUST_VERSION = os.getenv("RUST_VERSION", "1.95.0")
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
REPO_DIR = "/opt/wrenn"
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
RUST_PATH = (
"/root/.cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
)
def read_envd_version(capsule: Capsule) -> str:
content = capsule.files.read_bytes(f"{REPO_DIR}/envd-rs/Cargo.toml")
for line in content.decode("utf-8").splitlines():
stripped = line.strip()
if stripped.startswith("version ="):
return stripped.split("=", 1)[1].strip().strip('"')
print("FAIL [version]: envd-rs/Cargo.toml has no package version", file=sys.stderr)
sys.exit(1)
def run(capsule: Capsule, cmd: str, timeout: int = 30, envs={}) -> int:
result = capsule.commands.run(cmd, timeout=timeout, envs=envs)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def install_rust(capsule: Capsule) -> bool:
if run(capsule, "apt update", timeout=120) != 0:
return False
if (
run(
capsule,
"apt install -y make build-essential file curl musl-tools protobuf-compiler",
timeout=300,
)
!= 0
):
return False
if (
run(
capsule,
f"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain {RUST_VERSION}",
timeout=300,
)
!= 0
):
return False
if (
run(
capsule,
"/root/.cargo/bin/rustup target add x86_64-unknown-linux-musl",
timeout=120,
)
!= 0
):
return False
result = capsule.commands.run("/root/.cargo/bin/rustc --version")
print(result.stdout.strip())
return result.exit_code == 0
def clone_repo(capsule: Capsule) -> bool:
try:
capsule.git.clone(REPO_URL, REPO_DIR)
capsule.commands.run(f"cd {REPO_DIR} && git checkout fix/large-operations")
print("OK [git clone]")
return True
except GitCommandError as e:
print(f"FAIL [git clone]: {e}", file=sys.stderr)
return False
def build_rust(capsule: Capsule) -> bool:
if run(capsule, f"mkdir -p {REPO_DIR}/builds") != 0:
return False
# result = capsule.commands.run("file --version")
# print(result.stdout)
# result = capsule.commands.run(
# 'git rev-parse --short HEAD 2>/dev/null || echo "unknown"'
# )
# commit = result.stdout
# run(capsule, f"mkdir -p {REPO_DIR}/builds")
# result = capsule.commands.run("which musl-gcc")
# print(result.stdout)
handle = capsule.commands.run(
"make build-envd",
background=True,
cwd=REPO_DIR,
envs={"PATH": RUST_PATH},
)
print(f"rust build started (pid={handle.pid}), streaming output...")
exit_code = 0
for event in capsule.commands.connect(handle.pid):
if isinstance(event, StreamStdoutEvent):
print(event.data, end="")
elif isinstance(event, StreamStderrEvent):
print(event.data, end="", file=sys.stderr)
elif isinstance(event, StreamExitEvent):
exit_code = event.exit_code
if exit_code != 0:
print(f"FAIL [rust build]: exit={exit_code}", file=sys.stderr)
return False
print("OK [rust build]")
# if (
# run(
# capsule,
# f"cp {REPO_DIR}/envd-rs/target/x86_64-unknown-linux-musl/release/envd {REPO_DIR}/builds/envd",
# envs={"BIN_DIR": REPO_DIR},
# )
# != 0
# ):
# return False
# result = capsule.commands.run(f"readelf -d {REPO_DIR}/builds/envd 2>&1")
# print(result.stdout, end="")
# if result.stderr:
# print(result.stderr, end="", file=sys.stderr)
# result = capsule.commands.run(f"file {REPO_DIR}/builds/envd 2>&1")
# print(result.stdout)
return True
def download_artifacts(capsule: Capsule) -> bool:
version = read_envd_version(capsule)
remote_path = f"{REPO_DIR}/builds/envd"
local_dir = os.path.normpath(BUILDS_DIR)
local_name = f"envd-{version}"
local_path = os.path.join(local_dir, local_name)
os.makedirs(local_dir, exist_ok=True)
print(f"Downloading envd as {local_name}...")
with open(local_path, "wb") as f:
for chunk in capsule.files.download_stream(remote_path):
f.write(chunk)
print(f"OK [download {local_name}]")
return True
def main() -> None:
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
print(f"Capsule: {capsule.capsule_id}")
if not install_rust(capsule):
sys.exit(1)
if not clone_repo(capsule):
sys.exit(1)
if not build_rust(capsule):
sys.exit(1)
if not download_artifacts(capsule):
sys.exit(1)
print("Done.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,104 @@
import os
import sys
from pathlib import Path
import httpx
GITHUB_REPO = "R3dRum92/wrenn-releases"
GITHUB_API = "https://api.github.com"
GITHUB_UPLOADS = "https://uploads.github.com"
BUILDS_DIR = "builds"
VERSION_FILE = "VERSION_CP"
NOTES_FILE = os.path.join(".woodpecker", "release_notes.md")
def main() -> None:
token = os.environ["GITHUB_TOKEN"]
with open(VERSION_FILE) as f:
version = f.read().strip()
tag = f"v{version}"
release_notes = ""
if os.path.exists(NOTES_FILE):
with open(NOTES_FILE) as f:
release_notes = f.read()
headers = {
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
client = httpx.Client(headers=headers, timeout=60)
print(f"Creating GitHub release for {tag}...")
resp = client.post(
f"{GITHUB_API}/repos/{GITHUB_REPO}/releases",
json={
"tag_name": tag,
"name": tag,
"body": release_notes,
"draft": False,
"prerelease": False,
},
)
if resp.status_code == 422:
print(f"WARN [create release]: release for {tag} already exists, skipping")
data = resp.json()
errors = data.get("errors", [])
if errors:
existing_url = errors[0].get("documentation_url", "")
print(f" See: {existing_url}")
client.close()
return
if resp.status_code != 201:
print(f"FAIL [create release]: {resp.status_code} {resp.text}", file=sys.stderr)
client.close()
sys.exit(1)
release_data = resp.json()
release_id = release_data["id"]
release_url = release_data.get("html_url", "")
print(f"OK [create release] id={release_id}")
builds_path = Path(BUILDS_DIR)
if not builds_path.exists():
print(f"No {BUILDS_DIR}/ directory found, skipping asset upload")
client.close()
print(f"Release published: {release_url}")
return
upload_headers = {
**headers,
"Content-Type": "application/octet-stream",
}
for artifact in sorted(builds_path.iterdir()):
if artifact.is_dir():
continue
print(f"Uploading {artifact.name}...")
with open(artifact, "rb") as f:
data = f.read()
resp = client.post(
f"{GITHUB_UPLOADS}/repos/{GITHUB_REPO}/releases/{release_id}/assets",
params={"name": artifact.name},
headers=upload_headers,
content=data,
)
if resp.status_code != 201:
print(
f"WARN [upload {artifact.name}]: {resp.status_code} {resp.text}",
file=sys.stderr,
)
else:
print(f"OK [upload {artifact.name}]")
client.close()
print(f"Release published: {release_url}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,246 @@
import base64
import os
import sys
from wrenn import Capsule
REPO_URL = "https://git.omukk.dev/tksadik92/wrenn-releases.git"
REPO_DIR = "/opt/wrenn-releases"
CAPSULE_OUTPUT = "/tmp/release_notes.md"
LOCAL_OUTPUT = os.path.join(os.path.dirname(__file__), "..", "release_notes.md")
# Default starting configuration
ZHIPU_API_KEY = os.environ.get("ZHIPU_API_KEY", "")
if ZHIPU_API_KEY:
DEFAULT_MODEL = "zhipuai-coding-plan/glm-5.1"
else:
DEFAULT_MODEL = "opencode/minimax-m2.5-free"
RELEASE_NOTES_EXAMPLE = """
## What's new
Sandbox HTTP proxying, terminal reliability, and auth robustness improvements.
### Proxy
- Fixed redirect loops for apps served inside sandboxes (Python HTTP server, Jupyter, etc.)
- Proxy traffic no longer interferes with terminal and exec connections
- Services that take a moment to start up inside a sandbox are now retried instead of immediately failing
### Terminal (PTY)
- Terminal input is no longer blocked by slow network conditions — fast typing no longer causes timeouts or disconnects
- Input bursts are coalesced into fewer round trips — lower latency under fast typing
### Authentication
- WebSocket connections now authenticate correctly for both SDK clients (header-based) and browser clients (message-based)
### Bug Fixes
- Fixed crash in envd when a process exits without a PTY
- Fixed goroutine leak on sandbox pause
### Others
- Version bump
""".strip()
def run(capsule: Capsule, cmd: str, cwd: str | None = None, timeout: int = 30) -> int:
result = capsule.commands.run(cmd, cwd=cwd, timeout=timeout)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def get_tags(capsule: Capsule) -> tuple[str, str | None]:
result = capsule.commands.run(
f"cd {REPO_DIR} && git tag --sort=-version:refname",
cwd=REPO_DIR,
timeout=30,
)
if result.exit_code != 0:
print(f"FAIL [git tag]: {result.stderr}", file=sys.stderr)
sys.exit(1)
tags = [t for t in result.stdout.strip().split("\n") if t]
if not tags:
print("No tags found", file=sys.stderr)
sys.exit(1)
current_tag = tags[0]
previous_tag = tags[1] if len(tags) > 1 else None
print(f"Current tag: {current_tag}")
print(f"Previous tag: {previous_tag}")
return current_tag, previous_tag
def get_git_context(
capsule: Capsule, current_tag: str, previous_tag: str | None
) -> tuple[str, str]:
if previous_tag:
# FIX: Removed '-n 2' to ensure we grab ALL commits between the two tags
log_cmd = f"cd {REPO_DIR} && git log {previous_tag}..{current_tag} --pretty=format:'%s (%h)'"
else:
# Fallback to limit log size if this is the very first tag in the repo
log_cmd = (
f"cd {REPO_DIR} && git log {current_tag} --pretty=format:'%s (%h)' -n 50"
)
log_result = capsule.commands.run(log_cmd, cwd=REPO_DIR, timeout=30)
if log_result.exit_code != 0:
print(f"FAIL [git log]: {log_result.stderr}", file=sys.stderr)
sys.exit(1)
# git diff natively compares the entire tree state between tags
if previous_tag:
diff_cmd = f"cd {REPO_DIR} && git diff {previous_tag}..{current_tag} --stat"
else:
diff_cmd = f"cd {REPO_DIR} && git show {current_tag} --stat"
diff_result = capsule.commands.run(diff_cmd, cwd=REPO_DIR, timeout=30)
if diff_result.exit_code != 0:
print(f"FAIL [git diff]: {diff_result.stderr}", file=sys.stderr)
sys.exit(1)
return log_result.stdout.strip(), diff_result.stdout.strip()
def generate_release_notes(
capsule: Capsule,
current_tag: str,
git_log: str,
git_diff: str,
output_path: str,
model: str,
) -> None:
prompt = (
f"You are writing release notes for version {current_tag} of a software project.\n\n"
f"Here is what changed between the previous version and this one:\n\n"
f"Commit messages:\n{git_log}\n\n"
f"Files and areas that changed:\n{git_diff}\n\n"
f"Write the release notes in plain, friendly language that any developer can understand "
f"without deep knowledge of the codebase. Avoid jargon like 'goroutine', 'PTY', 'envd', "
f"or internal function names — describe what the change means for the user instead. "
f"Group related changes under headings that reflect what actually changed. "
f"Only include sections that are relevant to these specific changes. "
f"Start with a short one-line summary of what this release is about. "
f"Keep each bullet point to one clear sentence.\n\n"
f"Here is an example of the style to aim for — not a template to copy:\n\n"
f"{RELEASE_NOTES_EXAMPLE}\n\n"
f"You MUST start the document with `## What's New`\n"
f"The very next line MUST be a single short summary sentence.\n"
f"Output only the markdown. No intro, no explanation."
f"CRITICAL: Do not output any conversational filler, acknowledgments, or thoughts "
f"like 'Let me look at the changes'. Output absolutely nothing except the final markdown."
)
prompt_b64 = base64.b64encode(prompt.encode("utf-8")).decode("utf-8")
write_prompt_cmd = f"echo '{prompt_b64}' | base64 -d > /tmp/oc_prompt.txt"
result = capsule.commands.run(
write_prompt_cmd,
cwd=REPO_DIR,
timeout=10,
)
if result.exit_code != 0:
print(f"FAIL [write prompt]: {result.stderr}", file=sys.stderr)
sys.exit(1)
# FIX: Wrapper function to handle execution and authentication dynamically
def run_opencode_with_model(target_model: str) -> int:
env = ""
if "zhipu" in target_model.lower():
env = f"ZHIPU_API_KEY={os.environ.get('ZHIPU_API_KEY', '')}"
cmd = (
f"{env} "
f"~/.opencode/bin/opencode run "
f'"Read the attached file and generate the release notes. Output ONLY markdown." '
f"--model {target_model} "
f"--file /tmp/oc_prompt.txt "
f"> {output_path}"
)
cmd_result = capsule.commands.run(cmd, cwd=REPO_DIR, timeout=120)
if cmd_result.exit_code != 0:
print(
f"FAIL [opencode via {target_model}]: exit={cmd_result.exit_code}",
file=sys.stderr,
)
print(f"STDOUT:\n{cmd_result.stdout}", file=sys.stderr)
print(f"STDERR:\n{cmd_result.stderr}", file=sys.stderr)
return cmd_result.exit_code
# First attempt with the target model
exit_status = run_opencode_with_model(model)
# FIX: Catch failures (like Zhipu rate limits) and fallback to MiniMax
if exit_status != 0:
if "zhipu" in model.lower():
print(
"\n[!] Zhipu AI failed (likely rate-limited). Falling back to MiniMax...",
file=sys.stderr,
)
fallback_model = "opencode/minimax-m2.5-free"
exit_status = run_opencode_with_model(fallback_model)
if exit_status != 0:
print("FAIL: Fallback model also failed. Exiting.", file=sys.stderr)
sys.exit(1)
else:
sys.exit(1)
result = capsule.commands.run(f"cat {output_path}")
print(result.stdout)
if result.stderr:
print(result.stderr)
print(f"OK [opencode] release notes written to {output_path}")
def download_release_notes(capsule: Capsule) -> None:
local_path = os.path.normpath(LOCAL_OUTPUT)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
print(f"Downloading release notes from capsule...")
content = capsule.files.read_bytes(CAPSULE_OUTPUT)
with open(local_path, "wb") as f:
f.write(content)
print(f"OK [download] release notes → {local_path}")
print(content.decode("utf-8", errors="replace"))
def main() -> None:
model = os.environ.get("OPENCODE_MODEL", DEFAULT_MODEL)
with Capsule(template="opencode", wait=True, vcpus=2, memory_mb=2048) as capsule:
print(f"Capsule: {capsule.capsule_id}")
capsule.git.clone(
REPO_URL,
REPO_DIR,
username="tksadik92",
)
print("OK [git clone]")
current_tag, previous_tag = get_tags(capsule)
git_log, git_diff = get_git_context(capsule, current_tag, previous_tag)
# Note: This simply creates the directory string safely
output_path = os.path.normpath(CAPSULE_OUTPUT)
generate_release_notes(
capsule,
current_tag,
git_log,
git_diff,
output_path,
model,
)
download_release_notes(capsule)
if __name__ == "__main__":
main()

View File

@ -27,8 +27,12 @@ build-agent:
build-envd: build-envd:
cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl
@cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd @cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd
@file $(BIN_DIR)/envd | grep -q "static-pie linked" || \ @readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \
(echo "ERROR: envd is not statically linked!" && exit 1) readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \
! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \
{ ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \
readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \
(echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1)
# ═══════════════════════════════════════════════════ # ═══════════════════════════════════════════════════
# Development # Development
@ -111,6 +115,7 @@ vet:
test: test:
go test -race -v ./internal/... go test -race -v ./internal/...
cd envd-rs && cargo test
test-integration: test-integration:
go test -race -v -tags=integration ./tests/integration/... go test -race -v -tags=integration ./tests/integration/...

View File

@ -1 +1 @@
0.1.2 0.1.3

View File

@ -1 +1 @@
0.1.5 0.1.6

View File

@ -80,6 +80,25 @@ func main() {
os.Exit(1) os.Exit(1)
} }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Register with the control plane before touching rootfs images. If the
// agent can't reach the CP there's no point inflating images (and crashing
// afterward would leave them in the expanded state).
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M"). // Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M").
defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB
if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" { if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" {
@ -128,25 +147,8 @@ func main() {
mgr := sandbox.New(cfg) mgr := sandbox.New(cfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr.StartTTLReaper(ctx) mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating.
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// httpServer is declared here so the shutdown func can reference it. // httpServer is declared here so the shutdown func can reference it.
// ReadTimeout/WriteTimeout are intentionally omitted — they would kill // ReadTimeout/WriteTimeout are intentionally omitted — they would kill
// long-lived Connect RPC streams and WebSocket proxy connections. // long-lived Connect RPC streams and WebSocket proxy connections.

View File

@ -22,6 +22,12 @@ RETURNING *;
-- name: SetUserAdmin :exec -- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1; UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
-- name: RevokeUserAdmin :execrows
UPDATE users u SET is_admin = false, updated_at = NOW()
WHERE u.id = $1
AND u.is_admin = true
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1;
-- name: GetAdminUsers :many -- name: GetAdminUsers :many
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at; SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;

3
envd-rs/Cargo.lock generated
View File

@ -514,7 +514,7 @@ dependencies = [
[[package]] [[package]]
name = "envd" name = "envd"
version = "0.2.0" version = "0.2.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -543,6 +543,7 @@ dependencies = [
"sha2", "sha2",
"subtle", "subtle",
"sysinfo", "sysinfo",
"tempfile",
"tokio", "tokio",
"tokio-util", "tokio-util",
"tower", "tower",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "envd" name = "envd"
version = "0.2.0" version = "0.2.1"
edition = "2024" edition = "2024"
rust-version = "1.88" rust-version = "1.88"
@ -72,6 +72,9 @@ buffa = "0.3"
async-stream = "0.3.6" async-stream = "0.3.6"
mime_guess = "2" mime_guess = "2"
[dev-dependencies]
tempfile = "3"
[build-dependencies] [build-dependencies]
connectrpc-build = "0.3" connectrpc-build = "0.3"

View File

@ -83,3 +83,128 @@ pub fn validate_signing(
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
fn test_token(val: &[u8]) -> SecureToken {
let t = SecureToken::new();
t.set(val).unwrap();
t
}
fn far_future() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ 3600
}
#[test]
fn generate_starts_with_v1() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(sig.starts_with("v1_"));
}
#[test]
fn generate_deterministic() {
let token = test_token(b"secret");
let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn generate_with_expiration_differs() {
let token = test_token(b"secret");
let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap();
let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap();
assert_ne!(without, with);
}
#[test]
fn generate_unset_token_errors() {
let token = SecureToken::new();
assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err());
}
#[test]
fn validate_no_token_set_passes() {
let token = SecureToken::new();
assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_correct_header_token() {
let token = test_token(b"secret");
assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_wrong_header_token() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not match"));
}
#[test]
fn validate_valid_signature() {
let token = test_token(b"secret");
let exp = far_future();
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap();
assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn validate_invalid_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid signature"));
}
#[test]
fn validate_expired_signature() {
let token = test_token(b"secret");
let expired: i64 = 1_000_000;
let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap();
let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("expired"));
}
#[test]
fn validate_missing_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_empty_header_token_falls_through_to_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_valid_signature_no_expiration() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn different_operations_produce_different_signatures() {
let token = test_token(b"secret");
let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap();
let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap();
assert_ne!(r, w);
}
}

View File

@ -125,3 +125,132 @@ impl SecureToken {
Ok(token) Ok(token)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_is_unset() {
let t = SecureToken::new();
assert!(!t.is_set());
assert!(!t.equals("anything"));
}
#[test]
fn set_and_equals() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
assert!(t.is_set());
assert!(t.equals("secret"));
assert!(!t.equals("wrong"));
}
#[test]
fn set_empty_errors() {
let t = SecureToken::new();
assert!(t.set(b"").is_err());
assert!(!t.is_set());
}
#[test]
fn set_overwrites_previous() {
let t = SecureToken::new();
t.set(b"first").unwrap();
t.set(b"second").unwrap();
assert!(!t.equals("first"));
assert!(t.equals("second"));
}
#[test]
fn destroy_clears() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
t.destroy();
assert!(!t.is_set());
assert!(!t.equals("secret"));
}
#[test]
fn bytes_returns_copy() {
let t = SecureToken::new();
assert!(t.bytes().is_none());
t.set(b"hello").unwrap();
assert_eq!(t.bytes().unwrap(), b"hello");
}
#[test]
fn take_from_transfers_and_clears_source() {
let src = SecureToken::new();
src.set(b"token").unwrap();
let dst = SecureToken::new();
dst.take_from(&src);
assert!(!src.is_set());
assert!(dst.equals("token"));
}
#[test]
fn take_from_overwrites_existing() {
let src = SecureToken::new();
src.set(b"new").unwrap();
let dst = SecureToken::new();
dst.set(b"old").unwrap();
dst.take_from(&src);
assert!(dst.equals("new"));
assert!(!dst.equals("old"));
}
#[test]
fn equals_secure_matching() {
let a = SecureToken::new();
a.set(b"same").unwrap();
let b = SecureToken::new();
b.set(b"same").unwrap();
assert!(a.equals_secure(&b));
}
#[test]
fn equals_secure_different() {
let a = SecureToken::new();
a.set(b"one").unwrap();
let b = SecureToken::new();
b.set(b"two").unwrap();
assert!(!a.equals_secure(&b));
}
#[test]
fn equals_secure_unset() {
let a = SecureToken::new();
let b = SecureToken::new();
assert!(!a.equals_secure(&b));
}
#[test]
fn from_json_bytes_valid() {
let mut data = b"\"mysecret\"".to_vec();
let t = SecureToken::from_json_bytes(&mut data).unwrap();
assert!(t.equals("mysecret"));
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_missing_quotes() {
let mut data = b"noquotes".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_escape_sequences() {
let mut data = b"\"has\\nescapes\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_empty_content() {
let mut data = b"\"\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
}

View File

@ -76,4 +76,125 @@ impl ConnTracker {
pub fn keepalives_enabled(&self) -> bool { pub fn keepalives_enabled(&self) -> bool {
self.inner.lock().unwrap().keepalives_enabled self.inner.lock().unwrap().keepalives_enabled
} }
#[cfg(test)]
fn active_count(&self) -> usize {
self.inner.lock().unwrap().active.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_assigns_sequential_ids() {
let ct = ConnTracker::new();
assert_eq!(ct.register_connection(), 0);
assert_eq!(ct.register_connection(), 1);
assert_eq!(ct.register_connection(), 2);
}
#[test]
fn remove_clears_active() {
let ct = ConnTracker::new();
let id = ct.register_connection();
assert_eq!(ct.active_count(), 1);
ct.remove_connection(id);
assert_eq!(ct.active_count(), 0);
}
#[test]
fn remove_nonexistent_is_noop() {
let ct = ConnTracker::new();
ct.remove_connection(999);
assert_eq!(ct.active_count(), 0);
}
#[test]
fn prepare_disables_keepalives() {
let ct = ConnTracker::new();
assert!(ct.keepalives_enabled());
ct.register_connection();
ct.prepare_for_snapshot();
assert!(!ct.keepalives_enabled());
}
#[test]
fn restore_removes_zombies_and_reenables_keepalives() {
let ct = ConnTracker::new();
let id0 = ct.register_connection();
let id1 = ct.register_connection();
ct.prepare_for_snapshot();
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
// Both pre-snapshot connections removed as zombies
assert_eq!(ct.active_count(), 0);
// IDs don't matter anymore, but remove shouldn't panic
ct.remove_connection(id0);
ct.remove_connection(id1);
}
#[test]
fn restore_without_prepare_is_noop() {
let ct = ConnTracker::new();
let _id = ct.register_connection();
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
assert_eq!(ct.active_count(), 1);
}
#[test]
fn connection_closed_before_restore_not_zombie() {
let ct = ConnTracker::new();
let id0 = ct.register_connection();
let _id1 = ct.register_connection();
ct.prepare_for_snapshot();
// Close id0 during snapshot window
ct.remove_connection(id0);
assert_eq!(ct.active_count(), 1);
ct.restore_after_snapshot();
// id1 was zombie (still active at restore), id0 already gone
assert_eq!(ct.active_count(), 0);
}
#[test]
fn post_snapshot_connection_survives_restore() {
let ct = ConnTracker::new();
ct.register_connection();
ct.prepare_for_snapshot();
// New connection after snapshot
let _post = ct.register_connection();
ct.restore_after_snapshot();
// Pre-snapshot connection removed, post-snapshot survives
assert_eq!(ct.active_count(), 1);
}
#[test]
fn full_lifecycle() {
let ct = ConnTracker::new();
let _a = ct.register_connection();
let b = ct.register_connection();
let _c = ct.register_connection();
assert_eq!(ct.active_count(), 3);
assert!(ct.keepalives_enabled());
ct.prepare_for_snapshot();
assert!(!ct.keepalives_enabled());
let d = ct.register_connection();
ct.remove_connection(b);
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
// a and c were zombies, b removed before restore, d is post-snapshot
assert_eq!(ct.active_count(), 1);
ct.remove_connection(d);
assert_eq!(ct.active_count(), 0);
// Can reuse tracker after restore
let e = ct.register_connection();
assert_eq!(ct.active_count(), 1);
assert!(e > d);
}
} }

View File

@ -15,8 +15,29 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_hmac_sha256() { fn rfc4231_tc1() {
let result = compute(b"key", b"message"); let key = &[0x0b; 20];
assert_eq!(result.len(), 64); // SHA-256 hex = 64 chars let data = b"Hi There";
assert_eq!(
compute(key, data),
"b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
);
}
#[test]
fn rfc4231_tc2() {
let key = b"Jefe";
let data = b"what do ya want for nothing?";
assert_eq!(
compute(key, data),
"5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
);
}
#[test]
fn output_is_64_hex_chars() {
let result = compute(b"key", b"data");
assert_eq!(result.len(), 64);
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
} }
} }

View File

@ -17,17 +17,38 @@ pub fn hash_without_prefix(data: &[u8]) -> String {
mod tests { mod tests {
use super::*; use super::*;
const VECTORS: &[(&[u8], &str)] = &[
(b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"),
(b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"),
(b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"),
];
#[test] #[test]
fn test_hash_format() { fn known_answer_with_prefix() {
let result = hash(b"test"); for (input, expected_b64) in VECTORS {
assert!(result.starts_with("$sha256$")); let result = hash(input);
assert!(!result.contains('=')); assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input));
}
} }
#[test] #[test]
fn test_hash_without_prefix() { fn known_answer_without_prefix() {
let result = hash_without_prefix(b"test"); for (input, expected_b64) in VECTORS {
assert!(!result.starts_with("$sha256$")); let result = hash_without_prefix(input);
assert!(!result.contains('=')); assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input));
}
}
#[test]
fn no_base64_padding() {
for (input, _) in VECTORS {
assert!(!hash(input).contains('='));
assert!(!hash_without_prefix(input).contains('='));
}
}
#[test]
fn deterministic() {
assert_eq!(hash(b"test"), hash(b"test"));
} }
} }

View File

@ -14,11 +14,30 @@ pub fn hash_access_token_bytes(token: &[u8]) -> String {
mod tests { mod tests {
use super::*; use super::*;
const VECTORS: &[(&str, &str)] = &[
("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"),
("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"),
("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"),
];
#[test] #[test]
fn test_hash_access_token() { fn known_answer() {
let h1 = hash_access_token("test"); for (input, expected) in VECTORS {
let h2 = hash_access_token_bytes(b"test"); assert_eq!(hash_access_token(input), *expected, "input: {input:?}");
assert_eq!(h1, h2); }
assert_eq!(h1.len(), 128); // SHA-512 hex = 128 chars }
#[test]
fn str_and_bytes_agree() {
for (input, _) in VECTORS {
assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes()));
}
}
#[test]
fn output_is_lowercase_hex_128_chars() {
let h = hash_access_token("anything");
assert_eq!(h.len(), 128);
assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
} }
} }

View File

@ -1,21 +1,36 @@
use dashmap::DashMap; use dashmap::DashMap;
use std::sync::Arc; use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct Defaults { pub struct Defaults {
pub env_vars: Arc<DashMap<String, String>>, pub env_vars: Arc<DashMap<String, String>>,
pub user: String, user: RwLock<String>,
pub workdir: Option<String>, workdir: RwLock<Option<String>>,
} }
impl Defaults { impl Defaults {
pub fn new(user: &str) -> Self { pub fn new(user: &str) -> Self {
Self { Self {
env_vars: Arc::new(DashMap::new()), env_vars: Arc::new(DashMap::new()),
user: user.to_string(), user: RwLock::new(user.to_string()),
workdir: None, workdir: RwLock::new(None),
} }
} }
pub fn user(&self) -> String {
self.user.read().unwrap().clone()
}
pub fn set_user(&self, user: String) {
*self.user.write().unwrap() = user;
}
pub fn workdir(&self) -> Option<String> {
self.workdir.read().unwrap().clone()
}
pub fn set_workdir(&self, workdir: Option<String>) {
*self.workdir.write().unwrap() = workdir;
}
} }
pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String { pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String {
@ -40,3 +55,64 @@ pub fn resolve_default_username<'a>(
} }
Err("username not provided") Err("username not provided")
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn workdir_explicit_overrides_default() {
assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit");
}
#[test]
fn workdir_empty_uses_default() {
assert_eq!(resolve_default_workdir("", Some("/default")), "/default");
}
#[test]
fn workdir_empty_no_default_returns_empty() {
assert_eq!(resolve_default_workdir("", None), "");
}
#[test]
fn workdir_explicit_ignores_none_default() {
assert_eq!(resolve_default_workdir("/explicit", None), "/explicit");
}
#[test]
fn username_explicit_returns_explicit() {
assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root");
}
#[test]
fn username_none_uses_default() {
assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn");
}
#[test]
fn username_none_empty_default_errors() {
assert!(resolve_default_username(None, "").is_err());
}
#[test]
fn username_some_overrides_empty_default() {
assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root");
}
#[test]
fn defaults_user_set_and_get() {
let d = Defaults::new("initial");
assert_eq!(d.user(), "initial");
d.set_user("changed".into());
assert_eq!(d.user(), "changed");
}
#[test]
fn defaults_workdir_initially_none() {
let d = Defaults::new("user");
assert!(d.workdir().is_none());
d.set_workdir(Some("/home".into()));
assert_eq!(d.workdir().unwrap(), "/home");
}
}

View File

@ -145,3 +145,192 @@ pub fn parse_content_encoding<B>(r: &Request<B>) -> Result<&'static str, String>
Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}")) Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}"))
} }
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
fn req_with_accept(v: &str) -> Request<()> {
Request::builder()
.header("accept-encoding", v)
.body(())
.unwrap()
}
fn req_with_content(v: &str) -> Request<()> {
Request::builder()
.header("content-encoding", v)
.body(())
.unwrap()
}
fn req_no_headers() -> Request<()> {
Request::builder().body(()).unwrap()
}
// parse_encoding_with_quality
#[test]
fn encoding_quality_default_1() {
let eq = parse_encoding_with_quality("gzip");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 1.0);
}
#[test]
fn encoding_quality_explicit() {
let eq = parse_encoding_with_quality("gzip;q=0.8");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.8);
}
#[test]
fn encoding_quality_case_insensitive() {
let eq = parse_encoding_with_quality("GZIP;Q=0.5");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.5);
}
#[test]
fn encoding_quality_zero() {
let eq = parse_encoding_with_quality("gzip;q=0");
assert_eq!(eq.quality, 0.0);
}
#[test]
fn encoding_quality_whitespace_trimmed() {
let eq = parse_encoding_with_quality(" gzip ; q=0.9 ");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.9);
}
// parse_accept_encoding_header
#[test]
fn accept_header_empty() {
let (encs, rejected) = parse_accept_encoding_header("");
assert!(encs.is_empty());
assert!(!rejected);
}
#[test]
fn accept_header_identity_q0_rejects() {
let (_, rejected) = parse_accept_encoding_header("identity;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_rejects_identity() {
let (_, rejected) = parse_accept_encoding_header("*;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_but_identity_explicit_accepted() {
let (_, rejected) = parse_accept_encoding_header("*;q=0, identity");
assert!(!rejected);
}
// parse_accept_encoding (full)
#[test]
fn accept_encoding_no_header_returns_identity() {
assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip() {
assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip");
}
#[test]
fn accept_encoding_identity_explicit() {
assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip_higher_quality() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_wildcard_returns_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity");
}
#[test]
fn accept_encoding_wildcard_identity_rejected_returns_gzip() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_all_rejected_errors() {
assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err());
}
#[test]
fn accept_encoding_unsupported_only_falls_to_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity");
}
// is_identity_acceptable
#[test]
fn identity_acceptable_no_header() {
assert!(is_identity_acceptable(&req_no_headers()));
}
#[test]
fn identity_acceptable_gzip_only() {
assert!(is_identity_acceptable(&req_with_accept("gzip")));
}
#[test]
fn identity_not_acceptable_identity_q0() {
assert!(!is_identity_acceptable(&req_with_accept("identity;q=0")));
}
#[test]
fn identity_not_acceptable_wildcard_q0() {
assert!(!is_identity_acceptable(&req_with_accept("*;q=0")));
}
#[test]
fn identity_acceptable_wildcard_q0_but_identity_explicit() {
assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity")));
}
// parse_content_encoding
#[test]
fn content_encoding_empty_returns_identity() {
assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn content_encoding_gzip() {
assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip");
}
#[test]
fn content_encoding_identity_explicit() {
assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity");
}
#[test]
fn content_encoding_unsupported_errors() {
assert!(parse_content_encoding(&req_with_content("br")).is_err());
}
#[test]
fn content_encoding_case_insensitive() {
assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip");
}
}

View File

@ -71,9 +71,10 @@ pub async fn get_files(
let path_str = params.path.as_deref().unwrap_or(""); let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req); let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username( let username = match execcontext::resolve_default_username(
params.username.as_deref(), params.username.as_deref(),
&state.defaults.user, &default_user,
) { ) {
Ok(u) => u.to_string(), Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e), Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
@ -96,7 +97,8 @@ pub async fn get_files(
}; };
let home_dir = user.dir.to_string_lossy().to_string(); let home_dir = user.dir.to_string_lossy().to_string();
let resolved = match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) let default_workdir = state.defaults.workdir();
let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref())
{ {
Ok(p) => p, Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
@ -222,9 +224,10 @@ pub async fn post_files(
let path_str = params.path.as_deref().unwrap_or(""); let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req); let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username( let username = match execcontext::resolve_default_username(
params.username.as_deref(), params.username.as_deref(),
&state.defaults.user, &default_user,
) { ) {
Ok(u) => u.to_string(), Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e), Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
@ -266,6 +269,7 @@ pub async fn post_files(
}; };
let mut uploaded: Vec<EntryInfo> = Vec::new(); let mut uploaded: Vec<EntryInfo> = Vec::new();
let default_workdir = state.defaults.workdir();
while let Ok(Some(field)) = multipart.next_field().await { while let Ok(Some(field)) = multipart.next_field().await {
let field_name = field.name().unwrap_or("").to_string(); let field_name = field.name().unwrap_or("").to_string();
@ -274,7 +278,7 @@ pub async fn post_files(
} }
let file_path = if !path_str.is_empty() { let file_path = if !path_str.is_empty() {
match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) { match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) {
Ok(p) => p, Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
} }
@ -283,7 +287,7 @@ pub async fn post_files(
.file_name() .file_name()
.unwrap_or("upload") .unwrap_or("upload")
.to_string(); .to_string();
match expand_and_resolve(&fname, &home_dir, state.defaults.workdir.as_deref()) { match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) {
Ok(p) => p, Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
} }

View File

@ -29,6 +29,8 @@ pub async fn get_health(State(state): State<Arc<AppState>>) -> impl IntoResponse
fn post_restore_recovery(state: &AppState) { fn post_restore_recovery(state: &AppState) {
tracing::info!("restore: post-restore recovery (no GC needed in Rust)"); tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
state.conn_tracker.restore_after_snapshot(); state.conn_tracker.restore_after_snapshot();
tracing::info!("restore: zombie connections closed"); tracing::info!("restore: zombie connections closed");

View File

@ -78,11 +78,15 @@ pub async fn post_init(
if let Some(ref user) = init_req.default_user { if let Some(ref user) = init_req.default_user {
if !user.is_empty() { if !user.is_empty() {
tracing::debug!(user = %user, "setting default user"); tracing::debug!(user = %user, "setting default user");
let mut defaults = state.defaults.clone(); state.defaults.set_user(user.clone());
defaults.user = user.clone(); }
// Note: In Rust we'd need interior mutability for this. }
// For now, env_vars (DashMap) handles concurrent access.
// User/workdir mutation deferred to full state refactor. // Set default workdir
if let Some(ref workdir) = init_req.default_workdir {
if !workdir.is_empty() {
tracing::debug!(workdir = %workdir, "setting default workdir");
state.defaults.set_workdir(Some(workdir.clone()));
} }
} }
@ -147,6 +151,9 @@ async fn trigger_restore_and_respond(state: &AppState) -> axum::response::Respon
fn post_restore_recovery(state: &AppState) { fn post_restore_recovery(state: &AppState) {
tracing::info!("restore: post-restore recovery (no GC needed in Rust)"); tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
state.conn_tracker.restore_after_snapshot(); state.conn_tracker.restore_after_snapshot();
if let Some(ref ps) = state.port_subsystem { if let Some(ref ps) = state.port_subsystem {

View File

@ -46,7 +46,8 @@ fn collect_metrics(state: &AppState) -> Result<Metrics, String> {
let mut sys = sysinfo::System::new(); let mut sys = sysinfo::System::new();
sys.refresh_memory(); sys.refresh_memory();
let mem_total = sys.total_memory(); let mem_total = sys.total_memory();
let mem_used = sys.used_memory(); let mem_available = sys.available_memory();
let mem_used = mem_total.saturating_sub(mem_available);
let mem_total_mib = mem_total / 1024 / 1024; let mem_total_mib = mem_total / 1024 / 1024;
let mem_used_mib = mem_used / 1024 / 1024; let mem_used_mib = mem_used / 1024 / 1024;

View File

@ -10,10 +10,24 @@ use crate::state::AppState;
/// POST /snapshot/prepare — quiesce subsystems before Firecracker snapshot. /// POST /snapshot/prepare — quiesce subsystems before Firecracker snapshot.
/// ///
/// In Rust there is no GC dance. We just: /// In Rust there is no GC dance. We just:
/// 1. Stop port subsystem /// 1. Drop page cache to shrink snapshot size
/// 2. Close idle connections via conntracker /// 2. Stop port subsystem
/// 3. Set needs_restore flag /// 3. Close idle connections via conntracker
/// 4. Set needs_restore flag
pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse { pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse {
// Drop page cache BEFORE blocking the reclaimer — avoids snapshotting
// gigabytes of stale cache that inflates the memory dump on disk.
// "1" = pagecache only (keep dentries/inodes for faster resume).
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "1") {
tracing::warn!(error = %e, "snapshot/prepare: drop_caches failed");
} else {
tracing::info!("snapshot/prepare: page cache dropped");
}
// Block memory reclaimer — prevents drop_caches from running mid-freeze
// which would corrupt kernel page table state.
state.snapshot_in_progress.store(true, Ordering::Release);
if let Some(ref ps) = state.port_subsystem { if let Some(ref ps) = state.port_subsystem {
ps.stop(); ps.stop();
tracing::info!("snapshot/prepare: port subsystem stopped"); tracing::info!("snapshot/prepare: port subsystem stopped");
@ -22,6 +36,9 @@ pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl I
state.conn_tracker.prepare_for_snapshot(); state.conn_tracker.prepare_for_snapshot();
tracing::info!("snapshot/prepare: connections prepared"); tracing::info!("snapshot/prepare: connections prepared");
// Sync filesystem buffers so dirty pages are flushed before freeze.
unsafe { libc::sync(); }
state.needs_restore.store(true, Ordering::Release); state.needs_restore.store(true, Ordering::Release);
tracing::info!("snapshot/prepare: ready for freeze"); tracing::info!("snapshot/prepare: ready for freeze");

View File

@ -147,6 +147,14 @@ async fn main() {
Some(Arc::clone(&port_subsystem)), Some(Arc::clone(&port_subsystem)),
); );
// Memory reclaimer — drop page cache when available memory is low.
// Firecracker balloon device can only reclaim pages the guest kernel freed.
// Pauses during snapshot/prepare to avoid corrupting kernel page table state.
if !cli.is_not_fc {
let state_for_reclaimer = Arc::clone(&state);
std::thread::spawn(move || memory_reclaimer(state_for_reclaimer));
}
// RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port) // RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
let connect_router = rpc::rpc_router(Arc::clone(&state)); let connect_router = rpc::rpc_router(Arc::clone(&state));
@ -188,7 +196,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
use crate::rpc::process_handler; use crate::rpc::process_handler;
use std::collections::HashMap; use std::collections::HashMap;
let user = match lookup_user(&state.defaults.user) { let default_user = state.defaults.user();
let user = match lookup_user(&default_user) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "cmd: failed to lookup user"); tracing::error!(error = %e, "cmd: failed to lookup user");
@ -197,9 +206,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
}; };
let home = user.dir.to_string_lossy().to_string(); let home = user.dir.to_string_lossy().to_string();
let cwd = state let default_workdir = state.defaults.workdir();
.defaults let cwd = default_workdir
.workdir
.as_deref() .as_deref()
.unwrap_or(&home); .unwrap_or(&home);
@ -214,11 +222,52 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
&user, &user,
&state.defaults.env_vars, &state.defaults.env_vars,
) { ) {
Ok(handle) => { Ok(spawned) => {
tracing::info!(pid = handle.pid, cmd, "initial command spawned"); tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned");
} }
Err(e) => { Err(e) => {
tracing::error!(error = %e, cmd, "failed to spawn initial command"); tracing::error!(error = %e, cmd, "failed to spawn initial command");
} }
} }
} }
fn memory_reclaimer(state: Arc<AppState>) {
use std::sync::atomic::Ordering;
const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
const DROP_THRESHOLD_PCT: u64 = 80;
loop {
std::thread::sleep(CHECK_INTERVAL);
if state.snapshot_in_progress.load(Ordering::Acquire) {
continue;
}
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let total = sys.total_memory();
let available = sys.available_memory();
if total == 0 {
continue;
}
let used_pct = ((total - available) * 100) / total;
if used_pct >= DROP_THRESHOLD_PCT {
if state.snapshot_in_progress.load(Ordering::Acquire) {
continue;
}
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
tracing::debug!(error = %e, "drop_caches failed");
} else {
let mut sys2 = sysinfo::System::new();
sys2.refresh_memory();
let freed_mb =
sys2.available_memory().saturating_sub(available) / (1024 * 1024);
tracing::info!(used_pct, freed_mb, "page cache dropped");
}
}
}
}

View File

@ -70,3 +70,115 @@ pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> {
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
// expand_tilde
#[test]
fn tilde_empty_passthrough() {
assert_eq!(expand_tilde("", "/home/u").unwrap(), "");
}
#[test]
fn tilde_no_tilde_passthrough() {
assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute");
}
#[test]
fn tilde_bare() {
assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user");
}
#[test]
fn tilde_slash_path() {
assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs");
}
#[test]
fn tilde_nested() {
assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c");
}
#[test]
fn tilde_other_user_errors() {
assert!(expand_tilde("~bob/foo", "/home/user").is_err());
}
#[test]
fn tilde_relative_no_tilde() {
assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path");
}
// expand_and_resolve
#[test]
fn resolve_absolute_passthrough() {
assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path");
}
#[test]
fn resolve_empty_uses_default() {
assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default");
}
#[test]
fn resolve_empty_no_default_falls_back_to_home() {
// Empty path with no default → joins "" with home_dir → returns home_dir
let result = expand_and_resolve("", "/home", None).unwrap();
assert_eq!(result, "/home");
}
#[test]
fn resolve_tilde_expands() {
assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir");
}
#[test]
fn resolve_relative_joins_home() {
let result = expand_and_resolve("subdir", "/tmp", None).unwrap();
// Relative path joined with home and canonicalized (or raw join on missing)
assert!(result.starts_with("/tmp"));
assert!(result.contains("subdir"));
}
#[test]
fn resolve_tilde_other_user_errors() {
assert!(expand_and_resolve("~bob", "/home/u", None).is_err());
}
// ensure_dirs
#[test]
fn ensure_dirs_creates_nested() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("a/b/c");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap();
assert!(path.is_dir());
}
#[test]
fn ensure_dirs_existing_is_ok() {
let tmp = tempfile::TempDir::new().unwrap();
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap();
}
#[test]
fn ensure_dirs_file_in_path_errors() {
let tmp = tempfile::TempDir::new().unwrap();
let file_path = tmp.path().join("afile");
std::fs::write(&file_path, "").unwrap();
let nested = file_path.join("subdir");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
let result = ensure_dirs(nested.to_str().unwrap(), uid, gid);
assert!(result.is_err());
assert!(result.unwrap_err().contains("path is a file"));
}
}

View File

@ -110,3 +110,151 @@ fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> {
Some((ip_str, port)) Some((ip_str, port))
} }
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
// tcp_state_name
#[test]
fn state_all_known_codes() {
assert_eq!(tcp_state_name("01"), "ESTABLISHED");
assert_eq!(tcp_state_name("02"), "SYN_SENT");
assert_eq!(tcp_state_name("03"), "SYN_RECV");
assert_eq!(tcp_state_name("04"), "FIN_WAIT1");
assert_eq!(tcp_state_name("05"), "FIN_WAIT2");
assert_eq!(tcp_state_name("06"), "TIME_WAIT");
assert_eq!(tcp_state_name("07"), "CLOSE");
assert_eq!(tcp_state_name("08"), "CLOSE_WAIT");
assert_eq!(tcp_state_name("09"), "LAST_ACK");
assert_eq!(tcp_state_name("0A"), "LISTEN");
assert_eq!(tcp_state_name("0B"), "CLOSING");
}
#[test]
fn state_unknown_code() {
assert_eq!(tcp_state_name("FF"), "UNKNOWN");
assert_eq!(tcp_state_name("00"), "UNKNOWN");
}
// parse_hex_addr
#[test]
fn ipv4_localhost() {
let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "127.0.0.1");
assert_eq!(port, 80);
}
#[test]
fn ipv4_any() {
let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "0.0.0.0");
assert_eq!(port, 53);
}
#[test]
fn ipv4_real_address() {
// 192.168.1.1 in little-endian = 0101A8C0
let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "192.168.1.1");
assert_eq!(port, 443);
}
#[test]
fn ipv4_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn invalid_hex_returns_none() {
assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn no_colon_returns_none() {
assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none());
}
#[test]
fn ipv6_loopback() {
// ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000
let (ip, port) = parse_hex_addr(
"00000000000000000000000001000000:0035",
libc::AF_INET6 as u32,
)
.unwrap();
assert_eq!(ip, "::1");
assert_eq!(port, 53);
}
#[test]
fn ipv6_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none());
}
// parse_proc_net_tcp
fn write_tcp_file(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
#[test]
fn parse_empty_file() {
let f = write_tcp_file(
" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n",
);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert!(conns.is_empty());
}
#[test]
fn parse_single_entry() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
assert_eq!(conns[0].local_ip, "127.0.0.1");
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[0].status, "LISTEN");
assert_eq!(conns[0].inode, 12345);
assert_eq!(conns[0].family, libc::AF_INET as u32);
}
#[test]
fn parse_skips_malformed_rows() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000
bad line
1: short\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
}
#[test]
fn parse_multiple_entries() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000
1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 2);
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[1].local_port, 443);
}
#[test]
fn parse_nonexistent_file_errors() {
assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err());
}
}

View File

@ -140,3 +140,92 @@ fn format_permissions(mode: u32) -> String {
} }
s s
} }
#[cfg(test)]
mod tests {
use super::*;
// format_permissions
#[test]
fn regular_file_755() {
assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x");
}
#[test]
fn directory_755() {
assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x");
}
#[test]
fn symlink_777() {
assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx");
}
#[test]
fn regular_file_000() {
assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------");
}
#[test]
fn regular_file_644() {
assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--");
}
#[test]
fn block_device() {
assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----");
}
#[test]
fn char_device() {
assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-");
}
#[test]
fn fifo() {
assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--");
}
#[test]
fn socket() {
assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x");
}
#[test]
fn unknown_type() {
assert_eq!(format_permissions(0o755), "?rwxr-xr-x");
}
#[test]
fn setuid_in_mode_only_affects_lower_bits() {
// setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755
assert_eq!(
format_permissions(libc::S_IFREG | 0o4755),
format_permissions(libc::S_IFREG | 0o755),
);
}
#[test]
fn output_always_10_chars() {
for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] {
assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10);
}
}
// meta_to_file_type — needs real filesystem
#[test]
fn meta_regular_file() {
let f = tempfile::NamedTempFile::new().unwrap();
let meta = std::fs::metadata(f.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE);
}
#[test]
fn meta_directory() {
let d = tempfile::TempDir::new().unwrap();
let meta = std::fs::metadata(d.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY);
}
}

View File

@ -31,15 +31,15 @@ impl FilesystemServiceImpl {
} }
fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> { fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> {
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user.clone()); let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user());
let user = lookup_user(&username).map_err(|e| { let user = lookup_user(&username).map_err(|e| {
ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}")) ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}"))
})?; })?;
let home_dir = user.dir.to_string_lossy().to_string(); let home_dir = user.dir.to_string_lossy().to_string();
let default_workdir = self.state.defaults.workdir.as_deref(); let default_workdir = self.state.defaults.workdir();
expand_and_resolve(path, &home_dir, default_workdir) expand_and_resolve(path, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e)) .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))
} }
} }
@ -97,7 +97,7 @@ impl Filesystem for FilesystemServiceImpl {
} }
} }
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user = let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
@ -122,7 +122,7 @@ impl Filesystem for FilesystemServiceImpl {
let source = self.resolve_path(request.source, &ctx)?; let source = self.resolve_path(request.source, &ctx)?;
let destination = self.resolve_path(request.destination, &ctx)?; let destination = self.resolve_path(request.destination, &ctx)?;
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user = let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;

View File

@ -37,6 +37,7 @@ pub struct ProcessHandle {
data_tx: broadcast::Sender<DataEvent>, data_tx: broadcast::Sender<DataEvent>,
end_tx: broadcast::Sender<EndEvent>, end_tx: broadcast::Sender<EndEvent>,
ended: Mutex<Option<EndEvent>>,
stdin: Mutex<Option<std::process::ChildStdin>>, stdin: Mutex<Option<std::process::ChildStdin>>,
pty_master: Mutex<Option<std::fs::File>>, pty_master: Mutex<Option<std::fs::File>>,
@ -51,6 +52,10 @@ impl ProcessHandle {
self.end_tx.subscribe() self.end_tx.subscribe()
} }
pub fn cached_end(&self) -> Option<EndEvent> {
self.ended.lock().unwrap().clone()
}
pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> { pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> {
signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| { signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}")) ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}"))
@ -128,6 +133,12 @@ impl ProcessHandle {
} }
} }
pub struct SpawnedProcess {
pub handle: Arc<ProcessHandle>,
pub data_rx: broadcast::Receiver<DataEvent>,
pub end_rx: broadcast::Receiver<EndEvent>,
}
pub fn spawn_process( pub fn spawn_process(
cmd_str: &str, cmd_str: &str,
args: &[String], args: &[String],
@ -138,7 +149,7 @@ pub fn spawn_process(
tag: Option<String>, tag: Option<String>,
user: &nix::unistd::User, user: &nix::unistd::User,
default_env_vars: &dashmap::DashMap<String, String>, default_env_vars: &dashmap::DashMap<String, String>,
) -> Result<Arc<ProcessHandle>, ConnectError> { ) -> Result<SpawnedProcess, ConnectError> {
let mut env: Vec<(String, String)> = Vec::new(); let mut env: Vec<(String, String)> = Vec::new();
env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default())); env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default()));
let home = user.dir.to_string_lossy().to_string(); let home = user.dir.to_string_lossy().to_string();
@ -244,10 +255,14 @@ pub fn spawn_process(
pid, pid,
data_tx: data_tx.clone(), data_tx: data_tx.clone(),
end_tx: end_tx.clone(), end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(None), stdin: Mutex::new(None),
pty_master: Mutex::new(Some(master_file)), pty_master: Mutex::new(Some(master_file)),
}); });
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
let data_tx_clone = data_tx.clone(); let data_tx_clone = data_tx.clone();
std::thread::spawn(move || { std::thread::spawn(move || {
let mut master = master_clone; let mut master = master_clone;
@ -264,30 +279,29 @@ pub fn spawn_process(
}); });
let end_tx_clone = end_tx.clone(); let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || { std::thread::spawn(move || {
let mut child = child; let mut child = child;
match child.wait() { let end_event = match child.wait() {
Ok(s) => { Ok(s) => EndEvent {
let _ = end_tx_clone.send(EndEvent {
exit_code: s.code().unwrap_or(-1), exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(), exited: s.code().is_some(),
status: format!("{s}"), status: format!("{s}"),
error: None, error: None,
}); },
} Err(e) => EndEvent {
Err(e) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: -1, exit_code: -1,
exited: false, exited: false,
status: "error".into(), status: "error".into(),
error: Some(e.to_string()), error: Some(e.to_string()),
}); },
} };
} *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
}); });
tracing::info!(pid, cmd = cmd_str, "process started (pty)"); tracing::info!(pid, cmd = cmd_str, "process started (pty)");
Ok(handle) Ok(SpawnedProcess { handle, data_rx, end_rx })
} else { } else {
let mut command = std::process::Command::new("/bin/sh"); let mut command = std::process::Command::new("/bin/sh");
command command
@ -327,10 +341,14 @@ pub fn spawn_process(
pid, pid,
data_tx: data_tx.clone(), data_tx: data_tx.clone(),
end_tx: end_tx.clone(), end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(stdin), stdin: Mutex::new(stdin),
pty_master: Mutex::new(None), pty_master: Mutex::new(None),
}); });
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
if let Some(mut out) = stdout { if let Some(mut out) = stdout {
let tx = data_tx.clone(); let tx = data_tx.clone();
std::thread::spawn(move || { std::thread::spawn(move || {
@ -364,29 +382,28 @@ pub fn spawn_process(
} }
let end_tx_clone = end_tx.clone(); let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || { std::thread::spawn(move || {
match child.wait() { let end_event = match child.wait() {
Ok(s) => { Ok(s) => EndEvent {
let _ = end_tx_clone.send(EndEvent {
exit_code: s.code().unwrap_or(-1), exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(), exited: s.code().is_some(),
status: format!("{s}"), status: format!("{s}"),
error: None, error: None,
}); },
} Err(e) => EndEvent {
Err(e) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: -1, exit_code: -1,
exited: false, exited: false,
status: "error".into(), status: "error".into(),
error: Some(e.to_string()), error: Some(e.to_string()),
}); },
} };
} *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
}); });
tracing::info!(pid, cmd = cmd_str, "process started (pipe)"); tracing::info!(pid, cmd = cmd_str, "process started (pipe)");
Ok(handle) Ok(SpawnedProcess { handle, data_rx, end_rx })
} }
} }

View File

@ -66,12 +66,12 @@ impl ProcessServiceImpl {
fn spawn_from_request( fn spawn_from_request(
&self, &self,
request: &StartRequestView<'_>, request: &StartRequestView<'_>,
) -> Result<Arc<ProcessHandle>, ConnectError> { ) -> Result<process_handler::SpawnedProcess, ConnectError> {
let proc_config = request.process.as_option().ok_or_else(|| { let proc_config = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process config required") ConnectError::new(ErrorCode::InvalidArgument, "process config required")
})?; })?;
let username = self.state.defaults.user.clone(); let username = self.state.defaults.user();
let user = let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
@ -85,7 +85,8 @@ impl ProcessServiceImpl {
let home_dir = user.dir.to_string_lossy().to_string(); let home_dir = user.dir.to_string_lossy().to_string();
let cwd_str: &str = proc_config.cwd.unwrap_or(""); let cwd_str: &str = proc_config.cwd.unwrap_or("");
let cwd = expand_and_resolve(cwd_str, &home_dir, self.state.defaults.workdir.as_deref()) let default_workdir = self.state.defaults.workdir();
let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?; .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?;
let effective_cwd = if cwd.is_empty() { "/" } else { &cwd }; let effective_cwd = if cwd.is_empty() { "/" } else { &cwd };
@ -116,7 +117,7 @@ impl ProcessServiceImpl {
"process.Start request" "process.Start request"
); );
let handle = process_handler::spawn_process( let spawned = process_handler::spawn_process(
cmd, cmd,
&args, &args,
&envs, &envs,
@ -128,17 +129,17 @@ impl ProcessServiceImpl {
&self.state.defaults.env_vars, &self.state.defaults.env_vars,
)?; )?;
self.processes.insert(handle.pid, Arc::clone(&handle)); self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle));
let processes = self.processes.clone(); let processes = self.processes.clone();
let pid = handle.pid; let pid = spawned.handle.pid;
let mut end_rx = handle.subscribe_end(); let mut cleanup_end_rx = spawned.handle.subscribe_end();
tokio::spawn(async move { tokio::spawn(async move {
let _ = end_rx.recv().await; let _ = cleanup_end_rx.recv().await;
processes.remove(&pid); processes.remove(&pid);
}); });
Ok(handle) Ok(spawned)
} }
} }
@ -182,26 +183,36 @@ impl Process for ProcessServiceImpl {
), ),
ConnectError, ConnectError,
> { > {
let handle = self.spawn_from_request(&request)?; let spawned = self.spawn_from_request(&request)?;
let pid = handle.pid; let pid = spawned.handle.pid;
let mut data_rx = handle.subscribe_data(); let mut data_rx = spawned.data_rx;
let mut end_rx = handle.subscribe_end(); let mut end_rx = spawned.end_rx;
let stream = async_stream::stream! { let stream = async_stream::stream! {
yield Ok(make_start_response(pid)); yield Ok(make_start_response(pid));
loop { loop {
match data_rx.recv().await { tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => yield Ok(make_data_start_response(ev)), Ok(ev) => yield Ok(make_data_start_response(ev)),
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break, Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
} }
} }
end = end_rx.recv() => {
if let Ok(end) = end_rx.recv().await { while let Ok(ev) = data_rx.try_recv() {
yield Ok(make_data_start_response(ev));
}
if let Ok(end) = end {
yield Ok(make_end_start_response(end)); yield Ok(make_end_start_response(end));
} }
break;
}
}
}
}; };
Ok((Box::pin(stream), ctx)) Ok((Box::pin(stream), ctx))
@ -226,6 +237,7 @@ impl Process for ProcessServiceImpl {
let mut data_rx = handle.subscribe_data(); let mut data_rx = handle.subscribe_data();
let mut end_rx = handle.subscribe_end(); let mut end_rx = handle.subscribe_end();
let cached_end = handle.cached_end();
let stream = async_stream::stream! { let stream = async_stream::stream! {
yield Ok(ConnectResponse { yield Ok(ConnectResponse {
@ -238,8 +250,17 @@ impl Process for ProcessServiceImpl {
..Default::default() ..Default::default()
}); });
if let Some(end) = cached_end {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
} else {
loop { loop {
match data_rx.recv().await { tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => { Ok(ev) => {
yield Ok(ConnectResponse { yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)), event: buffa::MessageField::some(make_data_event(ev)),
@ -250,13 +271,24 @@ impl Process for ProcessServiceImpl {
Err(tokio::sync::broadcast::error::RecvError::Closed) => break, Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
} }
} }
end = end_rx.recv() => {
if let Ok(end) = end_rx.recv().await { while let Ok(ev) = data_rx.try_recv() {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
});
}
if let Ok(end) = end {
yield Ok(ConnectResponse { yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)), event: buffa::MessageField::some(make_end_event(end)),
..Default::default() ..Default::default()
}); });
} }
break;
}
}
}
}
}; };
Ok((Box::pin(stream), ctx)) Ok((Box::pin(stream), ctx))

View File

@ -19,6 +19,7 @@ pub struct AppState {
pub port_subsystem: Option<Arc<PortSubsystem>>, pub port_subsystem: Option<Arc<PortSubsystem>>,
pub cpu_used_pct: AtomicU32, pub cpu_used_pct: AtomicU32,
pub cpu_count: AtomicU32, pub cpu_count: AtomicU32,
pub snapshot_in_progress: AtomicBool,
} }
impl AppState { impl AppState {
@ -41,6 +42,7 @@ impl AppState {
port_subsystem, port_subsystem,
cpu_used_pct: AtomicU32::new(0), cpu_used_pct: AtomicU32::new(0),
cpu_count: AtomicU32::new(0), cpu_count: AtomicU32::new(0),
snapshot_in_progress: AtomicBool::new(false),
}); });
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);

View File

@ -11,6 +11,10 @@ impl AtomicMax {
} }
} }
pub fn get(&self) -> i64 {
self.val.load(Ordering::Acquire)
}
/// Sets the stored value to `new` if `new` is strictly greater than /// Sets the stored value to `new` if `new` is strictly greater than
/// the current value. Returns `true` if the value was updated. /// the current value. Returns `true` if the value was updated.
pub fn set_to_greater(&self, new: i64) -> bool { pub fn set_to_greater(&self, new: i64) -> bool {
@ -31,3 +35,68 @@ impl AtomicMax {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn initial_value_is_i64_min() {
let m = AtomicMax::new();
assert_eq!(m.get(), i64::MIN);
}
#[test]
fn updates_when_larger() {
let m = AtomicMax::new();
assert!(m.set_to_greater(0));
assert_eq!(m.get(), 0);
assert!(m.set_to_greater(100));
assert_eq!(m.get(), 100);
}
#[test]
fn returns_false_when_equal() {
let m = AtomicMax::new();
m.set_to_greater(42);
assert!(!m.set_to_greater(42));
assert_eq!(m.get(), 42);
}
#[test]
fn returns_false_when_smaller() {
let m = AtomicMax::new();
m.set_to_greater(100);
assert!(!m.set_to_greater(50));
assert_eq!(m.get(), 100);
}
#[test]
fn concurrent_convergence() {
let m = Arc::new(AtomicMax::new());
let threads: Vec<_> = (0..8)
.map(|t| {
let m = Arc::clone(&m);
std::thread::spawn(move || {
for i in (t * 100)..((t + 1) * 100) {
m.set_to_greater(i);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(m.get(), 799);
}
#[test]
fn i64_max_boundary() {
let m = AtomicMax::new();
assert!(m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(0));
assert_eq!(m.get(), i64::MAX);
}
}

View File

@ -26,3 +26,7 @@ export async function listAdminUsers(page: number = 1): Promise<ApiResult<AdminU
export async function setUserActive(id: string, active: boolean): Promise<ApiResult<void>> { export async function setUserActive(id: string, active: boolean): Promise<ApiResult<void>> {
return apiFetch('PUT', `/api/v1/admin/users/${id}/active`, { active }); return apiFetch('PUT', `/api/v1/admin/users/${id}/active`, { active });
} }
export async function setUserAdmin(id: string, admin: boolean): Promise<ApiResult<void>> {
return apiFetch('PUT', `/api/v1/admin/users/${id}/admin`, { admin });
}

View File

@ -213,6 +213,7 @@
}, },
}); });
lastDataKey = '';
updateCharts(); updateCharts();
} }
@ -233,6 +234,7 @@
onMount(async () => { onMount(async () => {
if (!available) return; if (!available) return;
loadMetrics();
const mod = await import('chart.js/auto'); const mod = await import('chart.js/auto');
ChartJS = mod.Chart; ChartJS = mod.Chart;

View File

@ -2,6 +2,7 @@
import CopyButton from '$lib/components/CopyButton.svelte'; import CopyButton from '$lib/components/CopyButton.svelte';
import { onMount, onDestroy } from 'svelte'; import { onMount, onDestroy } from 'svelte';
import { toast } from '$lib/toast.svelte'; import { toast } from '$lib/toast.svelte';
import { auth } from '$lib/auth.svelte';
import { formatDate, timeAgo } from '$lib/utils/format'; import { formatDate, timeAgo } from '$lib/utils/format';
import { import {
listBuilds, listBuilds,
@ -13,6 +14,7 @@
type BuildLogEntry, type BuildLogEntry,
type AdminTemplate type AdminTemplate
} from '$lib/api/builds'; } from '$lib/api/builds';
import { listAdminTeams } from '$lib/api/team';
let activeTab = $state<'templates' | 'builds'>('templates'); let activeTab = $state<'templates' | 'builds'>('templates');
@ -35,6 +37,9 @@
let expandedBuildId = $state<string | null>(null); let expandedBuildId = $state<string | null>(null);
let expandedSteps = $state<Set<number>>(new Set()); let expandedSteps = $state<Set<number>>(new Set());
// Team name lookup
let teamNames = $state<Map<string, string>>(new Map());
// Delete template state // Delete template state
let deleteTarget = $state<AdminTemplate | null>(null); let deleteTarget = $state<AdminTemplate | null>(null);
let deleting = $state(false); let deleting = $state(false);
@ -64,6 +69,28 @@
let baseCount = $derived(templates.filter((t) => t.type === 'base').length); let baseCount = $derived(templates.filter((t) => t.type === 'base').length);
let runningBuilds = $derived(builds.filter((b) => b.status === 'running').length); let runningBuilds = $derived(builds.filter((b) => b.status === 'running').length);
async function fetchTeamNames() {
const names = new Map<string, string>();
let page = 1;
while (true) {
const result = await listAdminTeams(page);
if (!result.ok) break;
for (const team of result.data.teams) {
names.set(team.id, team.name);
}
if (page >= result.data.total_pages) break;
page++;
}
teamNames = names;
}
const PLATFORM_TEAM_ID = 'team-0000000000000000000000000';
function canDeleteTemplate(tmpl: AdminTemplate): boolean {
if (tmpl.name === 'minimal') return false;
return tmpl.team_id === PLATFORM_TEAM_ID;
}
async function fetchTemplates() { async function fetchTemplates() {
templatesLoading = true; templatesLoading = true;
templatesError = null; templatesError = null;
@ -238,6 +265,7 @@
} }
onMount(() => { onMount(() => {
fetchTeamNames();
fetchTemplates(); fetchTemplates();
fetchBuilds().then(startPolling); fetchBuilds().then(startPolling);
@ -339,7 +367,7 @@
<div class="flex-1 overflow-y-auto px-8 py-6"> <div class="flex-1 overflow-y-auto px-8 py-6">
{#if activeTab === 'templates'} {#if activeTab === 'templates'}
{#if templatesLoading} {#if templatesLoading}
{@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])} {@render skeletonRows(5, ['Name', 'Type', 'Owner', 'Specs', 'Size', 'Created', ''])}
{:else if templatesError} {:else if templatesError}
<div class="rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]"> <div class="rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{templatesError} {templatesError}
@ -442,6 +470,7 @@
<tr class="border-b border-[var(--color-border)] bg-[var(--color-bg-0)]/40"> <tr class="border-b border-[var(--color-border)] bg-[var(--color-bg-0)]/40">
<th class="table-header">Name</th> <th class="table-header">Name</th>
<th class="table-header">Type</th> <th class="table-header">Type</th>
<th class="table-header hidden md:table-cell">Owner</th>
<th class="table-header hidden md:table-cell">Specs</th> <th class="table-header hidden md:table-cell">Specs</th>
<th class="table-header hidden lg:table-cell">Size</th> <th class="table-header hidden lg:table-cell">Size</th>
<th class="table-header hidden lg:table-cell">Created</th> <th class="table-header hidden lg:table-cell">Created</th>
@ -473,6 +502,13 @@
</span> </span>
{/if} {/if}
</td> </td>
<td class="hidden px-5 py-3.5 md:table-cell">
{#if tmpl.team_id === PLATFORM_TEAM_ID}
<span class="text-meta text-[var(--color-text-muted)]">Platform</span>
{:else}
<span class="text-meta text-[var(--color-text-secondary)]">{teamNames.get(tmpl.team_id) ?? tmpl.team_id}</span>
{/if}
</td>
<td class="hidden px-5 py-3.5 md:table-cell"> <td class="hidden px-5 py-3.5 md:table-cell">
{#if tmpl.vcpus && tmpl.memory_mb} {#if tmpl.vcpus && tmpl.memory_mb}
<span class="font-mono text-meta tabular-nums text-[var(--color-text-secondary)]"> <span class="font-mono text-meta tabular-nums text-[var(--color-text-secondary)]">
@ -495,7 +531,11 @@
<td class="px-5 py-3.5 text-right"> <td class="px-5 py-3.5 text-right">
<button <button
onclick={() => { deleteTarget = tmpl; deleteError = null; }} onclick={() => { deleteTarget = tmpl; deleteError = null; }}
class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta text-[var(--color-text-tertiary)] transition-all duration-150 hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]" disabled={!canDeleteTemplate(tmpl)}
title={tmpl.name === 'minimal' ? 'The minimal template cannot be deleted' : !canDeleteTemplate(tmpl) ? 'Cannot delete templates owned by other teams' : undefined}
class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta transition-all duration-150 {canDeleteTemplate(tmpl)
? 'text-[var(--color-text-tertiary)] hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]'
: 'text-[var(--color-text-muted)] cursor-not-allowed opacity-40'}"
> >
Delete Delete
</button> </button>

View File

@ -5,8 +5,10 @@
import { import {
listAdminUsers, listAdminUsers,
setUserActive, setUserActive,
setUserAdmin,
type AdminUser, type AdminUser,
} from '$lib/api/admin-users'; } from '$lib/api/admin-users';
import { auth } from '$lib/auth.svelte';
// Data state // Data state
let users = $state<AdminUser[]>([]); let users = $state<AdminUser[]>([]);
@ -22,6 +24,11 @@
// Toggle state // Toggle state
let togglingId = $state<string | null>(null); let togglingId = $state<string | null>(null);
// Admin dialog state
let adminTarget = $state<AdminUser | null>(null);
let togglingAdmin = $state(false);
let adminError = $state<string | null>(null);
async function fetchUsers(page: number = 1) { async function fetchUsers(page: number = 1) {
const wasEmpty = users.length === 0; const wasEmpty = users.length === 0;
if (wasEmpty) loading = true; if (wasEmpty) loading = true;
@ -56,6 +63,23 @@
togglingId = null; togglingId = null;
} }
async function handleConfirmAdminToggle() {
if (!adminTarget) return;
togglingAdmin = true;
adminError = null;
const target = adminTarget;
const newAdmin = !target.is_admin;
const result = await setUserAdmin(target.id, newAdmin);
if (result.ok) {
adminTarget = null;
target.is_admin = newAdmin;
toast.success(`${target.email} ${newAdmin ? 'granted' : 'revoked'} admin`);
} else {
adminError = result.error;
}
togglingAdmin = false;
}
function goToPage(page: number) { function goToPage(page: number) {
if (page < 1 || page > totalPages) return; if (page < 1 || page > totalPages) return;
fetchUsers(page); fetchUsers(page);
@ -222,8 +246,18 @@
</div> </div>
<!-- Role --> <!-- Role -->
<div class="px-5 py-4"> <div class="flex items-center px-5 py-4">
<span class="text-ui text-[var(--color-text-secondary)]">{user.is_admin ? 'Admin' : 'User'}</span> <button
onclick={() => { adminError = null; adminTarget = user; }}
disabled={user.status !== 'active'}
aria-label="{user.is_admin ? 'Revoke admin for' : 'Grant admin to'} {user.name || user.email}"
class="rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-medium transition-all duration-150 disabled:opacity-50
{user.is_admin
? 'border-[var(--color-amber)]/30 bg-[var(--color-amber)]/10 text-[var(--color-amber)] hover:bg-[var(--color-amber)]/20 hover:border-[var(--color-amber)]/50'
: 'border-[var(--color-border)] text-[var(--color-text-tertiary)] hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-secondary)]'}"
>
{user.is_admin ? 'Admin' : 'User'}
</button>
</div> </div>
<!-- Joined --> <!-- Joined -->
@ -292,3 +326,72 @@
</div> </div>
</footer> </footer>
</main> </main>
<!-- Admin confirmation dialog -->
{#if adminTarget}
<div class="fixed inset-0 z-50 flex items-center justify-center">
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="absolute inset-0 bg-black/60"
onclick={() => { if (!togglingAdmin) adminTarget = null; }}
onkeydown={(e) => { if (e.key === 'Escape' && !togglingAdmin) adminTarget = null; }}
></div>
<div
class="relative w-full max-w-[420px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)]"
style="animation: fadeUp 0.2s ease both; box-shadow: var(--shadow-dialog)"
>
<div class="p-6">
<h2 class="font-serif text-heading leading-tight text-[var(--color-text-bright)]">
{adminTarget.is_admin ? 'Revoke Admin' : 'Grant Admin'}
</h2>
<p class="mt-1.5 text-ui text-[var(--color-text-tertiary)]">
{adminTarget.is_admin ? 'Remove admin access from' : 'Grant admin access to'}
<code class="rounded bg-[var(--color-bg-4)] px-1.5 py-0.5 font-mono text-[0.8rem] text-[var(--color-text-primary)]">{adminTarget.email}</code>.
{adminTarget.is_admin
? 'They will lose access to the admin panel immediately.'
: 'They will be able to manage all platform resources.'}
</p>
{#if adminTarget.is_admin && adminTarget.id === auth.userId}
<div class="mt-3 flex items-start gap-2.5 rounded-[var(--radius-input)] border border-[var(--color-amber)]/30 bg-[var(--color-amber)]/5 px-3 py-2.5">
<svg class="mt-0.5 shrink-0 text-[var(--color-amber)]" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z" /><line x1="12" y1="9" x2="12" y2="13" /><line x1="12" y1="17" x2="12.01" y2="17" />
</svg>
<span class="text-meta text-[var(--color-amber)]">
You are removing your own admin access. You will lose access to this panel.
</span>
</div>
{/if}
{#if adminError}
<div class="mt-3 rounded-[var(--radius-input)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-3 py-2 text-meta text-[var(--color-red)]">
{adminError}
</div>
{/if}
<div class="mt-6 flex justify-end gap-3">
<button
onclick={() => (adminTarget = null)}
disabled={togglingAdmin}
class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
>
Cancel
</button>
<button
onclick={handleConfirmAdminToggle}
disabled={togglingAdmin}
class="flex items-center gap-2 rounded-[var(--radius-button)] px-5 py-2 text-ui font-semibold text-white transition-all duration-150 hover:brightness-110 hover:-translate-y-px active:translate-y-0 disabled:opacity-50 disabled:hover:translate-y-0
{adminTarget.is_admin ? 'bg-[var(--color-red)]' : 'bg-[var(--color-accent)]'}"
>
{#if togglingAdmin}
<svg class="animate-spin" width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
{adminTarget.is_admin ? 'Revoking...' : 'Granting...'}
{:else}
{adminTarget.is_admin ? 'Revoke admin' : 'Grant admin'}
{/if}
</button>
</div>
</div>
</div>
</div>
{/if}

View File

@ -120,6 +120,25 @@
} }
} }
function mergeCapsuleData(incoming: Capsule[]) {
const existingMap = new Map(capsules.map((c) => [c.id, c]));
const merged: Capsule[] = [];
for (const fresh of incoming) {
const existing = existingMap.get(fresh.id);
if (existing) {
for (const key of Object.keys(fresh) as (keyof Capsule)[]) {
if (existing[key] !== fresh[key]) {
(existing as any)[key] = fresh[key];
}
}
merged.push(existing);
} else {
merged.push(fresh);
}
}
capsules = merged;
}
async function fetchCapsules(manual = false) { async function fetchCapsules(manual = false) {
const wasEmpty = capsules.length === 0; const wasEmpty = capsules.length === 0;
if (wasEmpty) loading = true; if (wasEmpty) loading = true;
@ -131,7 +150,11 @@
const result = await listCapsules(); const result = await listCapsules();
if (result.ok) { if (result.ok) {
if (wasEmpty) {
capsules = result.data; capsules = result.data;
} else {
mergeCapsuleData(result.data);
}
error = null; error = null;
} else { } else {
error = result.error; error = result.error;

View File

@ -333,6 +333,7 @@
}, },
}); });
lastDataKey = '';
updateCharts(); updateCharts();
} }
@ -376,6 +377,7 @@
if (!metricsAvailable) return; if (!metricsAvailable) return;
loadMetrics();
const mod = await import('chart.js/auto'); const mod = await import('chart.js/auto');
ChartJS = mod.Chart; ChartJS = mod.Chart;

View File

@ -350,9 +350,23 @@ func runPtyLoop(
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
for msg := range inputCh { // pending holds a non-input message dequeued during coalescing
// Use a background context for unary RPCs so they complete // that must be processed on the next iteration.
// even if the stream context is being cancelled. var pending *wsPtyIn
for {
var msg wsPtyIn
if pending != nil {
msg = *pending
pending = nil
} else {
var ok bool
msg, ok = <-inputCh
if !ok {
break
}
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
switch msg.Type { switch msg.Type {
@ -364,7 +378,7 @@ func runPtyLoop(
} }
// Coalesce: drain any queued input messages into a single RPC. // Coalesce: drain any queued input messages into a single RPC.
data = coalescePtyInput(inputCh, data) data, pending = coalescePtyInput(inputCh, data)
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
@ -418,24 +432,29 @@ func runPtyLoop(
} }
}() }()
// When any pump cancels the context, close the websocket to unblock
// the reader goroutine stuck in ReadMessage.
go func() {
<-ctx.Done()
ws.conn.Close()
}()
wg.Wait() wg.Wait()
} }
// coalescePtyInput drains any immediately-available "input" messages from the // coalescePtyInput drains any immediately-available "input" messages from the
// channel and appends their decoded data to buf, reducing RPC call volume // channel and appends their decoded data to buf, reducing RPC call volume
// during bursts of fast typing. // during bursts of fast typing. Returns the coalesced buffer and any
func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { // non-input message that was dequeued (must be processed by the caller).
func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) ([]byte, *wsPtyIn) {
for { for {
select { select {
case msg, ok := <-ch: case msg, ok := <-ch:
if !ok { if !ok {
return buf return buf, nil
} }
if msg.Type != "input" { if msg.Type != "input" {
// Non-input message — can't coalesce. Put-back isn't possible return buf, &msg
// with channels, but resize/kill during a typing burst is rare
// enough that dropping one is acceptable.
return buf
} }
data, err := base64.StdEncoding.DecodeString(msg.Data) data, err := base64.StdEncoding.DecodeString(msg.Data)
if err != nil { if err != nil {
@ -443,7 +462,7 @@ func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte {
} }
buf = append(buf, data...) buf = append(buf, data...)
default: default:
return buf return buf, nil
} }
} }
} }

View File

@ -162,3 +162,58 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
} }
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
// SetUserAdmin handles PUT /v1/admin/users/{id}/admin
// Grants or revokes platform admin status. Cannot remove the last admin.
func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
userIDStr := chi.URLParam(r, "id")
userID, err := id.ParseUserID(userIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Admin bool `json:"admin"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
user, err := h.db.GetUserByID(r.Context(), userID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "user not found")
return
}
if user.IsAdmin == req.Admin {
w.WriteHeader(http.StatusNoContent)
return
}
if req.Admin {
if err := h.db.SetUserAdmin(r.Context(), db.SetUserAdminParams{
ID: userID,
IsAdmin: true,
}); err != nil {
writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status")
return
}
h.audit.LogUserGrantAdmin(r.Context(), ac, userID, user.Email)
} else {
affected, err := h.db.RevokeUserAdmin(r.Context(), userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status")
return
}
if affected == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "cannot remove the last admin")
return
}
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -2346,6 +2346,54 @@ paths:
schema: schema:
$ref: "#/components/schemas/Error" $ref: "#/components/schemas/Error"
/v1/admin/users/{id}/admin:
put:
summary: Grant or revoke platform admin
operationId: setUserAdmin
tags: [admin]
description: |
Sets the platform admin flag on a user. Cannot remove the last admin.
Requires platform admin access (JWT + is_admin).
The target user's JWT is not re-issued — their frontend will reflect the
change on next login or team switch.
security:
- bearerAuth: []
parameters:
- name: id
in: path
required: true
schema:
type: string
example: "usr-a1b2c3d4"
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [admin]
properties:
admin:
type: boolean
description: true to grant admin, false to revoke.
responses:
"204":
description: Admin status updated
"400":
$ref: "#/components/responses/BadRequest"
"403":
description: Caller is not a platform admin
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: User not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components: components:
securitySchemes: securitySchemes:
apiKeyAuth: apiKeyAuth:

View File

@ -269,6 +269,7 @@ func New(
r.Delete("/teams/{id}", teamH.AdminDeleteTeam) r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/users", usersH.AdminListUsers) r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive) r.Put("/users/{id}/active", usersH.SetUserActive)
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
r.Get("/audit-logs", auditH.AdminList) r.Get("/audit-logs", auditH.AdminList)
r.Get("/templates", buildH.ListTemplates) r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate) r.Delete("/templates/{name}", buildH.DeleteTemplate)

View File

@ -135,6 +135,20 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer tracker.Release() defer tracker.Release()
// Derive request context from the tracker's context so ForceClose()
// during pause aborts this proxied request.
trackerCtx := tracker.Context()
reqCtx, reqCancel := context.WithCancel(r.Context())
defer reqCancel()
go func() {
select {
case <-trackerCtx.Done():
reqCancel()
case <-reqCtx.Done():
}
}()
r = r.WithContext(reqCtx)
proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum)) proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum))
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
} }

View File

@ -11,6 +11,7 @@ type SandboxStatus string
const ( const (
StatusPending SandboxStatus = "pending" StatusPending SandboxStatus = "pending"
StatusRunning SandboxStatus = "running" StatusRunning SandboxStatus = "running"
StatusPausing SandboxStatus = "pausing"
StatusPaused SandboxStatus = "paused" StatusPaused SandboxStatus = "paused"
StatusStopped SandboxStatus = "stopped" StatusStopped SandboxStatus = "stopped"
StatusError SandboxStatus = "error" StatusError SandboxStatus = "error"

View File

@ -1,6 +1,7 @@
package sandbox package sandbox
import ( import (
"context"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -17,6 +18,20 @@ type ConnTracker struct {
// goroutine to exit, preventing goroutine leaks on repeated pause failures. // goroutine to exit, preventing goroutine leaks on repeated pause failures.
cancelMu sync.Mutex cancelMu sync.Mutex
cancelDrain chan struct{} cancelDrain chan struct{}
// ctx is cancelled by ForceClose to abort all in-flight proxy requests.
// Initialized lazily on first Acquire; replaced by Reset after a failed
// pause so new connections get a fresh, non-cancelled context.
ctxMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// ensureCtx lazily initializes the cancellable context.
func (t *ConnTracker) ensureCtx() {
if t.ctx == nil {
t.ctx, t.cancel = context.WithCancel(context.Background())
}
} }
// Acquire registers one in-flight connection. Returns false if the tracker // Acquire registers one in-flight connection. Returns false if the tracker
@ -35,6 +50,16 @@ func (t *ConnTracker) Acquire() bool {
return true return true
} }
// Context returns a context that is cancelled when ForceClose is called.
// Proxy handlers should derive their request context from this so that
// force-close during pause aborts in-flight proxied requests.
func (t *ConnTracker) Context() context.Context {
t.ctxMu.Lock()
defer t.ctxMu.Unlock()
t.ensureCtx()
return t.ctx
}
// Release marks one connection as complete. Must be called exactly once // Release marks one connection as complete. Must be called exactly once
// per successful Acquire. // per successful Acquire.
func (t *ConnTracker) Release() { func (t *ConnTracker) Release() {
@ -65,9 +90,33 @@ func (t *ConnTracker) Drain(timeout time.Duration) {
} }
} }
// ForceClose cancels all in-flight proxy connections by cancelling the
// shared context. Connections whose request context derives from Context()
// will see their requests aborted, causing the proxy handler to return
// and call Release(). Waits briefly for connections to actually release.
func (t *ConnTracker) ForceClose() {
t.ctxMu.Lock()
if t.cancel != nil {
t.cancel()
}
t.ctxMu.Unlock()
// Wait briefly for force-closed connections to call Release().
done := make(chan struct{})
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
}
}
// Reset re-enables the tracker after a failed drain. This allows the // Reset re-enables the tracker after a failed drain. This allows the
// sandbox to accept proxy connections again if the pause operation fails // sandbox to accept proxy connections again if the pause operation fails
// and the VM is resumed. It also cancels any lingering Drain goroutine. // and the VM is resumed. It also cancels any lingering Drain goroutine
// and creates a fresh context for new connections.
func (t *ConnTracker) Reset() { func (t *ConnTracker) Reset() {
t.cancelMu.Lock() t.cancelMu.Lock()
if t.cancelDrain != nil { if t.cancelDrain != nil {
@ -81,5 +130,10 @@ func (t *ConnTracker) Reset() {
} }
t.cancelMu.Unlock() t.cancelMu.Unlock()
// Replace the cancelled context with a fresh one.
t.ctxMu.Lock()
t.ctx, t.cancel = context.WithCancel(context.Background())
t.ctxMu.Unlock()
t.draining.Store(false) t.draining.Store(false)
} }

View File

@ -95,10 +95,10 @@ type snapshotParent struct {
} }
// maxDiffGenerations caps how many incremental diff generations we chain // maxDiffGenerations caps how many incremental diff generations we chain
// before falling back to a Full snapshot to collapse the chain. Long diff // before merging diffs into a single file. Since UFFD lazy-loads memory
// chains increase restore latency and snapshot directory size; a periodic // anyway, we merge on every re-pause to keep exactly 1 diff file per
// Full snapshot resets the counter and produces a clean base. // snapshot — no accumulated chain, no extra restore overhead.
const maxDiffGenerations = 8 const maxDiffGenerations = 1
// buildMetadata constructs the metadata map with version information. // buildMetadata constructs the metadata map with version information.
func (m *Manager) buildMetadata(envdVersion string) map[string]string { func (m *Manager) buildMetadata(envdVersion string) map[string]string {
@ -186,9 +186,12 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
} }
// Create dm-snapshot with per-sandbox CoW file. // Create dm-snapshot with per-sandbox CoW file.
// CoW must be at least as large as the origin — if every block is
// rewritten, the CoW stores a full copy. Undersized CoW causes
// dm-snapshot invalidation → EIO on all guest I/O.
dmName := "wrenn-" + sandboxID dmName := "wrenn-" + sandboxID
cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
cowSize := int64(diskSizeMB) * 1024 * 1024 cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil { if err != nil {
m.loops.Release(baseRootfs) m.loops.Release(baseRootfs)
@ -374,28 +377,43 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
} }
// Mark sandbox as pausing to block new exec/file/PTY operations.
m.mu.Lock()
sb.Status = models.StatusPausing
m.mu.Unlock()
// restoreRunning reverts state if any pre-freeze step fails.
restoreRunning := func() {
_ = m.vm.UpdateBalloon(context.Background(), sandboxID, 0)
sb.connTracker.Reset()
m.mu.Lock()
sb.Status = models.StatusRunning
m.mu.Unlock()
m.startSampler(sb)
}
// Stop the metrics sampler goroutine before tearing down any resources // Stop the metrics sampler goroutine before tearing down any resources
// it reads (dm device, Firecracker PID). Without this, the sampler // it reads (dm device, Firecracker PID). Without this, the sampler
// leaks on every successful pause. // leaks on every successful pause.
m.stopSampler(sb) m.stopSampler(sb)
// Step 0: Drain in-flight proxy connections before freezing vCPUs. // ── Step 1: Isolate from external traffic ─────────────────────────
// Stale TCP state from mid-flight connections causes issues on restore. // Drain in-flight proxy connections (grace period for clean shutdown).
sb.connTracker.Drain(2 * time.Second) sb.connTracker.Drain(5 * time.Second)
slog.Debug("pause: proxy connections drained", "id", sandboxID) // Force-close any connections that didn't finish during grace period.
sb.connTracker.ForceClose()
slog.Debug("pause: external connections closed", "id", sandboxID)
// Step 0b: Close host-side idle connections to envd. Done before // Close host-side idle connections to envd so FIN packets propagate
// PrepareSnapshot so FIN packets propagate to the guest during the // to the guest kernel before snapshot.
// PrepareSnapshot window (no extra sleep needed).
sb.client.CloseIdleConnections() sb.client.CloseIdleConnections()
slog.Debug("pause: envd client idle connections closed", "id", sandboxID)
// Step 0c: Signal envd to quiesce (stop port scanner/forwarder, mark // ── Step 2: Drop page cache ──────────────────────────────────────
// connections for post-restore cleanup). The 3s timeout also gives time // Signal envd to quiesce: drops page cache, stops port subsystem,
// for the FINs from Step 0b to be processed by the guest kernel. // marks connections for post-restore cleanup. Page cache drop can
// Best-effort: a failure is logged but does not abort the pause. // take significant time on large-memory VMs (20GB+).
func() { func() {
prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second) prepCtx, prepCancel := context.WithTimeout(ctx, 30*time.Second)
defer prepCancel() defer prepCancel()
if err := sb.client.PrepareSnapshot(prepCtx); err != nil { if err := sb.client.PrepareSnapshot(prepCtx); err != nil {
slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err) slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err)
@ -404,11 +422,37 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
} }
}() }()
// ── Step 3: Inflate balloon to reclaim free guest memory ─────────
// Freed pages become zero from FC's perspective, so ProcessMemfile
// skips them → dramatically smaller memfile (e.g. 20GB → 1GB).
func() {
memUsed, err := readEnvdMemUsed(sb.client)
if err != nil {
slog.Debug("pause: could not read guest memory, skipping balloon inflate", "id", sandboxID, "error", err)
return
}
usedMiB := int(memUsed / (1024 * 1024))
keepMiB := max(usedMiB*2, 256) + 128
inflateMiB := sb.MemoryMB - keepMiB
if inflateMiB <= 0 {
slog.Debug("pause: not enough free memory for balloon inflate", "id", sandboxID, "used_mib", usedMiB, "total_mib", sb.MemoryMB)
return
}
balloonCtx, balloonCancel := context.WithTimeout(ctx, 10*time.Second)
defer balloonCancel()
if err := m.vm.UpdateBalloon(balloonCtx, sandboxID, inflateMiB); err != nil {
slog.Debug("pause: balloon inflate failed (non-fatal)", "id", sandboxID, "error", err)
return
}
time.Sleep(2 * time.Second)
slog.Info("pause: balloon inflated", "id", sandboxID, "inflate_mib", inflateMiB, "guest_used_mib", usedMiB)
}()
// ── Step 4: Freeze vCPUs ─────────────────────────────────────────
pauseStart := time.Now() pauseStart := time.Now()
// Step 1: Pause the VM (freeze vCPUs).
if err := m.vm.Pause(ctx, sandboxID); err != nil { if err := m.vm.Pause(ctx, sandboxID); err != nil {
sb.connTracker.Reset() restoreRunning()
return fmt.Errorf("pause VM: %w", err) return fmt.Errorf("pause VM: %w", err)
} }
slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart))
@ -423,13 +467,23 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// resumeOnError unpauses the VM so the sandbox stays usable when a // resumeOnError unpauses the VM so the sandbox stays usable when a
// post-freeze step fails. If the resume itself fails, the sandbox is // post-freeze step fails. If the resume itself fails, the sandbox is
// left frozen — the caller should destroy it. It also resets the // frozen and unrecoverable — destroy it to avoid a zombie.
// connection tracker so the sandbox can accept proxy connections again.
resumeOnError := func() { resumeOnError := func() {
sb.connTracker.Reset() // Use a fresh context — the caller's ctx may already be cancelled.
if err := m.vm.Resume(ctx, sandboxID); err != nil { resumeCtx, resumeCancel := context.WithTimeout(context.Background(), 30*time.Second)
slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) defer resumeCancel()
if err := m.vm.Resume(resumeCtx, sandboxID); err != nil {
slog.Error("failed to resume VM after pause error — destroying frozen sandbox", "id", sandboxID, "error", err)
m.cleanup(context.Background(), sb)
m.mu.Lock()
delete(m.boxes, sandboxID)
m.mu.Unlock()
if m.onDestroy != nil {
m.onDestroy(sandboxID)
} }
return
}
restoreRunning()
} }
// Step 2: Take VM state snapshot (snapfile + memfile). // Step 2: Take VM state snapshot (snapfile + memfile).
@ -444,6 +498,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
snapshotStart := time.Now() snapshotStart := time.Now()
if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil { if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil {
slog.Error("pause: snapshot failed", "id", sandboxID, "type", snapshotType, "elapsed", time.Since(snapshotStart), "error", err)
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError() resumeOnError()
return fmt.Errorf("create VM snapshot: %w", err) return fmt.Errorf("create VM snapshot: %w", err)
@ -795,6 +850,12 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int,
slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err) slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err)
} }
// Deflate balloon — the snapshot was taken with an inflated balloon to
// reduce memfile size, so restore the guest's full memory allocation.
if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil {
slog.Debug("resume: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err)
}
// Fetch envd version (best-effort). // Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx) envdVersion, _ := client.FetchVersion(ctx)
@ -1134,7 +1195,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
dmName := "wrenn-" + sandboxID dmName := "wrenn-" + sandboxID
cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
cowSize := int64(diskSizeMB) * 1024 * 1024 cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil { if err != nil {
source.Close() source.Close()
@ -1235,6 +1296,11 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err) slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err)
} }
// Deflate balloon — template snapshot was taken with an inflated balloon.
if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil {
slog.Debug("create-from-snapshot: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err)
}
// Fetch envd version (best-effort). // Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx) envdVersion, _ := client.FetchVersion(ctx)
@ -1720,12 +1786,12 @@ func (m *Manager) startSampler(sb *sandboxState) {
go m.samplerLoop(ctx, sb, fcPID, sb.VCPUs, initialCPU) go m.samplerLoop(ctx, sb, fcPID, sb.VCPUs, initialCPU)
} }
// samplerLoop samples /proc metrics at 500ms intervals. // samplerLoop samples metrics at 1s intervals.
// lastCPU is goroutine-local to avoid shared-state races. // lastCPU is goroutine-local to avoid shared-state races.
func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpus int, lastCPU cpuStat) { func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpus int, lastCPU cpuStat) {
defer close(sb.samplerDone) defer close(sb.samplerDone)
ticker := time.NewTicker(500 * time.Millisecond) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
clkTck := 100.0 // sysconf(_SC_CLK_TCK), almost always 100 on Linux clkTck := 100.0 // sysconf(_SC_CLK_TCK), almost always 100 on Linux
@ -1758,8 +1824,11 @@ func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpu
cpuInitialized = true cpuInitialized = true
} }
// Memory: VmRSS of the Firecracker process. // Memory: guest-reported used memory from envd /metrics.
memBytes, _ := readMemRSS(fcPID) // VmRSS of the Firecracker process includes guest page cache
// and never decreases, so we use the guest's own view which
// reports total - available (actual process memory).
memBytes, _ := readEnvdMemUsed(sb.client)
// Disk: allocated bytes of the CoW sparse file. // Disk: allocated bytes of the CoW sparse file.
var diskBytes int64 var diskBytes int64

View File

@ -15,11 +15,11 @@ type MetricPoint struct {
// Ring buffer capacity constants. // Ring buffer capacity constants.
const ( const (
ring10mCap = 1200 // 500ms × 1200 = 10 min ring10mCap = 600 // 1s × 600 = 10 min
ring2hCap = 240 // 30s × 240 = 2 h ring2hCap = 240 // 30s × 240 = 2 h
ring24hCap = 288 // 5min × 288 = 24 h ring24hCap = 288 // 5min × 288 = 24 h
downsample2hEvery = 60 // 60 × 500ms = 30s downsample2hEvery = 30 // 30 × 1s = 30s
downsample24hEvery = 10 // 10 × 30s = 5min downsample24hEvery = 10 // 10 × 30s = 5min
) )
@ -44,8 +44,8 @@ type metricsRing struct {
count24h int count24h int
// Accumulators for downsampling. // Accumulators for downsampling.
acc500ms [downsample2hEvery]MetricPoint acc1s [downsample2hEvery]MetricPoint
acc500msN int acc1sN int
acc30s [downsample24hEvery]MetricPoint acc30s [downsample24hEvery]MetricPoint
acc30sN int acc30sN int
@ -56,7 +56,7 @@ func newMetricsRing() *metricsRing {
return &metricsRing{} return &metricsRing{}
} }
// Push adds a 500ms sample to the finest tier and triggers downsampling // Push adds a 1s sample to the finest tier and triggers downsampling
// into coarser tiers when enough samples have accumulated. // into coarser tiers when enough samples have accumulated.
func (r *metricsRing) Push(p MetricPoint) { func (r *metricsRing) Push(p MetricPoint) {
r.mu.Lock() r.mu.Lock()
@ -70,12 +70,12 @@ func (r *metricsRing) Push(p MetricPoint) {
} }
// Accumulate for 2h downsample. // Accumulate for 2h downsample.
r.acc500ms[r.acc500msN] = p r.acc1s[r.acc1sN] = p
r.acc500msN++ r.acc1sN++
if r.acc500msN == downsample2hEvery { if r.acc1sN == downsample2hEvery {
avg := averagePoints(r.acc500ms[:downsample2hEvery]) avg := averagePoints(r.acc1s[:downsample2hEvery])
r.push2h(avg) r.push2h(avg)
r.acc500msN = 0 r.acc1sN = 0
} }
} }
@ -138,7 +138,7 @@ func (r *metricsRing) Flush() (pts10m, pts2h, pts24h []MetricPoint) {
r.idx10m, r.count10m = 0, 0 r.idx10m, r.count10m = 0, 0
r.idx2h, r.count2h = 0, 0 r.idx2h, r.count2h = 0, 0
r.idx24h, r.count24h = 0, 0 r.idx24h, r.count24h = 0, 0
r.acc500msN = 0 r.acc1sN = 0
r.acc30sN = 0 r.acc30sN = 0
return pts10m, pts2h, pts24h return pts10m, pts2h, pts24h

View File

@ -1,11 +1,15 @@
package sandbox package sandbox
import ( import (
"encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
) )
// cpuStat holds raw CPU jiffies read from /proc/{pid}/stat. // cpuStat holds raw CPU jiffies read from /proc/{pid}/stat.
@ -24,16 +28,11 @@ func readCPUStat(pid int) (cpuStat, error) {
return cpuStat{}, fmt.Errorf("read stat: %w", err) return cpuStat{}, fmt.Errorf("read stat: %w", err)
} }
// /proc/{pid}/stat format: pid (comm) state fields...
// The comm field may contain spaces and parens, so find the last ')' first.
content := string(data) content := string(data)
idx := strings.LastIndex(content, ")") idx := strings.LastIndex(content, ")")
if idx < 0 { if idx < 0 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid) return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid)
} }
// After ")" there is " state field3 field4 ... fieldN"
// field1 after ')' is state (index 0), utime is field 11, stime is field 12
// (0-indexed from after the closing paren).
fields := strings.Fields(content[idx+2:]) fields := strings.Fields(content[idx+2:])
if len(fields) < 13 { if len(fields) < 13 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields)) return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields))
@ -49,27 +48,34 @@ func readCPUStat(pid int) (cpuStat, error) {
return cpuStat{utime: utime, stime: stime}, nil return cpuStat{utime: utime, stime: stime}, nil
} }
// readMemRSS reads VmRSS from /proc/{pid}/status and returns bytes. // readEnvdMemUsed fetches mem_used from envd's /metrics endpoint. Returns
func readMemRSS(pid int) (int64, error) { // guest-side total - MemAvailable (actual process memory, excluding reclaimable
path := fmt.Sprintf("/proc/%d/status", pid) // page cache). VmRSS of the Firecracker process includes guest page cache and
data, err := os.ReadFile(path) // never decreases, so this is the accurate metric for dashboard display.
func readEnvdMemUsed(client *envdclient.Client) (int64, error) {
resp, err := client.HTTPClient().Get(client.BaseURL() + "/metrics")
if err != nil { if err != nil {
return 0, fmt.Errorf("read status: %w", err) return 0, fmt.Errorf("fetch envd metrics: %w", err)
} }
for _, line := range strings.Split(string(data), "\n") { defer resp.Body.Close()
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line) if resp.StatusCode != 200 {
if len(fields) < 2 { return 0, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
return 0, fmt.Errorf("malformed VmRSS line")
} }
kb, err := strconv.ParseInt(fields[1], 10, 64)
body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return 0, fmt.Errorf("parse VmRSS: %w", err) return 0, fmt.Errorf("read envd metrics body: %w", err)
} }
return kb * 1024, nil
var m struct {
MemUsed int64 `json:"mem_used"`
} }
if err := json.Unmarshal(body, &m); err != nil {
return 0, fmt.Errorf("decode envd metrics: %w", err)
} }
return 0, fmt.Errorf("VmRSS not found in /proc/%d/status", pid)
return m.MemUsed, nil
} }
// readDiskAllocated returns the actual allocated bytes (not apparent size) // readDiskAllocated returns the actual allocated bytes (not apparent size)

View File

@ -29,6 +29,10 @@ import (
const ( const (
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
UFFD_EVENT_FORK = C.UFFD_EVENT_FORK
UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP
UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE
UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
UFFDIO_COPY = C.UFFDIO_COPY UFFDIO_COPY = C.UFFDIO_COPY
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP

View File

@ -253,8 +253,17 @@ func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
} }
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0])) msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT { event := getMsgEvent(&msg)
return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg))
switch event {
case UFFD_EVENT_PAGEFAULT:
// Handled below.
case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK:
// Non-fatal lifecycle events from the guest kernel (e.g. balloon
// deflation, mmap/munmap). No action needed — continue polling.
continue
default:
return fmt.Errorf("unexpected uffd event type: %d", event)
} }
arg := getMsgArg(&msg) arg := getMsgArg(&msg)

View File

@ -8,7 +8,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"time"
) )
// fcClient talks to the Firecracker HTTP API over a Unix socket. // fcClient talks to the Firecracker HTTP API over a Unix socket.
@ -27,7 +26,9 @@ func newFCClient(socketPath string) *fcClient {
return d.DialContext(ctx, "unix", socketPath) return d.DialContext(ctx, "unix", socketPath)
}, },
}, },
Timeout: 10 * time.Second, // No global timeout — callers pass context.Context with appropriate
// deadlines. A fixed 10s timeout was too short for snapshot/resume
// operations on large-memory VMs (20GB+ memfiles).
}, },
} }
} }
@ -136,6 +137,25 @@ func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) er
}) })
} }
// setBalloon configures the Firecracker balloon device for dynamic memory
// management. deflateOnOom lets the guest reclaim balloon pages under memory
// pressure. statsInterval enables periodic stats via GET /balloon/statistics.
// Must be called before startVM.
func (c *fcClient) setBalloon(ctx context.Context, amountMiB int, deflateOnOom bool, statsIntervalS int) error {
return c.do(ctx, http.MethodPut, "/balloon", map[string]any{
"amount_mib": amountMiB,
"deflate_on_oom": deflateOnOom,
"stats_polling_interval_s": statsIntervalS,
})
}
// updateBalloon adjusts the balloon target at runtime.
func (c *fcClient) updateBalloon(ctx context.Context, amountMiB int) error {
return c.do(ctx, http.MethodPatch, "/balloon", map[string]any{
"amount_mib": amountMiB,
})
}
// startVM issues the InstanceStart action. // startVM issues the InstanceStart action.
func (c *fcClient) startVM(ctx context.Context) error { func (c *fcClient) startVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/actions", map[string]string{ return c.do(ctx, http.MethodPut, "/actions", map[string]string{

View File

@ -119,6 +119,13 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
return fmt.Errorf("set machine config: %w", err) return fmt.Errorf("set machine config: %w", err)
} }
// Balloon device — allows the host to reclaim unused guest memory.
// Start with 0 (no inflation). deflate_on_oom lets the guest reclaim
// balloon pages under memory pressure. Stats interval enables monitoring.
if err := client.setBalloon(ctx, 0, true, 5); err != nil {
slog.Warn("set balloon failed (non-fatal, VM will run without memory reclaim)", "error", err)
}
// MMDS config — enable V2 token access on eth0 so that envd can read // MMDS config — enable V2 token access on eth0 so that envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDSConfig(ctx, "eth0"); err != nil { if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
@ -162,6 +169,19 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
return nil return nil
} }
// UpdateBalloon adjusts the balloon target for a running VM.
// amountMiB is memory to take FROM the guest (0 = give all back).
func (m *Manager) UpdateBalloon(ctx context.Context, sandboxID string, amountMiB int) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
return vm.client.updateBalloon(ctx, amountMiB)
}
// Destroy stops and cleans up a VM. // Destroy stops and cleans up a VM.
func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
m.mu.Lock() m.mu.Lock()

View File

@ -365,6 +365,14 @@ func (l *AuditLogger) LogUserDeactivate(ctx context.Context, ac auth.AuthContext
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "deactivate", "warning", map[string]any{"email": email})) l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "deactivate", "warning", map[string]any{"email": email}))
} }
func (l *AuditLogger) LogUserGrantAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) {
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "grant_admin", "success", map[string]any{"email": email}))
}
func (l *AuditLogger) LogUserRevokeAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) {
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "revoke_admin", "warning", map[string]any{"email": email}))
}
// --- Team admin events (scope: admin) --- // --- Team admin events (scope: admin) ---
func (l *AuditLogger) LogTeamSetBYOC(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, enabled bool) { func (l *AuditLogger) LogTeamSetBYOC(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, enabled bool) {

View File

@ -415,6 +415,21 @@ func (q *Queries) ListUsersAdmin(ctx context.Context, arg ListUsersAdminParams)
return items, nil return items, nil
} }
const revokeUserAdmin = `-- name: RevokeUserAdmin :execrows
UPDATE users u SET is_admin = false, updated_at = NOW()
WHERE u.id = $1
AND u.is_admin = true
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1
`
func (q *Queries) RevokeUserAdmin(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, revokeUserAdmin, id)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10 SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
` `

View File

@ -47,7 +47,7 @@ func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
MaxIdleConnsPerHost: 20, MaxIdleConnsPerHost: 20,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: 45 * time.Second, ResponseHeaderTimeout: 5 * time.Minute,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,

View File

@ -239,12 +239,26 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUI
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr, SandboxId: sandboxIDStr,
})); err != nil { })); err != nil {
// Revert status on failure. // Check if the agent still has this sandbox. If it was destroyed
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ // (e.g. frozen VM couldn't be resumed), mark as "error" instead of
ID: sandboxID, Status: "running", // reverting to "running" — which would create a ghost record.
// Use a fresh context since the original ctx may already be expired.
revertStatus := "running"
pingCtx, pingCancel := context.WithTimeout(context.Background(), 10*time.Second)
if _, pingErr := agent.PingSandbox(pingCtx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); pingErr != nil {
revertStatus = "error"
slog.Warn("sandbox gone from agent after failed pause, marking as error", "sandbox_id", sandboxIDStr)
}
pingCancel()
dbCtx, dbCancel := context.WithTimeout(context.Background(), 5*time.Second)
if _, dbErr := s.DB.UpdateSandboxStatus(dbCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: revertStatus,
}); dbErr != nil { }); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr) slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
} }
dbCancel()
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
} }

View File

@ -57,7 +57,7 @@ if [ ! -f "${ENVD_BIN}" ]; then
exit 1 exit 1
fi fi
if ! file "${ENVD_BIN}" | grep -q "statically linked"; then if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then
echo "ERROR: envd is not statically linked!" echo "ERROR: envd is not statically linked!"
exit 1 exit 1
fi fi

View File

@ -17,7 +17,8 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
ROOTFS="${1:-/var/lib/wrenn/images/minimal/rootfs.ext4}" WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}"
ROOTFS="${1:-${WRENN_DIR}/images/minimal/rootfs.ext4}"
MOUNT_DIR="/tmp/wrenn-rootfs-update" MOUNT_DIR="/tmp/wrenn-rootfs-update"
if [ ! -f "${ROOTFS}" ]; then if [ ! -f "${ROOTFS}" ]; then
@ -36,6 +37,11 @@ if [ ! -f "${ENVD_BIN}" ]; then
exit 1 exit 1
fi fi
if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then
echo "ERROR: envd is not statically linked!"
exit 1
fi
# Step 2: Mount the rootfs. # Step 2: Mount the rootfs.
echo "==> Mounting rootfs at ${MOUNT_DIR}..." echo "==> Mounting rootfs at ${MOUNT_DIR}..."
mkdir -p "${MOUNT_DIR}" mkdir -p "${MOUNT_DIR}"