1
0
forked from wrenn/wrenn

2 Commits

Author SHA1 Message Date
e503afbc0f CI Test 3
All checks were successful
ci/woodpecker/push/pipeline Pipeline was successful
2026-05-22 01:31:40 +06:00
2f167f406c v0.1.6 (#45)
## What's New?
Performance updates for large capsules, admin panel enhancement and bug fixes

### Envd
- Fixed bug with sandbox metrics calculation
- Page cache drop and balloon inflation to reduce memfile snapshot
- Updated rpc timeout logic for better control
- Added tests

### Admin Panel
- Add/Remove platform admin
- Updated template deletion logic for fine grained permission

### Others
- Minor frontend visual improvement
- Minor bugfixes
- Version bump

Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/wrenn#45
Co-authored-by: pptx704 <rafeed@omukk.dev>
Co-committed-by: pptx704 <rafeed@omukk.dev>
2026-05-22 01:27:00 +06:00
60 changed files with 2788 additions and 238 deletions

63
.woodpecker/pipeline.yml Normal file
View File

@ -0,0 +1,63 @@
when:
- event: push
branch: main
steps:
build-go:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
commands:
- pip install wrenn
- export GO_VERSION=$$(grep '^go ' go.mod | cut -d' ' -f2)
- python .woodpecker/scripts/build_go.py
depends_on: []
build-rust:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
commands:
- pip install wrenn
- export RUST_VERSION=$$(grep '^rust-version ' envd-rs/Cargo.toml | cut -d'"' -f2)
- python .woodpecker/scripts/build_rust.py
depends_on: []
tag-release:
image: python:3.13
environment:
GITEA_TOKEN:
from_secret: gitea_token
commands:
- VERSION=$$(cat VERSION_CP)
- git config user.name "R3dRum92"
- git config user.email "tksadik@omukk.dev"
- git tag "v$${VERSION}"
- git push "https://tksadik92:$${GITEA_TOKEN}@git.omukk.dev/tksadik92/wrenn-releases.git" "v$${VERSION}"
depends_on: [build-go, build-rust]
release-notes:
image: python:3.13
environment:
WRENN_API_KEY:
from_secret: wrenn_api_key
GITEA_TOKEN:
from_secret: gitea_token
ZHIPU_API_KEY:
from_secret: zhipu_api_key
commands:
- pip install wrenn
- python .woodpecker/scripts/release_notes.py
depends_on: [tag-release]
publish-github:
image: python:3.13
environment:
GITHUB_TOKEN:
from_secret: github_token
commands:
- pip install httpx
- python .woodpecker/scripts/publish_github.py
depends_on: [release-notes]

View File

@ -0,0 +1,136 @@
import os
import sys
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
from wrenn._git import GitCommandError
GO_VERSION = os.getenv("GO_VERSION", "1.25.8")
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
REPO_DIR = "/opt/wrenn"
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
def read_remote_version(capsule: Capsule, filename: str) -> str:
content = capsule.files.read_bytes(f"{REPO_DIR}/{filename}")
return content.decode("utf-8").strip()
def run(capsule: Capsule, cmd: str, timeout: int = 30) -> int:
result = capsule.commands.run(cmd, timeout=timeout)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def install_go(capsule: Capsule) -> bool:
tarball = f"go{GO_VERSION}.linux-amd64.tar.gz"
url = f"https://go.dev/dl/{tarball}"
if run(capsule, "apt update", timeout=120) != 0:
return False
if run(capsule, "apt install -y make build-essential file", timeout=300) != 0:
return False
if run(capsule, f"curl -LO {url}", timeout=120) != 0:
return False
if run(capsule, f"tar -C /usr/local -xzf {tarball}", timeout=300) != 0:
return False
if run(capsule, 'echo "export PATH=$PATH:/usr/local/go/bin" >> ~/.profile') != 0:
return False
if run(capsule, "rm -f " + tarball) != 0:
return False
result = capsule.commands.run("/usr/local/go/bin/go version")
print(result.stdout.strip())
return result.exit_code == 0
def clone_repo(capsule: Capsule) -> bool:
try:
capsule.git.clone(REPO_URL, REPO_DIR)
print("OK [git clone]")
return True
except GitCommandError as e:
print(f"FAIL [git clone]: {e}", file=sys.stderr)
return False
def build_go(capsule: Capsule) -> bool:
command = "CGO_ENABLED=1 make build-cp build-agent"
handle = capsule.commands.run(
command,
background=True,
cwd=REPO_DIR,
envs={
"PATH": "/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
},
)
print(f"{command} started (pid={handle.pid}), streaming output...")
exit_code = 0
for event in capsule.commands.connect(handle.pid):
if isinstance(event, StreamStdoutEvent):
print(event.data, end="")
elif isinstance(event, StreamStderrEvent):
print(event.data, end="", file=sys.stderr)
elif isinstance(event, StreamExitEvent):
exit_code = event.exit_code
if exit_code != 0:
print(f"FAIL [go build]: exit={exit_code}", file=sys.stderr)
return False
print("OK [go build]")
return True
def download_artifacts(capsule: Capsule) -> bool:
remote_dir = f"{REPO_DIR}/builds"
entries = capsule.files.list(remote_dir, depth=1)
files = [e for e in entries if e.type != "directory"]
if not files:
print("FAIL [download]: no files found in builds/", file=sys.stderr)
return False
local_dir = os.path.normpath(BUILDS_DIR)
os.makedirs(local_dir, exist_ok=True)
versions = {
"wrenn-cp": read_remote_version(capsule, "VERSION_CP"),
"wrenn-agent": read_remote_version(capsule, "VERSION_AGENT"),
}
for entry in files:
name = entry.name or "unknown"
remote_path = f"{remote_dir}/{name}"
local_name = f"{name}-{versions[name]}" if name in versions else name
local_path = os.path.join(local_dir, local_name)
print(f"Downloading {name} as {local_name} ({entry.size or '?'} bytes)...")
with open(local_path, "wb") as f:
for chunk in capsule.files.download_stream(remote_path):
f.write(chunk)
print(f"OK [download {local_name}]")
return True
def main() -> None:
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
print(f"Capsule: {capsule.capsule_id}")
if not install_go(capsule):
sys.exit(1)
if not clone_repo(capsule):
sys.exit(1)
if not build_go(capsule):
sys.exit(1)
if not download_artifacts(capsule):
sys.exit(1)
print("Done.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,144 @@
import os
import sys
from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent
from wrenn._git import GitCommandError
RUST_VERSION = os.getenv("RUST_VERSION", "1.95.0")
REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git"
REPO_DIR = "/opt/wrenn"
BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds")
RUST_PATH = (
"/root/.cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
)
def read_envd_version(capsule: Capsule) -> str:
content = capsule.files.read_bytes(f"{REPO_DIR}/envd-rs/Cargo.toml")
for line in content.decode("utf-8").splitlines():
stripped = line.strip()
if stripped.startswith("version ="):
return stripped.split("=", 1)[1].strip().strip('"')
print("FAIL [version]: envd-rs/Cargo.toml has no package version", file=sys.stderr)
sys.exit(1)
def run(capsule: Capsule, cmd: str, timeout: int = 30, envs={}) -> int:
result = capsule.commands.run(cmd, timeout=timeout, envs=envs)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def install_rust(capsule: Capsule) -> bool:
if run(capsule, "apt update", timeout=120) != 0:
return False
if (
run(
capsule,
"apt install -y make build-essential file curl musl-tools protobuf-compiler",
timeout=300,
)
!= 0
):
return False
if (
run(
capsule,
f"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain {RUST_VERSION}",
timeout=300,
)
!= 0
):
return False
if (
run(
capsule,
"/root/.cargo/bin/rustup target add x86_64-unknown-linux-musl",
timeout=120,
)
!= 0
):
return False
result = capsule.commands.run("/root/.cargo/bin/rustc --version")
print(result.stdout.strip())
return result.exit_code == 0
def clone_repo(capsule: Capsule) -> bool:
try:
capsule.git.clone(REPO_URL, REPO_DIR)
print("OK [git clone]")
return True
except GitCommandError as e:
print(f"FAIL [git clone]: {e}", file=sys.stderr)
return False
def build_rust(capsule: Capsule) -> bool:
if run(capsule, f"mkdir -p {REPO_DIR}/builds") != 0:
return False
handle = capsule.commands.run(
"make build-envd",
background=True,
cwd=REPO_DIR,
envs={"PATH": RUST_PATH},
)
print(f"rust build started (pid={handle.pid}), streaming output...")
exit_code = 0
for event in capsule.commands.connect(handle.pid):
if isinstance(event, StreamStdoutEvent):
print(event.data, end="")
elif isinstance(event, StreamStderrEvent):
print(event.data, end="", file=sys.stderr)
elif isinstance(event, StreamExitEvent):
exit_code = event.exit_code
if exit_code != 0:
print(f"FAIL [rust build]: exit={exit_code}", file=sys.stderr)
return False
print("OK [rust build]")
return True
def download_artifacts(capsule: Capsule) -> bool:
version = read_envd_version(capsule)
remote_path = f"{REPO_DIR}/builds/envd"
local_dir = os.path.normpath(BUILDS_DIR)
local_name = f"envd-{version}"
local_path = os.path.join(local_dir, local_name)
os.makedirs(local_dir, exist_ok=True)
print(f"Downloading envd as {local_name}...")
with open(local_path, "wb") as f:
for chunk in capsule.files.download_stream(remote_path):
f.write(chunk)
print(f"OK [download {local_name}]")
return True
def main() -> None:
with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule:
print(f"Capsule: {capsule.capsule_id}")
if not install_rust(capsule):
sys.exit(1)
if not clone_repo(capsule):
sys.exit(1)
if not build_rust(capsule):
sys.exit(1)
if not download_artifacts(capsule):
sys.exit(1)
print("Done.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,104 @@
import os
import sys
from pathlib import Path
import httpx
GITHUB_REPO = "R3dRum92/wrenn-releases"
GITHUB_API = "https://api.github.com"
GITHUB_UPLOADS = "https://uploads.github.com"
BUILDS_DIR = "builds"
VERSION_FILE = "VERSION_CP"
NOTES_FILE = os.path.join(".woodpecker", "release_notes.md")
def main() -> None:
token = os.environ["GITHUB_TOKEN"]
with open(VERSION_FILE) as f:
version = f.read().strip()
tag = f"v{version}"
release_notes = ""
if os.path.exists(NOTES_FILE):
with open(NOTES_FILE) as f:
release_notes = f.read()
headers = {
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
client = httpx.Client(headers=headers, timeout=60)
print(f"Creating GitHub release for {tag}...")
resp = client.post(
f"{GITHUB_API}/repos/{GITHUB_REPO}/releases",
json={
"tag_name": tag,
"name": tag,
"body": release_notes,
"draft": False,
"prerelease": False,
},
)
if resp.status_code == 422:
print(f"WARN [create release]: release for {tag} already exists, skipping")
data = resp.json()
errors = data.get("errors", [])
if errors:
existing_url = errors[0].get("documentation_url", "")
print(f" See: {existing_url}")
client.close()
return
if resp.status_code != 201:
print(f"FAIL [create release]: {resp.status_code} {resp.text}", file=sys.stderr)
client.close()
sys.exit(1)
release_data = resp.json()
release_id = release_data["id"]
release_url = release_data.get("html_url", "")
print(f"OK [create release] id={release_id}")
builds_path = Path(BUILDS_DIR)
if not builds_path.exists():
print(f"No {BUILDS_DIR}/ directory found, skipping asset upload")
client.close()
print(f"Release published: {release_url}")
return
upload_headers = {
**headers,
"Content-Type": "application/octet-stream",
}
for artifact in sorted(builds_path.iterdir()):
if artifact.is_dir():
continue
print(f"Uploading {artifact.name}...")
with open(artifact, "rb") as f:
data = f.read()
resp = client.post(
f"{GITHUB_UPLOADS}/repos/{GITHUB_REPO}/releases/{release_id}/assets",
params={"name": artifact.name},
headers=upload_headers,
content=data,
)
if resp.status_code != 201:
print(
f"WARN [upload {artifact.name}]: {resp.status_code} {resp.text}",
file=sys.stderr,
)
else:
print(f"OK [upload {artifact.name}]")
client.close()
print(f"Release published: {release_url}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,299 @@
import base64
import os
import sys
from wrenn import Capsule
REPO_URL = "https://git.omukk.dev/tksadik92/wrenn-releases.git"
REPO_DIR = "/opt/wrenn-releases"
CAPSULE_OUTPUT = "/tmp/release_notes.md"
LOCAL_OUTPUT = os.path.join(os.path.dirname(__file__), "..", "release_notes.md")
# Default starting configuration
ZHIPU_API_KEY = os.environ.get("ZHIPU_API_KEY", "")
if ZHIPU_API_KEY:
DEFAULT_MODEL = "zhipuai-coding-plan/glm-5.1"
else:
DEFAULT_MODEL = "opencode/minimax-m2.5-free"
RELEASE_NOTES_EXAMPLE = """
## What's new
Sandbox HTTP proxying, terminal reliability, and auth robustness improvements.
### Proxy
- Fixed redirect loops for apps served inside sandboxes (Python HTTP server, Jupyter, etc.)
- Proxy traffic no longer interferes with terminal and exec connections
- Services that take a moment to start up inside a sandbox are now retried instead of immediately failing
### Terminal (PTY)
- Terminal input is no longer blocked by slow network conditions — fast typing no longer causes timeouts or disconnects
- Input bursts are coalesced into fewer round trips — lower latency under fast typing
### Authentication
- WebSocket connections now authenticate correctly for both SDK clients (header-based) and browser clients (message-based)
### Bug Fixes
- Fixed crash in envd when a process exits without a PTY
- Fixed goroutine leak on sandbox pause
### Others
- Version bump
""".strip()
def run(capsule: Capsule, cmd: str, cwd: str | None = None, timeout: int = 30) -> int:
result = capsule.commands.run(cmd, cwd=cwd, timeout=timeout)
if result.exit_code != 0:
print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr)
if result.stderr:
print(result.stderr.strip(), file=sys.stderr)
return result.exit_code
print(f"OK [{cmd.split()[0]}]")
return 0
def get_tags(capsule: Capsule) -> tuple[str, str | None]:
result = capsule.commands.run(
f"cd {REPO_DIR} && git tag --sort=-version:refname",
cwd=REPO_DIR,
timeout=30,
)
if result.exit_code != 0:
print(f"FAIL [git tag]: {result.stderr}", file=sys.stderr)
sys.exit(1)
tags = [t for t in result.stdout.strip().split("\n") if t]
if not tags:
print("No tags found", file=sys.stderr)
sys.exit(1)
current_tag = tags[0]
previous_tag = tags[1] if len(tags) > 1 else None
print(f"Current tag: {current_tag}")
print(f"Previous tag: {previous_tag}")
return current_tag, previous_tag
def get_git_context(
capsule: Capsule, current_tag: str, previous_tag: str | None
) -> tuple[str, str]:
if previous_tag:
# FIX: Removed '-n 2' to ensure we grab ALL commits between the two tags
log_cmd = f"cd {REPO_DIR} && git log {previous_tag}..{current_tag} --pretty=format:'%s (%h)'"
else:
# Fallback to limit log size if this is the very first tag in the repo
log_cmd = (
f"cd {REPO_DIR} && git log {current_tag} --pretty=format:'%s (%h)' -n 50"
)
log_result = capsule.commands.run(log_cmd, cwd=REPO_DIR, timeout=30)
if log_result.exit_code != 0:
print(f"FAIL [git log]: {log_result.stderr}", file=sys.stderr)
sys.exit(1)
# git diff natively compares the entire tree state between tags
if previous_tag:
diff_cmd = f"cd {REPO_DIR} && git diff {previous_tag}..{current_tag} --stat"
else:
diff_cmd = f"cd {REPO_DIR} && git show {current_tag} --stat"
diff_result = capsule.commands.run(diff_cmd, cwd=REPO_DIR, timeout=30)
if diff_result.exit_code != 0:
print(f"FAIL [git diff]: {diff_result.stderr}", file=sys.stderr)
sys.exit(1)
return log_result.stdout.strip(), diff_result.stdout.strip()
def generate_release_notes(
capsule: Capsule,
output_path: str,
model: str,
) -> None:
prompt = f"""
You are inside a cloned git repository at:
{REPO_DIR}
Generate release notes for the latest tagged version of this software project.
Before writing anything, inspect the repository yourself using git commands.
You MUST determine:
1. The latest version tag.
2. The previous version tag, if one exists.
3. The commits between the previous tag and the latest tag.
4. The files and areas changed between those tags.
Use commands like:
git tag --sort=-version:refname
If there are at least two tags, compare the newest tag against the previous tag:
git log PREVIOUS_TAG..LATEST_TAG --pretty=format:'%s (%h)'
git diff PREVIOUS_TAG..LATEST_TAG --stat
git diff PREVIOUS_TAG..LATEST_TAG --name-only
If there is only one tag, inspect the latest tag with:
git log LATEST_TAG --pretty=format:'%s (%h)' -n 50
git show LATEST_TAG --stat
git show LATEST_TAG --name-only
Do not rely on any pre-injected commit list or diff summary.
You must inspect the git history yourself.
Write the release notes in plain, friendly language that any developer can understand
without deep knowledge of the codebase.
Avoid jargon like "goroutine", "PTY", "envd", or internal function names.
Describe what the change means for the user instead.
Group related changes under headings that reflect what actually changed.
Only include sections that are relevant to the actual changes.
Do not include CI/CD-only changes.
Start with:
## What's New
The very next line must be a single short summary sentence.
Keep each bullet point to one clear sentence.
Here is an example of the style to aim for — not a template to copy:
{RELEASE_NOTES_EXAMPLE}
Output only the final markdown.
No intro.
No explanation.
No conversational filler.
No acknowledgments.
No "I checked the logs" text.
No thoughts.
""".strip()
prompt_b64 = base64.b64encode(prompt.encode("utf-8")).decode("utf-8")
write_prompt_cmd = f"echo '{prompt_b64}' | base64 -d > /tmp/oc_prompt.txt"
result = capsule.commands.run(
write_prompt_cmd,
cwd=REPO_DIR,
timeout=10,
)
if result.exit_code != 0:
print(f"FAIL [write prompt]: {result.stderr}", file=sys.stderr)
sys.exit(1)
def run_opencode_with_model(target_model: str) -> int:
env = ""
if "zhipu" in target_model.lower():
env = f"ZHIPU_API_KEY={os.environ.get('ZHIPU_API_KEY', '')}"
raw_output_path = "/tmp/opencode_raw.txt"
cmd = (
f"{env} "
f"~/.opencode/bin/opencode run "
f'"Read the attached file and generate the release notes. Output ONLY markdown." '
f"--model {target_model} "
f"--file /tmp/oc_prompt.txt "
f"> {raw_output_path}"
)
cmd_result = capsule.commands.run(cmd, cwd=REPO_DIR, timeout=300)
if cmd_result.exit_code != 0:
print(
f"FAIL [opencode via {target_model}]: exit={cmd_result.exit_code}",
file=sys.stderr,
)
print(f"STDOUT:\n{cmd_result.stdout}", file=sys.stderr)
print(f"STDERR:\n{cmd_result.stderr}", file=sys.stderr)
clean_cmd = (
f"awk 'found || /^## What.s [Nn]ew/ {{ found=1; print }}' "
f"{raw_output_path} > {output_path}"
)
clean_result = capsule.commands.run(clean_cmd, cwd=REPO_DIR, timeout=10)
if clean_result.exit_code != 0:
print(f"FAIL [clean output]: {clean_result.stderr}", file=sys.stderr)
return clean_result.exit_code
check_result = capsule.commands.run(
f"grep -q '^## What.s New' {output_path}",
cwd=REPO_DIR,
timeout=10,
)
if check_result.exit_code != 0:
print(
"FAIL: Could not find release notes heading in opencode output",
file=sys.stderr,
)
print(cmd_result.stdout, file=sys.stderr)
print(cmd_result.stderr, file=sys.stderr)
return 1
return cmd_result.exit_code
# First attempt with the target model
exit_status = run_opencode_with_model(model)
# FIX: Catch failures (like Zhipu rate limits) and fallback to MiniMax
if exit_status != 0:
if "zhipu" in model.lower():
print(
"\n[!] Zhipu AI failed (likely rate-limited). Falling back to MiniMax...",
file=sys.stderr,
)
fallback_model = "opencode/minimax-m2.5-free"
exit_status = run_opencode_with_model(fallback_model)
if exit_status != 0:
print("FAIL: Fallback model also failed. Exiting.", file=sys.stderr)
sys.exit(1)
else:
sys.exit(1)
result = capsule.commands.run(f"cat {output_path}")
print(result.stdout)
if result.stderr:
print(result.stderr)
print(f"OK [opencode] release notes written to {output_path}")
def download_release_notes(capsule: Capsule) -> None:
local_path = os.path.normpath(LOCAL_OUTPUT)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
print(f"Downloading release notes from capsule...")
content = capsule.files.read_bytes(CAPSULE_OUTPUT)
with open(local_path, "wb") as f:
f.write(content)
print(f"OK [download] release notes → {local_path}")
print(content.decode("utf-8", errors="replace"))
def main() -> None:
model = os.environ.get("OPENCODE_MODEL", DEFAULT_MODEL)
with Capsule(template="opencode", wait=True, vcpus=2, memory_mb=2048) as capsule:
print(f"Capsule: {capsule.capsule_id}")
capsule.git.clone(
REPO_URL,
REPO_DIR,
)
print("OK [git clone]")
# Note: This simply creates the directory string safely
output_path = os.path.normpath(CAPSULE_OUTPUT)
generate_release_notes(capsule, output_path, model)
download_release_notes(capsule)
if __name__ == "__main__":
main()

View File

@ -27,8 +27,12 @@ build-agent:
build-envd:
cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl
@cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd
@file $(BIN_DIR)/envd | grep -q "static-pie linked" || \
(echo "ERROR: envd is not statically linked!" && exit 1)
@readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \
readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \
! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \
{ ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \
readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \
(echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1)
# ═══════════════════════════════════════════════════
# Development
@ -111,6 +115,7 @@ vet:
test:
go test -race -v ./internal/...
cd envd-rs && cargo test
test-integration:
go test -race -v -tags=integration ./tests/integration/...

View File

@ -1 +1 @@
0.1.2
0.1.3

View File

@ -1 +1 @@
0.1.5
0.1.6

View File

@ -80,6 +80,25 @@ func main() {
os.Exit(1)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Register with the control plane before touching rootfs images. If the
// agent can't reach the CP there's no point inflating images (and crashing
// afterward would leave them in the expanded state).
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M").
defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB
if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" {
@ -128,25 +147,8 @@ func main() {
mgr := sandbox.New(cfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating.
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// httpServer is declared here so the shutdown func can reference it.
// ReadTimeout/WriteTimeout are intentionally omitted — they would kill
// long-lived Connect RPC streams and WebSocket proxy connections.

View File

@ -22,6 +22,12 @@ RETURNING *;
-- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
-- name: RevokeUserAdmin :execrows
UPDATE users u SET is_admin = false, updated_at = NOW()
WHERE u.id = $1
AND u.is_admin = true
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1;
-- name: GetAdminUsers :many
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;

3
envd-rs/Cargo.lock generated
View File

@ -514,7 +514,7 @@ dependencies = [
[[package]]
name = "envd"
version = "0.2.0"
version = "0.2.1"
dependencies = [
"async-stream",
"axum",
@ -543,6 +543,7 @@ dependencies = [
"sha2",
"subtle",
"sysinfo",
"tempfile",
"tokio",
"tokio-util",
"tower",

View File

@ -1,6 +1,6 @@
[package]
name = "envd"
version = "0.2.0"
version = "0.2.1"
edition = "2024"
rust-version = "1.88"
@ -72,6 +72,9 @@ buffa = "0.3"
async-stream = "0.3.6"
mime_guess = "2"
[dev-dependencies]
tempfile = "3"
[build-dependencies]
connectrpc-build = "0.3"

View File

@ -83,3 +83,128 @@ pub fn validate_signing(
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_token(val: &[u8]) -> SecureToken {
let t = SecureToken::new();
t.set(val).unwrap();
t
}
fn far_future() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ 3600
}
#[test]
fn generate_starts_with_v1() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(sig.starts_with("v1_"));
}
#[test]
fn generate_deterministic() {
let token = test_token(b"secret");
let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn generate_with_expiration_differs() {
let token = test_token(b"secret");
let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap();
let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap();
assert_ne!(without, with);
}
#[test]
fn generate_unset_token_errors() {
let token = SecureToken::new();
assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err());
}
#[test]
fn validate_no_token_set_passes() {
let token = SecureToken::new();
assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_correct_header_token() {
let token = test_token(b"secret");
assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_wrong_header_token() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not match"));
}
#[test]
fn validate_valid_signature() {
let token = test_token(b"secret");
let exp = far_future();
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap();
assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn validate_invalid_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid signature"));
}
#[test]
fn validate_expired_signature() {
let token = test_token(b"secret");
let expired: i64 = 1_000_000;
let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap();
let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("expired"));
}
#[test]
fn validate_missing_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_empty_header_token_falls_through_to_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_valid_signature_no_expiration() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn different_operations_produce_different_signatures() {
let token = test_token(b"secret");
let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap();
let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap();
assert_ne!(r, w);
}
}

View File

@ -125,3 +125,132 @@ impl SecureToken {
Ok(token)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_is_unset() {
let t = SecureToken::new();
assert!(!t.is_set());
assert!(!t.equals("anything"));
}
#[test]
fn set_and_equals() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
assert!(t.is_set());
assert!(t.equals("secret"));
assert!(!t.equals("wrong"));
}
#[test]
fn set_empty_errors() {
let t = SecureToken::new();
assert!(t.set(b"").is_err());
assert!(!t.is_set());
}
#[test]
fn set_overwrites_previous() {
let t = SecureToken::new();
t.set(b"first").unwrap();
t.set(b"second").unwrap();
assert!(!t.equals("first"));
assert!(t.equals("second"));
}
#[test]
fn destroy_clears() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
t.destroy();
assert!(!t.is_set());
assert!(!t.equals("secret"));
}
#[test]
fn bytes_returns_copy() {
let t = SecureToken::new();
assert!(t.bytes().is_none());
t.set(b"hello").unwrap();
assert_eq!(t.bytes().unwrap(), b"hello");
}
#[test]
fn take_from_transfers_and_clears_source() {
let src = SecureToken::new();
src.set(b"token").unwrap();
let dst = SecureToken::new();
dst.take_from(&src);
assert!(!src.is_set());
assert!(dst.equals("token"));
}
#[test]
fn take_from_overwrites_existing() {
let src = SecureToken::new();
src.set(b"new").unwrap();
let dst = SecureToken::new();
dst.set(b"old").unwrap();
dst.take_from(&src);
assert!(dst.equals("new"));
assert!(!dst.equals("old"));
}
#[test]
fn equals_secure_matching() {
let a = SecureToken::new();
a.set(b"same").unwrap();
let b = SecureToken::new();
b.set(b"same").unwrap();
assert!(a.equals_secure(&b));
}
#[test]
fn equals_secure_different() {
let a = SecureToken::new();
a.set(b"one").unwrap();
let b = SecureToken::new();
b.set(b"two").unwrap();
assert!(!a.equals_secure(&b));
}
#[test]
fn equals_secure_unset() {
let a = SecureToken::new();
let b = SecureToken::new();
assert!(!a.equals_secure(&b));
}
#[test]
fn from_json_bytes_valid() {
let mut data = b"\"mysecret\"".to_vec();
let t = SecureToken::from_json_bytes(&mut data).unwrap();
assert!(t.equals("mysecret"));
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_missing_quotes() {
let mut data = b"noquotes".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_escape_sequences() {
let mut data = b"\"has\\nescapes\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_empty_content() {
let mut data = b"\"\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
}

View File

@ -76,4 +76,125 @@ impl ConnTracker {
pub fn keepalives_enabled(&self) -> bool {
self.inner.lock().unwrap().keepalives_enabled
}
#[cfg(test)]
fn active_count(&self) -> usize {
self.inner.lock().unwrap().active.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_assigns_sequential_ids() {
let ct = ConnTracker::new();
assert_eq!(ct.register_connection(), 0);
assert_eq!(ct.register_connection(), 1);
assert_eq!(ct.register_connection(), 2);
}
#[test]
fn remove_clears_active() {
let ct = ConnTracker::new();
let id = ct.register_connection();
assert_eq!(ct.active_count(), 1);
ct.remove_connection(id);
assert_eq!(ct.active_count(), 0);
}
#[test]
fn remove_nonexistent_is_noop() {
let ct = ConnTracker::new();
ct.remove_connection(999);
assert_eq!(ct.active_count(), 0);
}
#[test]
fn prepare_disables_keepalives() {
let ct = ConnTracker::new();
assert!(ct.keepalives_enabled());
ct.register_connection();
ct.prepare_for_snapshot();
assert!(!ct.keepalives_enabled());
}
#[test]
fn restore_removes_zombies_and_reenables_keepalives() {
let ct = ConnTracker::new();
let id0 = ct.register_connection();
let id1 = ct.register_connection();
ct.prepare_for_snapshot();
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
// Both pre-snapshot connections removed as zombies
assert_eq!(ct.active_count(), 0);
// IDs don't matter anymore, but remove shouldn't panic
ct.remove_connection(id0);
ct.remove_connection(id1);
}
#[test]
fn restore_without_prepare_is_noop() {
let ct = ConnTracker::new();
let _id = ct.register_connection();
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
assert_eq!(ct.active_count(), 1);
}
#[test]
fn connection_closed_before_restore_not_zombie() {
let ct = ConnTracker::new();
let id0 = ct.register_connection();
let _id1 = ct.register_connection();
ct.prepare_for_snapshot();
// Close id0 during snapshot window
ct.remove_connection(id0);
assert_eq!(ct.active_count(), 1);
ct.restore_after_snapshot();
// id1 was zombie (still active at restore), id0 already gone
assert_eq!(ct.active_count(), 0);
}
#[test]
fn post_snapshot_connection_survives_restore() {
let ct = ConnTracker::new();
ct.register_connection();
ct.prepare_for_snapshot();
// New connection after snapshot
let _post = ct.register_connection();
ct.restore_after_snapshot();
// Pre-snapshot connection removed, post-snapshot survives
assert_eq!(ct.active_count(), 1);
}
#[test]
fn full_lifecycle() {
let ct = ConnTracker::new();
let _a = ct.register_connection();
let b = ct.register_connection();
let _c = ct.register_connection();
assert_eq!(ct.active_count(), 3);
assert!(ct.keepalives_enabled());
ct.prepare_for_snapshot();
assert!(!ct.keepalives_enabled());
let d = ct.register_connection();
ct.remove_connection(b);
ct.restore_after_snapshot();
assert!(ct.keepalives_enabled());
// a and c were zombies, b removed before restore, d is post-snapshot
assert_eq!(ct.active_count(), 1);
ct.remove_connection(d);
assert_eq!(ct.active_count(), 0);
// Can reuse tracker after restore
let e = ct.register_connection();
assert_eq!(ct.active_count(), 1);
assert!(e > d);
}
}

View File

@ -15,8 +15,29 @@ mod tests {
use super::*;
#[test]
fn test_hmac_sha256() {
let result = compute(b"key", b"message");
assert_eq!(result.len(), 64); // SHA-256 hex = 64 chars
fn rfc4231_tc1() {
let key = &[0x0b; 20];
let data = b"Hi There";
assert_eq!(
compute(key, data),
"b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
);
}
#[test]
fn rfc4231_tc2() {
let key = b"Jefe";
let data = b"what do ya want for nothing?";
assert_eq!(
compute(key, data),
"5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
);
}
#[test]
fn output_is_64_hex_chars() {
let result = compute(b"key", b"data");
assert_eq!(result.len(), 64);
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
}
}

View File

@ -17,17 +17,38 @@ pub fn hash_without_prefix(data: &[u8]) -> String {
mod tests {
use super::*;
const VECTORS: &[(&[u8], &str)] = &[
(b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"),
(b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"),
(b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"),
];
#[test]
fn test_hash_format() {
let result = hash(b"test");
assert!(result.starts_with("$sha256$"));
assert!(!result.contains('='));
fn known_answer_with_prefix() {
for (input, expected_b64) in VECTORS {
let result = hash(input);
assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input));
}
}
#[test]
fn test_hash_without_prefix() {
let result = hash_without_prefix(b"test");
assert!(!result.starts_with("$sha256$"));
assert!(!result.contains('='));
fn known_answer_without_prefix() {
for (input, expected_b64) in VECTORS {
let result = hash_without_prefix(input);
assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input));
}
}
#[test]
fn no_base64_padding() {
for (input, _) in VECTORS {
assert!(!hash(input).contains('='));
assert!(!hash_without_prefix(input).contains('='));
}
}
#[test]
fn deterministic() {
assert_eq!(hash(b"test"), hash(b"test"));
}
}

View File

@ -14,11 +14,30 @@ pub fn hash_access_token_bytes(token: &[u8]) -> String {
mod tests {
use super::*;
const VECTORS: &[(&str, &str)] = &[
("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"),
("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"),
("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"),
];
#[test]
fn test_hash_access_token() {
let h1 = hash_access_token("test");
let h2 = hash_access_token_bytes(b"test");
assert_eq!(h1, h2);
assert_eq!(h1.len(), 128); // SHA-512 hex = 128 chars
fn known_answer() {
for (input, expected) in VECTORS {
assert_eq!(hash_access_token(input), *expected, "input: {input:?}");
}
}
#[test]
fn str_and_bytes_agree() {
for (input, _) in VECTORS {
assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes()));
}
}
#[test]
fn output_is_lowercase_hex_128_chars() {
let h = hash_access_token("anything");
assert_eq!(h.len(), 128);
assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
}
}

View File

@ -1,21 +1,36 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct Defaults {
pub env_vars: Arc<DashMap<String, String>>,
pub user: String,
pub workdir: Option<String>,
user: RwLock<String>,
workdir: RwLock<Option<String>>,
}
impl Defaults {
pub fn new(user: &str) -> Self {
Self {
env_vars: Arc::new(DashMap::new()),
user: user.to_string(),
workdir: None,
user: RwLock::new(user.to_string()),
workdir: RwLock::new(None),
}
}
pub fn user(&self) -> String {
self.user.read().unwrap().clone()
}
pub fn set_user(&self, user: String) {
*self.user.write().unwrap() = user;
}
pub fn workdir(&self) -> Option<String> {
self.workdir.read().unwrap().clone()
}
pub fn set_workdir(&self, workdir: Option<String>) {
*self.workdir.write().unwrap() = workdir;
}
}
pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String {
@ -40,3 +55,64 @@ pub fn resolve_default_username<'a>(
}
Err("username not provided")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn workdir_explicit_overrides_default() {
assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit");
}
#[test]
fn workdir_empty_uses_default() {
assert_eq!(resolve_default_workdir("", Some("/default")), "/default");
}
#[test]
fn workdir_empty_no_default_returns_empty() {
assert_eq!(resolve_default_workdir("", None), "");
}
#[test]
fn workdir_explicit_ignores_none_default() {
assert_eq!(resolve_default_workdir("/explicit", None), "/explicit");
}
#[test]
fn username_explicit_returns_explicit() {
assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root");
}
#[test]
fn username_none_uses_default() {
assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn");
}
#[test]
fn username_none_empty_default_errors() {
assert!(resolve_default_username(None, "").is_err());
}
#[test]
fn username_some_overrides_empty_default() {
assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root");
}
#[test]
fn defaults_user_set_and_get() {
let d = Defaults::new("initial");
assert_eq!(d.user(), "initial");
d.set_user("changed".into());
assert_eq!(d.user(), "changed");
}
#[test]
fn defaults_workdir_initially_none() {
let d = Defaults::new("user");
assert!(d.workdir().is_none());
d.set_workdir(Some("/home".into()));
assert_eq!(d.workdir().unwrap(), "/home");
}
}

View File

@ -145,3 +145,192 @@ pub fn parse_content_encoding<B>(r: &Request<B>) -> Result<&'static str, String>
Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}"))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
fn req_with_accept(v: &str) -> Request<()> {
Request::builder()
.header("accept-encoding", v)
.body(())
.unwrap()
}
fn req_with_content(v: &str) -> Request<()> {
Request::builder()
.header("content-encoding", v)
.body(())
.unwrap()
}
fn req_no_headers() -> Request<()> {
Request::builder().body(()).unwrap()
}
// parse_encoding_with_quality
#[test]
fn encoding_quality_default_1() {
let eq = parse_encoding_with_quality("gzip");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 1.0);
}
#[test]
fn encoding_quality_explicit() {
let eq = parse_encoding_with_quality("gzip;q=0.8");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.8);
}
#[test]
fn encoding_quality_case_insensitive() {
let eq = parse_encoding_with_quality("GZIP;Q=0.5");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.5);
}
#[test]
fn encoding_quality_zero() {
let eq = parse_encoding_with_quality("gzip;q=0");
assert_eq!(eq.quality, 0.0);
}
#[test]
fn encoding_quality_whitespace_trimmed() {
let eq = parse_encoding_with_quality(" gzip ; q=0.9 ");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.9);
}
// parse_accept_encoding_header
#[test]
fn accept_header_empty() {
let (encs, rejected) = parse_accept_encoding_header("");
assert!(encs.is_empty());
assert!(!rejected);
}
#[test]
fn accept_header_identity_q0_rejects() {
let (_, rejected) = parse_accept_encoding_header("identity;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_rejects_identity() {
let (_, rejected) = parse_accept_encoding_header("*;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_but_identity_explicit_accepted() {
let (_, rejected) = parse_accept_encoding_header("*;q=0, identity");
assert!(!rejected);
}
// parse_accept_encoding (full)
#[test]
fn accept_encoding_no_header_returns_identity() {
assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip() {
assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip");
}
#[test]
fn accept_encoding_identity_explicit() {
assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip_higher_quality() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_wildcard_returns_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity");
}
#[test]
fn accept_encoding_wildcard_identity_rejected_returns_gzip() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_all_rejected_errors() {
assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err());
}
#[test]
fn accept_encoding_unsupported_only_falls_to_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity");
}
// is_identity_acceptable
#[test]
fn identity_acceptable_no_header() {
assert!(is_identity_acceptable(&req_no_headers()));
}
#[test]
fn identity_acceptable_gzip_only() {
assert!(is_identity_acceptable(&req_with_accept("gzip")));
}
#[test]
fn identity_not_acceptable_identity_q0() {
assert!(!is_identity_acceptable(&req_with_accept("identity;q=0")));
}
#[test]
fn identity_not_acceptable_wildcard_q0() {
assert!(!is_identity_acceptable(&req_with_accept("*;q=0")));
}
#[test]
fn identity_acceptable_wildcard_q0_but_identity_explicit() {
assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity")));
}
// parse_content_encoding
#[test]
fn content_encoding_empty_returns_identity() {
assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn content_encoding_gzip() {
assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip");
}
#[test]
fn content_encoding_identity_explicit() {
assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity");
}
#[test]
fn content_encoding_unsupported_errors() {
assert!(parse_content_encoding(&req_with_content("br")).is_err());
}
#[test]
fn content_encoding_case_insensitive() {
assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip");
}
}

View File

@ -71,9 +71,10 @@ pub async fn get_files(
let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username(
params.username.as_deref(),
&state.defaults.user,
&default_user,
) {
Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
@ -96,7 +97,8 @@ pub async fn get_files(
};
let home_dir = user.dir.to_string_lossy().to_string();
let resolved = match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref())
let default_workdir = state.defaults.workdir();
let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref())
{
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
@ -222,9 +224,10 @@ pub async fn post_files(
let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username(
params.username.as_deref(),
&state.defaults.user,
&default_user,
) {
Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
@ -266,6 +269,7 @@ pub async fn post_files(
};
let mut uploaded: Vec<EntryInfo> = Vec::new();
let default_workdir = state.defaults.workdir();
while let Ok(Some(field)) = multipart.next_field().await {
let field_name = field.name().unwrap_or("").to_string();
@ -274,7 +278,7 @@ pub async fn post_files(
}
let file_path = if !path_str.is_empty() {
match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) {
match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) {
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
}
@ -283,7 +287,7 @@ pub async fn post_files(
.file_name()
.unwrap_or("upload")
.to_string();
match expand_and_resolve(&fname, &home_dir, state.defaults.workdir.as_deref()) {
match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) {
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
}

View File

@ -29,6 +29,8 @@ pub async fn get_health(State(state): State<Arc<AppState>>) -> impl IntoResponse
fn post_restore_recovery(state: &AppState) {
tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
state.conn_tracker.restore_after_snapshot();
tracing::info!("restore: zombie connections closed");

View File

@ -78,11 +78,15 @@ pub async fn post_init(
if let Some(ref user) = init_req.default_user {
if !user.is_empty() {
tracing::debug!(user = %user, "setting default user");
let mut defaults = state.defaults.clone();
defaults.user = user.clone();
// Note: In Rust we'd need interior mutability for this.
// For now, env_vars (DashMap) handles concurrent access.
// User/workdir mutation deferred to full state refactor.
state.defaults.set_user(user.clone());
}
}
// Set default workdir
if let Some(ref workdir) = init_req.default_workdir {
if !workdir.is_empty() {
tracing::debug!(workdir = %workdir, "setting default workdir");
state.defaults.set_workdir(Some(workdir.clone()));
}
}
@ -147,6 +151,9 @@ async fn trigger_restore_and_respond(state: &AppState) -> axum::response::Respon
fn post_restore_recovery(state: &AppState) {
tracing::info!("restore: post-restore recovery (no GC needed in Rust)");
state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release);
state.conn_tracker.restore_after_snapshot();
if let Some(ref ps) = state.port_subsystem {

View File

@ -46,7 +46,8 @@ fn collect_metrics(state: &AppState) -> Result<Metrics, String> {
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let mem_total = sys.total_memory();
let mem_used = sys.used_memory();
let mem_available = sys.available_memory();
let mem_used = mem_total.saturating_sub(mem_available);
let mem_total_mib = mem_total / 1024 / 1024;
let mem_used_mib = mem_used / 1024 / 1024;

View File

@ -10,10 +10,24 @@ use crate::state::AppState;
/// POST /snapshot/prepare — quiesce subsystems before Firecracker snapshot.
///
/// In Rust there is no GC dance. We just:
/// 1. Stop port subsystem
/// 2. Close idle connections via conntracker
/// 3. Set needs_restore flag
/// 1. Drop page cache to shrink snapshot size
/// 2. Stop port subsystem
/// 3. Close idle connections via conntracker
/// 4. Set needs_restore flag
pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse {
// Drop page cache BEFORE blocking the reclaimer — avoids snapshotting
// gigabytes of stale cache that inflates the memory dump on disk.
// "1" = pagecache only (keep dentries/inodes for faster resume).
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "1") {
tracing::warn!(error = %e, "snapshot/prepare: drop_caches failed");
} else {
tracing::info!("snapshot/prepare: page cache dropped");
}
// Block memory reclaimer — prevents drop_caches from running mid-freeze
// which would corrupt kernel page table state.
state.snapshot_in_progress.store(true, Ordering::Release);
if let Some(ref ps) = state.port_subsystem {
ps.stop();
tracing::info!("snapshot/prepare: port subsystem stopped");
@ -22,6 +36,9 @@ pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl I
state.conn_tracker.prepare_for_snapshot();
tracing::info!("snapshot/prepare: connections prepared");
// Sync filesystem buffers so dirty pages are flushed before freeze.
unsafe { libc::sync(); }
state.needs_restore.store(true, Ordering::Release);
tracing::info!("snapshot/prepare: ready for freeze");

View File

@ -147,6 +147,14 @@ async fn main() {
Some(Arc::clone(&port_subsystem)),
);
// Memory reclaimer — drop page cache when available memory is low.
// Firecracker balloon device can only reclaim pages the guest kernel freed.
// Pauses during snapshot/prepare to avoid corrupting kernel page table state.
if !cli.is_not_fc {
let state_for_reclaimer = Arc::clone(&state);
std::thread::spawn(move || memory_reclaimer(state_for_reclaimer));
}
// RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
let connect_router = rpc::rpc_router(Arc::clone(&state));
@ -188,7 +196,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
use crate::rpc::process_handler;
use std::collections::HashMap;
let user = match lookup_user(&state.defaults.user) {
let default_user = state.defaults.user();
let user = match lookup_user(&default_user) {
Ok(u) => u,
Err(e) => {
tracing::error!(error = %e, "cmd: failed to lookup user");
@ -197,9 +206,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
};
let home = user.dir.to_string_lossy().to_string();
let cwd = state
.defaults
.workdir
let default_workdir = state.defaults.workdir();
let cwd = default_workdir
.as_deref()
.unwrap_or(&home);
@ -214,11 +222,52 @@ fn spawn_initial_command(cmd: &str, state: &AppState) {
&user,
&state.defaults.env_vars,
) {
Ok(handle) => {
tracing::info!(pid = handle.pid, cmd, "initial command spawned");
Ok(spawned) => {
tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned");
}
Err(e) => {
tracing::error!(error = %e, cmd, "failed to spawn initial command");
}
}
}
fn memory_reclaimer(state: Arc<AppState>) {
use std::sync::atomic::Ordering;
const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
const DROP_THRESHOLD_PCT: u64 = 80;
loop {
std::thread::sleep(CHECK_INTERVAL);
if state.snapshot_in_progress.load(Ordering::Acquire) {
continue;
}
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let total = sys.total_memory();
let available = sys.available_memory();
if total == 0 {
continue;
}
let used_pct = ((total - available) * 100) / total;
if used_pct >= DROP_THRESHOLD_PCT {
if state.snapshot_in_progress.load(Ordering::Acquire) {
continue;
}
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
tracing::debug!(error = %e, "drop_caches failed");
} else {
let mut sys2 = sysinfo::System::new();
sys2.refresh_memory();
let freed_mb =
sys2.available_memory().saturating_sub(available) / (1024 * 1024);
tracing::info!(used_pct, freed_mb, "page cache dropped");
}
}
}
}

View File

@ -70,3 +70,115 @@ pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// expand_tilde
#[test]
fn tilde_empty_passthrough() {
assert_eq!(expand_tilde("", "/home/u").unwrap(), "");
}
#[test]
fn tilde_no_tilde_passthrough() {
assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute");
}
#[test]
fn tilde_bare() {
assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user");
}
#[test]
fn tilde_slash_path() {
assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs");
}
#[test]
fn tilde_nested() {
assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c");
}
#[test]
fn tilde_other_user_errors() {
assert!(expand_tilde("~bob/foo", "/home/user").is_err());
}
#[test]
fn tilde_relative_no_tilde() {
assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path");
}
// expand_and_resolve
#[test]
fn resolve_absolute_passthrough() {
assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path");
}
#[test]
fn resolve_empty_uses_default() {
assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default");
}
#[test]
fn resolve_empty_no_default_falls_back_to_home() {
// Empty path with no default → joins "" with home_dir → returns home_dir
let result = expand_and_resolve("", "/home", None).unwrap();
assert_eq!(result, "/home");
}
#[test]
fn resolve_tilde_expands() {
assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir");
}
#[test]
fn resolve_relative_joins_home() {
let result = expand_and_resolve("subdir", "/tmp", None).unwrap();
// Relative path joined with home and canonicalized (or raw join on missing)
assert!(result.starts_with("/tmp"));
assert!(result.contains("subdir"));
}
#[test]
fn resolve_tilde_other_user_errors() {
assert!(expand_and_resolve("~bob", "/home/u", None).is_err());
}
// ensure_dirs
#[test]
fn ensure_dirs_creates_nested() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("a/b/c");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap();
assert!(path.is_dir());
}
#[test]
fn ensure_dirs_existing_is_ok() {
let tmp = tempfile::TempDir::new().unwrap();
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap();
}
#[test]
fn ensure_dirs_file_in_path_errors() {
let tmp = tempfile::TempDir::new().unwrap();
let file_path = tmp.path().join("afile");
std::fs::write(&file_path, "").unwrap();
let nested = file_path.join("subdir");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
let result = ensure_dirs(nested.to_str().unwrap(), uid, gid);
assert!(result.is_err());
assert!(result.unwrap_err().contains("path is a file"));
}
}

View File

@ -110,3 +110,151 @@ fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> {
Some((ip_str, port))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
// tcp_state_name
#[test]
fn state_all_known_codes() {
assert_eq!(tcp_state_name("01"), "ESTABLISHED");
assert_eq!(tcp_state_name("02"), "SYN_SENT");
assert_eq!(tcp_state_name("03"), "SYN_RECV");
assert_eq!(tcp_state_name("04"), "FIN_WAIT1");
assert_eq!(tcp_state_name("05"), "FIN_WAIT2");
assert_eq!(tcp_state_name("06"), "TIME_WAIT");
assert_eq!(tcp_state_name("07"), "CLOSE");
assert_eq!(tcp_state_name("08"), "CLOSE_WAIT");
assert_eq!(tcp_state_name("09"), "LAST_ACK");
assert_eq!(tcp_state_name("0A"), "LISTEN");
assert_eq!(tcp_state_name("0B"), "CLOSING");
}
#[test]
fn state_unknown_code() {
assert_eq!(tcp_state_name("FF"), "UNKNOWN");
assert_eq!(tcp_state_name("00"), "UNKNOWN");
}
// parse_hex_addr
#[test]
fn ipv4_localhost() {
let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "127.0.0.1");
assert_eq!(port, 80);
}
#[test]
fn ipv4_any() {
let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "0.0.0.0");
assert_eq!(port, 53);
}
#[test]
fn ipv4_real_address() {
// 192.168.1.1 in little-endian = 0101A8C0
let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "192.168.1.1");
assert_eq!(port, 443);
}
#[test]
fn ipv4_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn invalid_hex_returns_none() {
assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn no_colon_returns_none() {
assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none());
}
#[test]
fn ipv6_loopback() {
// ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000
let (ip, port) = parse_hex_addr(
"00000000000000000000000001000000:0035",
libc::AF_INET6 as u32,
)
.unwrap();
assert_eq!(ip, "::1");
assert_eq!(port, 53);
}
#[test]
fn ipv6_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none());
}
// parse_proc_net_tcp
fn write_tcp_file(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
#[test]
fn parse_empty_file() {
let f = write_tcp_file(
" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n",
);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert!(conns.is_empty());
}
#[test]
fn parse_single_entry() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
assert_eq!(conns[0].local_ip, "127.0.0.1");
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[0].status, "LISTEN");
assert_eq!(conns[0].inode, 12345);
assert_eq!(conns[0].family, libc::AF_INET as u32);
}
#[test]
fn parse_skips_malformed_rows() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000
bad line
1: short\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
}
#[test]
fn parse_multiple_entries() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000
1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 2);
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[1].local_port, 443);
}
#[test]
fn parse_nonexistent_file_errors() {
assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err());
}
}

View File

@ -140,3 +140,92 @@ fn format_permissions(mode: u32) -> String {
}
s
}
#[cfg(test)]
mod tests {
use super::*;
// format_permissions
#[test]
fn regular_file_755() {
assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x");
}
#[test]
fn directory_755() {
assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x");
}
#[test]
fn symlink_777() {
assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx");
}
#[test]
fn regular_file_000() {
assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------");
}
#[test]
fn regular_file_644() {
assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--");
}
#[test]
fn block_device() {
assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----");
}
#[test]
fn char_device() {
assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-");
}
#[test]
fn fifo() {
assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--");
}
#[test]
fn socket() {
assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x");
}
#[test]
fn unknown_type() {
assert_eq!(format_permissions(0o755), "?rwxr-xr-x");
}
#[test]
fn setuid_in_mode_only_affects_lower_bits() {
// setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755
assert_eq!(
format_permissions(libc::S_IFREG | 0o4755),
format_permissions(libc::S_IFREG | 0o755),
);
}
#[test]
fn output_always_10_chars() {
for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] {
assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10);
}
}
// meta_to_file_type — needs real filesystem
#[test]
fn meta_regular_file() {
let f = tempfile::NamedTempFile::new().unwrap();
let meta = std::fs::metadata(f.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE);
}
#[test]
fn meta_directory() {
let d = tempfile::TempDir::new().unwrap();
let meta = std::fs::metadata(d.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY);
}
}

View File

@ -31,15 +31,15 @@ impl FilesystemServiceImpl {
}
fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> {
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user.clone());
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user());
let user = lookup_user(&username).map_err(|e| {
ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}"))
})?;
let home_dir = user.dir.to_string_lossy().to_string();
let default_workdir = self.state.defaults.workdir.as_deref();
let default_workdir = self.state.defaults.workdir();
expand_and_resolve(path, &home_dir, default_workdir)
expand_and_resolve(path, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))
}
}
@ -97,7 +97,7 @@ impl Filesystem for FilesystemServiceImpl {
}
}
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone());
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
@ -122,7 +122,7 @@ impl Filesystem for FilesystemServiceImpl {
let source = self.resolve_path(request.source, &ctx)?;
let destination = self.resolve_path(request.destination, &ctx)?;
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone());
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;

View File

@ -37,6 +37,7 @@ pub struct ProcessHandle {
data_tx: broadcast::Sender<DataEvent>,
end_tx: broadcast::Sender<EndEvent>,
ended: Mutex<Option<EndEvent>>,
stdin: Mutex<Option<std::process::ChildStdin>>,
pty_master: Mutex<Option<std::fs::File>>,
@ -51,6 +52,10 @@ impl ProcessHandle {
self.end_tx.subscribe()
}
pub fn cached_end(&self) -> Option<EndEvent> {
self.ended.lock().unwrap().clone()
}
pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> {
signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}"))
@ -128,6 +133,12 @@ impl ProcessHandle {
}
}
pub struct SpawnedProcess {
pub handle: Arc<ProcessHandle>,
pub data_rx: broadcast::Receiver<DataEvent>,
pub end_rx: broadcast::Receiver<EndEvent>,
}
pub fn spawn_process(
cmd_str: &str,
args: &[String],
@ -138,7 +149,7 @@ pub fn spawn_process(
tag: Option<String>,
user: &nix::unistd::User,
default_env_vars: &dashmap::DashMap<String, String>,
) -> Result<Arc<ProcessHandle>, ConnectError> {
) -> Result<SpawnedProcess, ConnectError> {
let mut env: Vec<(String, String)> = Vec::new();
env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default()));
let home = user.dir.to_string_lossy().to_string();
@ -244,10 +255,14 @@ pub fn spawn_process(
pid,
data_tx: data_tx.clone(),
end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(None),
pty_master: Mutex::new(Some(master_file)),
});
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
let data_tx_clone = data_tx.clone();
std::thread::spawn(move || {
let mut master = master_clone;
@ -264,30 +279,29 @@ pub fn spawn_process(
});
let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || {
let mut child = child;
match child.wait() {
Ok(s) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
});
}
Err(e) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
});
}
}
let end_event = match child.wait() {
Ok(s) => EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
},
Err(e) => EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
},
};
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
});
tracing::info!(pid, cmd = cmd_str, "process started (pty)");
Ok(handle)
Ok(SpawnedProcess { handle, data_rx, end_rx })
} else {
let mut command = std::process::Command::new("/bin/sh");
command
@ -327,10 +341,14 @@ pub fn spawn_process(
pid,
data_tx: data_tx.clone(),
end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(stdin),
pty_master: Mutex::new(None),
});
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
if let Some(mut out) = stdout {
let tx = data_tx.clone();
std::thread::spawn(move || {
@ -364,29 +382,28 @@ pub fn spawn_process(
}
let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || {
match child.wait() {
Ok(s) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
});
}
Err(e) => {
let _ = end_tx_clone.send(EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
});
}
}
let end_event = match child.wait() {
Ok(s) => EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
},
Err(e) => EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
},
};
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
});
tracing::info!(pid, cmd = cmd_str, "process started (pipe)");
Ok(handle)
Ok(SpawnedProcess { handle, data_rx, end_rx })
}
}

View File

@ -66,12 +66,12 @@ impl ProcessServiceImpl {
fn spawn_from_request(
&self,
request: &StartRequestView<'_>,
) -> Result<Arc<ProcessHandle>, ConnectError> {
) -> Result<process_handler::SpawnedProcess, ConnectError> {
let proc_config = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process config required")
})?;
let username = self.state.defaults.user.clone();
let username = self.state.defaults.user();
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
@ -85,7 +85,8 @@ impl ProcessServiceImpl {
let home_dir = user.dir.to_string_lossy().to_string();
let cwd_str: &str = proc_config.cwd.unwrap_or("");
let cwd = expand_and_resolve(cwd_str, &home_dir, self.state.defaults.workdir.as_deref())
let default_workdir = self.state.defaults.workdir();
let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?;
let effective_cwd = if cwd.is_empty() { "/" } else { &cwd };
@ -116,7 +117,7 @@ impl ProcessServiceImpl {
"process.Start request"
);
let handle = process_handler::spawn_process(
let spawned = process_handler::spawn_process(
cmd,
&args,
&envs,
@ -128,17 +129,17 @@ impl ProcessServiceImpl {
&self.state.defaults.env_vars,
)?;
self.processes.insert(handle.pid, Arc::clone(&handle));
self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle));
let processes = self.processes.clone();
let pid = handle.pid;
let mut end_rx = handle.subscribe_end();
let pid = spawned.handle.pid;
let mut cleanup_end_rx = spawned.handle.subscribe_end();
tokio::spawn(async move {
let _ = end_rx.recv().await;
let _ = cleanup_end_rx.recv().await;
processes.remove(&pid);
});
Ok(handle)
Ok(spawned)
}
}
@ -182,26 +183,36 @@ impl Process for ProcessServiceImpl {
),
ConnectError,
> {
let handle = self.spawn_from_request(&request)?;
let pid = handle.pid;
let spawned = self.spawn_from_request(&request)?;
let pid = spawned.handle.pid;
let mut data_rx = handle.subscribe_data();
let mut end_rx = handle.subscribe_end();
let mut data_rx = spawned.data_rx;
let mut end_rx = spawned.end_rx;
let stream = async_stream::stream! {
yield Ok(make_start_response(pid));
loop {
match data_rx.recv().await {
Ok(ev) => yield Ok(make_data_start_response(ev)),
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => yield Ok(make_data_start_response(ev)),
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
end = end_rx.recv() => {
while let Ok(ev) = data_rx.try_recv() {
yield Ok(make_data_start_response(ev));
}
if let Ok(end) = end {
yield Ok(make_end_start_response(end));
}
break;
}
}
}
if let Ok(end) = end_rx.recv().await {
yield Ok(make_end_start_response(end));
}
};
Ok((Box::pin(stream), ctx))
@ -226,6 +237,7 @@ impl Process for ProcessServiceImpl {
let mut data_rx = handle.subscribe_data();
let mut end_rx = handle.subscribe_end();
let cached_end = handle.cached_end();
let stream = async_stream::stream! {
yield Ok(ConnectResponse {
@ -238,24 +250,44 @@ impl Process for ProcessServiceImpl {
..Default::default()
});
loop {
match data_rx.recv().await {
Ok(ev) => {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
});
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
if let Ok(end) = end_rx.recv().await {
if let Some(end) = cached_end {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
} else {
loop {
tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
});
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
end = end_rx.recv() => {
while let Ok(ev) = data_rx.try_recv() {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
});
}
if let Ok(end) = end {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
}
break;
}
}
}
}
};

View File

@ -19,6 +19,7 @@ pub struct AppState {
pub port_subsystem: Option<Arc<PortSubsystem>>,
pub cpu_used_pct: AtomicU32,
pub cpu_count: AtomicU32,
pub snapshot_in_progress: AtomicBool,
}
impl AppState {
@ -41,6 +42,7 @@ impl AppState {
port_subsystem,
cpu_used_pct: AtomicU32::new(0),
cpu_count: AtomicU32::new(0),
snapshot_in_progress: AtomicBool::new(false),
});
let state_clone = Arc::clone(&state);

View File

@ -11,6 +11,10 @@ impl AtomicMax {
}
}
pub fn get(&self) -> i64 {
self.val.load(Ordering::Acquire)
}
/// Sets the stored value to `new` if `new` is strictly greater than
/// the current value. Returns `true` if the value was updated.
pub fn set_to_greater(&self, new: i64) -> bool {
@ -31,3 +35,68 @@ impl AtomicMax {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn initial_value_is_i64_min() {
let m = AtomicMax::new();
assert_eq!(m.get(), i64::MIN);
}
#[test]
fn updates_when_larger() {
let m = AtomicMax::new();
assert!(m.set_to_greater(0));
assert_eq!(m.get(), 0);
assert!(m.set_to_greater(100));
assert_eq!(m.get(), 100);
}
#[test]
fn returns_false_when_equal() {
let m = AtomicMax::new();
m.set_to_greater(42);
assert!(!m.set_to_greater(42));
assert_eq!(m.get(), 42);
}
#[test]
fn returns_false_when_smaller() {
let m = AtomicMax::new();
m.set_to_greater(100);
assert!(!m.set_to_greater(50));
assert_eq!(m.get(), 100);
}
#[test]
fn concurrent_convergence() {
let m = Arc::new(AtomicMax::new());
let threads: Vec<_> = (0..8)
.map(|t| {
let m = Arc::clone(&m);
std::thread::spawn(move || {
for i in (t * 100)..((t + 1) * 100) {
m.set_to_greater(i);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(m.get(), 799);
}
#[test]
fn i64_max_boundary() {
let m = AtomicMax::new();
assert!(m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(0));
assert_eq!(m.get(), i64::MAX);
}
}

View File

@ -26,3 +26,7 @@ export async function listAdminUsers(page: number = 1): Promise<ApiResult<AdminU
export async function setUserActive(id: string, active: boolean): Promise<ApiResult<void>> {
return apiFetch('PUT', `/api/v1/admin/users/${id}/active`, { active });
}
export async function setUserAdmin(id: string, admin: boolean): Promise<ApiResult<void>> {
return apiFetch('PUT', `/api/v1/admin/users/${id}/admin`, { admin });
}

View File

@ -213,6 +213,7 @@
},
});
lastDataKey = '';
updateCharts();
}
@ -233,6 +234,7 @@
onMount(async () => {
if (!available) return;
loadMetrics();
const mod = await import('chart.js/auto');
ChartJS = mod.Chart;

View File

@ -2,6 +2,7 @@
import CopyButton from '$lib/components/CopyButton.svelte';
import { onMount, onDestroy } from 'svelte';
import { toast } from '$lib/toast.svelte';
import { auth } from '$lib/auth.svelte';
import { formatDate, timeAgo } from '$lib/utils/format';
import {
listBuilds,
@ -13,6 +14,7 @@
type BuildLogEntry,
type AdminTemplate
} from '$lib/api/builds';
import { listAdminTeams } from '$lib/api/team';
let activeTab = $state<'templates' | 'builds'>('templates');
@ -35,6 +37,9 @@
let expandedBuildId = $state<string | null>(null);
let expandedSteps = $state<Set<number>>(new Set());
// Team name lookup
let teamNames = $state<Map<string, string>>(new Map());
// Delete template state
let deleteTarget = $state<AdminTemplate | null>(null);
let deleting = $state(false);
@ -64,6 +69,28 @@
let baseCount = $derived(templates.filter((t) => t.type === 'base').length);
let runningBuilds = $derived(builds.filter((b) => b.status === 'running').length);
async function fetchTeamNames() {
const names = new Map<string, string>();
let page = 1;
while (true) {
const result = await listAdminTeams(page);
if (!result.ok) break;
for (const team of result.data.teams) {
names.set(team.id, team.name);
}
if (page >= result.data.total_pages) break;
page++;
}
teamNames = names;
}
const PLATFORM_TEAM_ID = 'team-0000000000000000000000000';
function canDeleteTemplate(tmpl: AdminTemplate): boolean {
if (tmpl.name === 'minimal') return false;
return tmpl.team_id === PLATFORM_TEAM_ID;
}
async function fetchTemplates() {
templatesLoading = true;
templatesError = null;
@ -238,6 +265,7 @@
}
onMount(() => {
fetchTeamNames();
fetchTemplates();
fetchBuilds().then(startPolling);
@ -339,7 +367,7 @@
<div class="flex-1 overflow-y-auto px-8 py-6">
{#if activeTab === 'templates'}
{#if templatesLoading}
{@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])}
{@render skeletonRows(5, ['Name', 'Type', 'Owner', 'Specs', 'Size', 'Created', ''])}
{:else if templatesError}
<div class="rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{templatesError}
@ -442,6 +470,7 @@
<tr class="border-b border-[var(--color-border)] bg-[var(--color-bg-0)]/40">
<th class="table-header">Name</th>
<th class="table-header">Type</th>
<th class="table-header hidden md:table-cell">Owner</th>
<th class="table-header hidden md:table-cell">Specs</th>
<th class="table-header hidden lg:table-cell">Size</th>
<th class="table-header hidden lg:table-cell">Created</th>
@ -473,6 +502,13 @@
</span>
{/if}
</td>
<td class="hidden px-5 py-3.5 md:table-cell">
{#if tmpl.team_id === PLATFORM_TEAM_ID}
<span class="text-meta text-[var(--color-text-muted)]">Platform</span>
{:else}
<span class="text-meta text-[var(--color-text-secondary)]">{teamNames.get(tmpl.team_id) ?? tmpl.team_id}</span>
{/if}
</td>
<td class="hidden px-5 py-3.5 md:table-cell">
{#if tmpl.vcpus && tmpl.memory_mb}
<span class="font-mono text-meta tabular-nums text-[var(--color-text-secondary)]">
@ -495,7 +531,11 @@
<td class="px-5 py-3.5 text-right">
<button
onclick={() => { deleteTarget = tmpl; deleteError = null; }}
class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta text-[var(--color-text-tertiary)] transition-all duration-150 hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]"
disabled={!canDeleteTemplate(tmpl)}
title={tmpl.name === 'minimal' ? 'The minimal template cannot be deleted' : !canDeleteTemplate(tmpl) ? 'Cannot delete templates owned by other teams' : undefined}
class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta transition-all duration-150 {canDeleteTemplate(tmpl)
? 'text-[var(--color-text-tertiary)] hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]'
: 'text-[var(--color-text-muted)] cursor-not-allowed opacity-40'}"
>
Delete
</button>

View File

@ -5,8 +5,10 @@
import {
listAdminUsers,
setUserActive,
setUserAdmin,
type AdminUser,
} from '$lib/api/admin-users';
import { auth } from '$lib/auth.svelte';
// Data state
let users = $state<AdminUser[]>([]);
@ -22,6 +24,11 @@
// Toggle state
let togglingId = $state<string | null>(null);
// Admin dialog state
let adminTarget = $state<AdminUser | null>(null);
let togglingAdmin = $state(false);
let adminError = $state<string | null>(null);
async function fetchUsers(page: number = 1) {
const wasEmpty = users.length === 0;
if (wasEmpty) loading = true;
@ -56,6 +63,23 @@
togglingId = null;
}
async function handleConfirmAdminToggle() {
if (!adminTarget) return;
togglingAdmin = true;
adminError = null;
const target = adminTarget;
const newAdmin = !target.is_admin;
const result = await setUserAdmin(target.id, newAdmin);
if (result.ok) {
adminTarget = null;
target.is_admin = newAdmin;
toast.success(`${target.email} ${newAdmin ? 'granted' : 'revoked'} admin`);
} else {
adminError = result.error;
}
togglingAdmin = false;
}
function goToPage(page: number) {
if (page < 1 || page > totalPages) return;
fetchUsers(page);
@ -222,8 +246,18 @@
</div>
<!-- Role -->
<div class="px-5 py-4">
<span class="text-ui text-[var(--color-text-secondary)]">{user.is_admin ? 'Admin' : 'User'}</span>
<div class="flex items-center px-5 py-4">
<button
onclick={() => { adminError = null; adminTarget = user; }}
disabled={user.status !== 'active'}
aria-label="{user.is_admin ? 'Revoke admin for' : 'Grant admin to'} {user.name || user.email}"
class="rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-medium transition-all duration-150 disabled:opacity-50
{user.is_admin
? 'border-[var(--color-amber)]/30 bg-[var(--color-amber)]/10 text-[var(--color-amber)] hover:bg-[var(--color-amber)]/20 hover:border-[var(--color-amber)]/50'
: 'border-[var(--color-border)] text-[var(--color-text-tertiary)] hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-secondary)]'}"
>
{user.is_admin ? 'Admin' : 'User'}
</button>
</div>
<!-- Joined -->
@ -292,3 +326,72 @@
</div>
</footer>
</main>
<!-- Admin confirmation dialog -->
{#if adminTarget}
<div class="fixed inset-0 z-50 flex items-center justify-center">
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="absolute inset-0 bg-black/60"
onclick={() => { if (!togglingAdmin) adminTarget = null; }}
onkeydown={(e) => { if (e.key === 'Escape' && !togglingAdmin) adminTarget = null; }}
></div>
<div
class="relative w-full max-w-[420px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)]"
style="animation: fadeUp 0.2s ease both; box-shadow: var(--shadow-dialog)"
>
<div class="p-6">
<h2 class="font-serif text-heading leading-tight text-[var(--color-text-bright)]">
{adminTarget.is_admin ? 'Revoke Admin' : 'Grant Admin'}
</h2>
<p class="mt-1.5 text-ui text-[var(--color-text-tertiary)]">
{adminTarget.is_admin ? 'Remove admin access from' : 'Grant admin access to'}
<code class="rounded bg-[var(--color-bg-4)] px-1.5 py-0.5 font-mono text-[0.8rem] text-[var(--color-text-primary)]">{adminTarget.email}</code>.
{adminTarget.is_admin
? 'They will lose access to the admin panel immediately.'
: 'They will be able to manage all platform resources.'}
</p>
{#if adminTarget.is_admin && adminTarget.id === auth.userId}
<div class="mt-3 flex items-start gap-2.5 rounded-[var(--radius-input)] border border-[var(--color-amber)]/30 bg-[var(--color-amber)]/5 px-3 py-2.5">
<svg class="mt-0.5 shrink-0 text-[var(--color-amber)]" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z" /><line x1="12" y1="9" x2="12" y2="13" /><line x1="12" y1="17" x2="12.01" y2="17" />
</svg>
<span class="text-meta text-[var(--color-amber)]">
You are removing your own admin access. You will lose access to this panel.
</span>
</div>
{/if}
{#if adminError}
<div class="mt-3 rounded-[var(--radius-input)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-3 py-2 text-meta text-[var(--color-red)]">
{adminError}
</div>
{/if}
<div class="mt-6 flex justify-end gap-3">
<button
onclick={() => (adminTarget = null)}
disabled={togglingAdmin}
class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
>
Cancel
</button>
<button
onclick={handleConfirmAdminToggle}
disabled={togglingAdmin}
class="flex items-center gap-2 rounded-[var(--radius-button)] px-5 py-2 text-ui font-semibold text-white transition-all duration-150 hover:brightness-110 hover:-translate-y-px active:translate-y-0 disabled:opacity-50 disabled:hover:translate-y-0
{adminTarget.is_admin ? 'bg-[var(--color-red)]' : 'bg-[var(--color-accent)]'}"
>
{#if togglingAdmin}
<svg class="animate-spin" width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
{adminTarget.is_admin ? 'Revoking...' : 'Granting...'}
{:else}
{adminTarget.is_admin ? 'Revoke admin' : 'Grant admin'}
{/if}
</button>
</div>
</div>
</div>
</div>
{/if}

View File

@ -120,6 +120,25 @@
}
}
function mergeCapsuleData(incoming: Capsule[]) {
const existingMap = new Map(capsules.map((c) => [c.id, c]));
const merged: Capsule[] = [];
for (const fresh of incoming) {
const existing = existingMap.get(fresh.id);
if (existing) {
for (const key of Object.keys(fresh) as (keyof Capsule)[]) {
if (existing[key] !== fresh[key]) {
(existing as any)[key] = fresh[key];
}
}
merged.push(existing);
} else {
merged.push(fresh);
}
}
capsules = merged;
}
async function fetchCapsules(manual = false) {
const wasEmpty = capsules.length === 0;
if (wasEmpty) loading = true;
@ -131,7 +150,11 @@
const result = await listCapsules();
if (result.ok) {
capsules = result.data;
if (wasEmpty) {
capsules = result.data;
} else {
mergeCapsuleData(result.data);
}
error = null;
} else {
error = result.error;

View File

@ -333,6 +333,7 @@
},
});
lastDataKey = '';
updateCharts();
}
@ -376,6 +377,7 @@
if (!metricsAvailable) return;
loadMetrics();
const mod = await import('chart.js/auto');
ChartJS = mod.Chart;

View File

@ -350,9 +350,23 @@ func runPtyLoop(
defer wg.Done()
defer cancel()
for msg := range inputCh {
// Use a background context for unary RPCs so they complete
// even if the stream context is being cancelled.
// pending holds a non-input message dequeued during coalescing
// that must be processed on the next iteration.
var pending *wsPtyIn
for {
var msg wsPtyIn
if pending != nil {
msg = *pending
pending = nil
} else {
var ok bool
msg, ok = <-inputCh
if !ok {
break
}
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
switch msg.Type {
@ -364,7 +378,7 @@ func runPtyLoop(
}
// Coalesce: drain any queued input messages into a single RPC.
data = coalescePtyInput(inputCh, data)
data, pending = coalescePtyInput(inputCh, data)
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
SandboxId: sandboxID,
@ -418,24 +432,29 @@ func runPtyLoop(
}
}()
// When any pump cancels the context, close the websocket to unblock
// the reader goroutine stuck in ReadMessage.
go func() {
<-ctx.Done()
ws.conn.Close()
}()
wg.Wait()
}
// coalescePtyInput drains any immediately-available "input" messages from the
// channel and appends their decoded data to buf, reducing RPC call volume
// during bursts of fast typing.
func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte {
// during bursts of fast typing. Returns the coalesced buffer and any
// non-input message that was dequeued (must be processed by the caller).
func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) ([]byte, *wsPtyIn) {
for {
select {
case msg, ok := <-ch:
if !ok {
return buf
return buf, nil
}
if msg.Type != "input" {
// Non-input message — can't coalesce. Put-back isn't possible
// with channels, but resize/kill during a typing burst is rare
// enough that dropping one is acceptable.
return buf
return buf, &msg
}
data, err := base64.StdEncoding.DecodeString(msg.Data)
if err != nil {
@ -443,7 +462,7 @@ func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte {
}
buf = append(buf, data...)
default:
return buf
return buf, nil
}
}
}

View File

@ -162,3 +162,58 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(http.StatusNoContent)
}
// SetUserAdmin handles PUT /v1/admin/users/{id}/admin
// Grants or revokes platform admin status. Cannot remove the last admin.
func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
userIDStr := chi.URLParam(r, "id")
userID, err := id.ParseUserID(userIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Admin bool `json:"admin"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
user, err := h.db.GetUserByID(r.Context(), userID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "user not found")
return
}
if user.IsAdmin == req.Admin {
w.WriteHeader(http.StatusNoContent)
return
}
if req.Admin {
if err := h.db.SetUserAdmin(r.Context(), db.SetUserAdminParams{
ID: userID,
IsAdmin: true,
}); err != nil {
writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status")
return
}
h.audit.LogUserGrantAdmin(r.Context(), ac, userID, user.Email)
} else {
affected, err := h.db.RevokeUserAdmin(r.Context(), userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status")
return
}
if affected == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "cannot remove the last admin")
return
}
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -2346,6 +2346,54 @@ paths:
schema:
$ref: "#/components/schemas/Error"
/v1/admin/users/{id}/admin:
put:
summary: Grant or revoke platform admin
operationId: setUserAdmin
tags: [admin]
description: |
Sets the platform admin flag on a user. Cannot remove the last admin.
Requires platform admin access (JWT + is_admin).
The target user's JWT is not re-issued — their frontend will reflect the
change on next login or team switch.
security:
- bearerAuth: []
parameters:
- name: id
in: path
required: true
schema:
type: string
example: "usr-a1b2c3d4"
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [admin]
properties:
admin:
type: boolean
description: true to grant admin, false to revoke.
responses:
"204":
description: Admin status updated
"400":
$ref: "#/components/responses/BadRequest"
"403":
description: Caller is not a platform admin
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: User not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components:
securitySchemes:
apiKeyAuth:

View File

@ -269,6 +269,7 @@ func New(
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
r.Get("/audit-logs", auditH.AdminList)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)

View File

@ -135,6 +135,20 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer tracker.Release()
// Derive request context from the tracker's context so ForceClose()
// during pause aborts this proxied request.
trackerCtx := tracker.Context()
reqCtx, reqCancel := context.WithCancel(r.Context())
defer reqCancel()
go func() {
select {
case <-trackerCtx.Done():
reqCancel()
case <-reqCtx.Done():
}
}()
r = r.WithContext(reqCtx)
proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum))
proxy.ServeHTTP(w, r)
}

View File

@ -11,6 +11,7 @@ type SandboxStatus string
const (
StatusPending SandboxStatus = "pending"
StatusRunning SandboxStatus = "running"
StatusPausing SandboxStatus = "pausing"
StatusPaused SandboxStatus = "paused"
StatusStopped SandboxStatus = "stopped"
StatusError SandboxStatus = "error"

View File

@ -1,6 +1,7 @@
package sandbox
import (
"context"
"sync"
"sync/atomic"
"time"
@ -17,6 +18,20 @@ type ConnTracker struct {
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
cancelMu sync.Mutex
cancelDrain chan struct{}
// ctx is cancelled by ForceClose to abort all in-flight proxy requests.
// Initialized lazily on first Acquire; replaced by Reset after a failed
// pause so new connections get a fresh, non-cancelled context.
ctxMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// ensureCtx lazily initializes the cancellable context.
func (t *ConnTracker) ensureCtx() {
if t.ctx == nil {
t.ctx, t.cancel = context.WithCancel(context.Background())
}
}
// Acquire registers one in-flight connection. Returns false if the tracker
@ -35,6 +50,16 @@ func (t *ConnTracker) Acquire() bool {
return true
}
// Context returns a context that is cancelled when ForceClose is called.
// Proxy handlers should derive their request context from this so that
// force-close during pause aborts in-flight proxied requests.
func (t *ConnTracker) Context() context.Context {
t.ctxMu.Lock()
defer t.ctxMu.Unlock()
t.ensureCtx()
return t.ctx
}
// Release marks one connection as complete. Must be called exactly once
// per successful Acquire.
func (t *ConnTracker) Release() {
@ -65,9 +90,33 @@ func (t *ConnTracker) Drain(timeout time.Duration) {
}
}
// ForceClose cancels all in-flight proxy connections by cancelling the
// shared context. Connections whose request context derives from Context()
// will see their requests aborted, causing the proxy handler to return
// and call Release(). Waits briefly for connections to actually release.
func (t *ConnTracker) ForceClose() {
t.ctxMu.Lock()
if t.cancel != nil {
t.cancel()
}
t.ctxMu.Unlock()
// Wait briefly for force-closed connections to call Release().
done := make(chan struct{})
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
}
}
// Reset re-enables the tracker after a failed drain. This allows the
// sandbox to accept proxy connections again if the pause operation fails
// and the VM is resumed. It also cancels any lingering Drain goroutine.
// and the VM is resumed. It also cancels any lingering Drain goroutine
// and creates a fresh context for new connections.
func (t *ConnTracker) Reset() {
t.cancelMu.Lock()
if t.cancelDrain != nil {
@ -81,5 +130,10 @@ func (t *ConnTracker) Reset() {
}
t.cancelMu.Unlock()
// Replace the cancelled context with a fresh one.
t.ctxMu.Lock()
t.ctx, t.cancel = context.WithCancel(context.Background())
t.ctxMu.Unlock()
t.draining.Store(false)
}

View File

@ -95,10 +95,10 @@ type snapshotParent struct {
}
// maxDiffGenerations caps how many incremental diff generations we chain
// before falling back to a Full snapshot to collapse the chain. Long diff
// chains increase restore latency and snapshot directory size; a periodic
// Full snapshot resets the counter and produces a clean base.
const maxDiffGenerations = 8
// before merging diffs into a single file. Since UFFD lazy-loads memory
// anyway, we merge on every re-pause to keep exactly 1 diff file per
// snapshot — no accumulated chain, no extra restore overhead.
const maxDiffGenerations = 1
// buildMetadata constructs the metadata map with version information.
func (m *Manager) buildMetadata(envdVersion string) map[string]string {
@ -186,9 +186,12 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
}
// Create dm-snapshot with per-sandbox CoW file.
// CoW must be at least as large as the origin — if every block is
// rewritten, the CoW stores a full copy. Undersized CoW causes
// dm-snapshot invalidation → EIO on all guest I/O.
dmName := "wrenn-" + sandboxID
cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
cowSize := int64(diskSizeMB) * 1024 * 1024
cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil {
m.loops.Release(baseRootfs)
@ -374,28 +377,43 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
// Mark sandbox as pausing to block new exec/file/PTY operations.
m.mu.Lock()
sb.Status = models.StatusPausing
m.mu.Unlock()
// restoreRunning reverts state if any pre-freeze step fails.
restoreRunning := func() {
_ = m.vm.UpdateBalloon(context.Background(), sandboxID, 0)
sb.connTracker.Reset()
m.mu.Lock()
sb.Status = models.StatusRunning
m.mu.Unlock()
m.startSampler(sb)
}
// Stop the metrics sampler goroutine before tearing down any resources
// it reads (dm device, Firecracker PID). Without this, the sampler
// leaks on every successful pause.
m.stopSampler(sb)
// Step 0: Drain in-flight proxy connections before freezing vCPUs.
// Stale TCP state from mid-flight connections causes issues on restore.
sb.connTracker.Drain(2 * time.Second)
slog.Debug("pause: proxy connections drained", "id", sandboxID)
// ── Step 1: Isolate from external traffic ─────────────────────────
// Drain in-flight proxy connections (grace period for clean shutdown).
sb.connTracker.Drain(5 * time.Second)
// Force-close any connections that didn't finish during grace period.
sb.connTracker.ForceClose()
slog.Debug("pause: external connections closed", "id", sandboxID)
// Step 0b: Close host-side idle connections to envd. Done before
// PrepareSnapshot so FIN packets propagate to the guest during the
// PrepareSnapshot window (no extra sleep needed).
// Close host-side idle connections to envd so FIN packets propagate
// to the guest kernel before snapshot.
sb.client.CloseIdleConnections()
slog.Debug("pause: envd client idle connections closed", "id", sandboxID)
// Step 0c: Signal envd to quiesce (stop port scanner/forwarder, mark
// connections for post-restore cleanup). The 3s timeout also gives time
// for the FINs from Step 0b to be processed by the guest kernel.
// Best-effort: a failure is logged but does not abort the pause.
// ── Step 2: Drop page cache ──────────────────────────────────────
// Signal envd to quiesce: drops page cache, stops port subsystem,
// marks connections for post-restore cleanup. Page cache drop can
// take significant time on large-memory VMs (20GB+).
func() {
prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second)
prepCtx, prepCancel := context.WithTimeout(ctx, 30*time.Second)
defer prepCancel()
if err := sb.client.PrepareSnapshot(prepCtx); err != nil {
slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err)
@ -404,11 +422,37 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
}
}()
// ── Step 3: Inflate balloon to reclaim free guest memory ─────────
// Freed pages become zero from FC's perspective, so ProcessMemfile
// skips them → dramatically smaller memfile (e.g. 20GB → 1GB).
func() {
memUsed, err := readEnvdMemUsed(sb.client)
if err != nil {
slog.Debug("pause: could not read guest memory, skipping balloon inflate", "id", sandboxID, "error", err)
return
}
usedMiB := int(memUsed / (1024 * 1024))
keepMiB := max(usedMiB*2, 256) + 128
inflateMiB := sb.MemoryMB - keepMiB
if inflateMiB <= 0 {
slog.Debug("pause: not enough free memory for balloon inflate", "id", sandboxID, "used_mib", usedMiB, "total_mib", sb.MemoryMB)
return
}
balloonCtx, balloonCancel := context.WithTimeout(ctx, 10*time.Second)
defer balloonCancel()
if err := m.vm.UpdateBalloon(balloonCtx, sandboxID, inflateMiB); err != nil {
slog.Debug("pause: balloon inflate failed (non-fatal)", "id", sandboxID, "error", err)
return
}
time.Sleep(2 * time.Second)
slog.Info("pause: balloon inflated", "id", sandboxID, "inflate_mib", inflateMiB, "guest_used_mib", usedMiB)
}()
// ── Step 4: Freeze vCPUs ─────────────────────────────────────────
pauseStart := time.Now()
// Step 1: Pause the VM (freeze vCPUs).
if err := m.vm.Pause(ctx, sandboxID); err != nil {
sb.connTracker.Reset()
restoreRunning()
return fmt.Errorf("pause VM: %w", err)
}
slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart))
@ -423,13 +467,23 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// resumeOnError unpauses the VM so the sandbox stays usable when a
// post-freeze step fails. If the resume itself fails, the sandbox is
// left frozen — the caller should destroy it. It also resets the
// connection tracker so the sandbox can accept proxy connections again.
// frozen and unrecoverable — destroy it to avoid a zombie.
resumeOnError := func() {
sb.connTracker.Reset()
if err := m.vm.Resume(ctx, sandboxID); err != nil {
slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err)
// Use a fresh context — the caller's ctx may already be cancelled.
resumeCtx, resumeCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer resumeCancel()
if err := m.vm.Resume(resumeCtx, sandboxID); err != nil {
slog.Error("failed to resume VM after pause error — destroying frozen sandbox", "id", sandboxID, "error", err)
m.cleanup(context.Background(), sb)
m.mu.Lock()
delete(m.boxes, sandboxID)
m.mu.Unlock()
if m.onDestroy != nil {
m.onDestroy(sandboxID)
}
return
}
restoreRunning()
}
// Step 2: Take VM state snapshot (snapfile + memfile).
@ -444,6 +498,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
snapshotStart := time.Now()
if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil {
slog.Error("pause: snapshot failed", "id", sandboxID, "type", snapshotType, "elapsed", time.Since(snapshotStart), "error", err)
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("create VM snapshot: %w", err)
@ -795,6 +850,12 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int,
slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err)
}
// Deflate balloon — the snapshot was taken with an inflated balloon to
// reduce memfile size, so restore the guest's full memory allocation.
if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil {
slog.Debug("resume: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err)
}
// Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx)
@ -1134,7 +1195,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
dmName := "wrenn-" + sandboxID
cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
cowSize := int64(diskSizeMB) * 1024 * 1024
cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil {
source.Close()
@ -1235,6 +1296,11 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err)
}
// Deflate balloon — template snapshot was taken with an inflated balloon.
if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil {
slog.Debug("create-from-snapshot: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err)
}
// Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx)
@ -1720,12 +1786,12 @@ func (m *Manager) startSampler(sb *sandboxState) {
go m.samplerLoop(ctx, sb, fcPID, sb.VCPUs, initialCPU)
}
// samplerLoop samples /proc metrics at 500ms intervals.
// samplerLoop samples metrics at 1s intervals.
// lastCPU is goroutine-local to avoid shared-state races.
func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpus int, lastCPU cpuStat) {
defer close(sb.samplerDone)
ticker := time.NewTicker(500 * time.Millisecond)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
clkTck := 100.0 // sysconf(_SC_CLK_TCK), almost always 100 on Linux
@ -1758,8 +1824,11 @@ func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpu
cpuInitialized = true
}
// Memory: VmRSS of the Firecracker process.
memBytes, _ := readMemRSS(fcPID)
// Memory: guest-reported used memory from envd /metrics.
// VmRSS of the Firecracker process includes guest page cache
// and never decreases, so we use the guest's own view which
// reports total - available (actual process memory).
memBytes, _ := readEnvdMemUsed(sb.client)
// Disk: allocated bytes of the CoW sparse file.
var diskBytes int64

View File

@ -15,11 +15,11 @@ type MetricPoint struct {
// Ring buffer capacity constants.
const (
ring10mCap = 1200 // 500ms × 1200 = 10 min
ring2hCap = 240 // 30s × 240 = 2 h
ring24hCap = 288 // 5min × 288 = 24 h
ring10mCap = 600 // 1s × 600 = 10 min
ring2hCap = 240 // 30s × 240 = 2 h
ring24hCap = 288 // 5min × 288 = 24 h
downsample2hEvery = 60 // 60 × 500ms = 30s
downsample2hEvery = 30 // 30 × 1s = 30s
downsample24hEvery = 10 // 10 × 30s = 5min
)
@ -44,8 +44,8 @@ type metricsRing struct {
count24h int
// Accumulators for downsampling.
acc500ms [downsample2hEvery]MetricPoint
acc500msN int
acc1s [downsample2hEvery]MetricPoint
acc1sN int
acc30s [downsample24hEvery]MetricPoint
acc30sN int
@ -56,7 +56,7 @@ func newMetricsRing() *metricsRing {
return &metricsRing{}
}
// Push adds a 500ms sample to the finest tier and triggers downsampling
// Push adds a 1s sample to the finest tier and triggers downsampling
// into coarser tiers when enough samples have accumulated.
func (r *metricsRing) Push(p MetricPoint) {
r.mu.Lock()
@ -70,12 +70,12 @@ func (r *metricsRing) Push(p MetricPoint) {
}
// Accumulate for 2h downsample.
r.acc500ms[r.acc500msN] = p
r.acc500msN++
if r.acc500msN == downsample2hEvery {
avg := averagePoints(r.acc500ms[:downsample2hEvery])
r.acc1s[r.acc1sN] = p
r.acc1sN++
if r.acc1sN == downsample2hEvery {
avg := averagePoints(r.acc1s[:downsample2hEvery])
r.push2h(avg)
r.acc500msN = 0
r.acc1sN = 0
}
}
@ -138,7 +138,7 @@ func (r *metricsRing) Flush() (pts10m, pts2h, pts24h []MetricPoint) {
r.idx10m, r.count10m = 0, 0
r.idx2h, r.count2h = 0, 0
r.idx24h, r.count24h = 0, 0
r.acc500msN = 0
r.acc1sN = 0
r.acc30sN = 0
return pts10m, pts2h, pts24h

View File

@ -1,11 +1,15 @@
package sandbox
import (
"encoding/json"
"fmt"
"io"
"os"
"strconv"
"strings"
"syscall"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
)
// cpuStat holds raw CPU jiffies read from /proc/{pid}/stat.
@ -24,16 +28,11 @@ func readCPUStat(pid int) (cpuStat, error) {
return cpuStat{}, fmt.Errorf("read stat: %w", err)
}
// /proc/{pid}/stat format: pid (comm) state fields...
// The comm field may contain spaces and parens, so find the last ')' first.
content := string(data)
idx := strings.LastIndex(content, ")")
if idx < 0 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid)
}
// After ")" there is " state field3 field4 ... fieldN"
// field1 after ')' is state (index 0), utime is field 11, stime is field 12
// (0-indexed from after the closing paren).
fields := strings.Fields(content[idx+2:])
if len(fields) < 13 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields))
@ -49,27 +48,34 @@ func readCPUStat(pid int) (cpuStat, error) {
return cpuStat{utime: utime, stime: stime}, nil
}
// readMemRSS reads VmRSS from /proc/{pid}/status and returns bytes.
func readMemRSS(pid int) (int64, error) {
path := fmt.Sprintf("/proc/%d/status", pid)
data, err := os.ReadFile(path)
// readEnvdMemUsed fetches mem_used from envd's /metrics endpoint. Returns
// guest-side total - MemAvailable (actual process memory, excluding reclaimable
// page cache). VmRSS of the Firecracker process includes guest page cache and
// never decreases, so this is the accurate metric for dashboard display.
func readEnvdMemUsed(client *envdclient.Client) (int64, error) {
resp, err := client.HTTPClient().Get(client.BaseURL() + "/metrics")
if err != nil {
return 0, fmt.Errorf("read status: %w", err)
return 0, fmt.Errorf("fetch envd metrics: %w", err)
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line)
if len(fields) < 2 {
return 0, fmt.Errorf("malformed VmRSS line")
}
kb, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, fmt.Errorf("parse VmRSS: %w", err)
}
return kb * 1024, nil
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return 0, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
}
return 0, fmt.Errorf("VmRSS not found in /proc/%d/status", pid)
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("read envd metrics body: %w", err)
}
var m struct {
MemUsed int64 `json:"mem_used"`
}
if err := json.Unmarshal(body, &m); err != nil {
return 0, fmt.Errorf("decode envd metrics: %w", err)
}
return m.MemUsed, nil
}
// readDiskAllocated returns the actual allocated bytes (not apparent size)

View File

@ -29,6 +29,10 @@ import (
const (
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
UFFD_EVENT_FORK = C.UFFD_EVENT_FORK
UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP
UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE
UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
UFFDIO_COPY = C.UFFDIO_COPY
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP

View File

@ -253,8 +253,17 @@ func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
}
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT {
return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg))
event := getMsgEvent(&msg)
switch event {
case UFFD_EVENT_PAGEFAULT:
// Handled below.
case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK:
// Non-fatal lifecycle events from the guest kernel (e.g. balloon
// deflation, mmap/munmap). No action needed — continue polling.
continue
default:
return fmt.Errorf("unexpected uffd event type: %d", event)
}
arg := getMsgArg(&msg)

View File

@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"time"
)
// fcClient talks to the Firecracker HTTP API over a Unix socket.
@ -27,7 +26,9 @@ func newFCClient(socketPath string) *fcClient {
return d.DialContext(ctx, "unix", socketPath)
},
},
Timeout: 10 * time.Second,
// No global timeout — callers pass context.Context with appropriate
// deadlines. A fixed 10s timeout was too short for snapshot/resume
// operations on large-memory VMs (20GB+ memfiles).
},
}
}
@ -136,6 +137,25 @@ func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) er
})
}
// setBalloon configures the Firecracker balloon device for dynamic memory
// management. deflateOnOom lets the guest reclaim balloon pages under memory
// pressure. statsInterval enables periodic stats via GET /balloon/statistics.
// Must be called before startVM.
func (c *fcClient) setBalloon(ctx context.Context, amountMiB int, deflateOnOom bool, statsIntervalS int) error {
return c.do(ctx, http.MethodPut, "/balloon", map[string]any{
"amount_mib": amountMiB,
"deflate_on_oom": deflateOnOom,
"stats_polling_interval_s": statsIntervalS,
})
}
// updateBalloon adjusts the balloon target at runtime.
func (c *fcClient) updateBalloon(ctx context.Context, amountMiB int) error {
return c.do(ctx, http.MethodPatch, "/balloon", map[string]any{
"amount_mib": amountMiB,
})
}
// startVM issues the InstanceStart action.
func (c *fcClient) startVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/actions", map[string]string{

View File

@ -119,6 +119,13 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
return fmt.Errorf("set machine config: %w", err)
}
// Balloon device — allows the host to reclaim unused guest memory.
// Start with 0 (no inflation). deflate_on_oom lets the guest reclaim
// balloon pages under memory pressure. Stats interval enables monitoring.
if err := client.setBalloon(ctx, 0, true, 5); err != nil {
slog.Warn("set balloon failed (non-fatal, VM will run without memory reclaim)", "error", err)
}
// MMDS config — enable V2 token access on eth0 so that envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
@ -162,6 +169,19 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
return nil
}
// UpdateBalloon adjusts the balloon target for a running VM.
// amountMiB is memory to take FROM the guest (0 = give all back).
func (m *Manager) UpdateBalloon(ctx context.Context, sandboxID string, amountMiB int) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
return vm.client.updateBalloon(ctx, amountMiB)
}
// Destroy stops and cleans up a VM.
func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
m.mu.Lock()

View File

@ -365,6 +365,14 @@ func (l *AuditLogger) LogUserDeactivate(ctx context.Context, ac auth.AuthContext
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "deactivate", "warning", map[string]any{"email": email}))
}
func (l *AuditLogger) LogUserGrantAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) {
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "grant_admin", "success", map[string]any{"email": email}))
}
func (l *AuditLogger) LogUserRevokeAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) {
l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "revoke_admin", "warning", map[string]any{"email": email}))
}
// --- Team admin events (scope: admin) ---
func (l *AuditLogger) LogTeamSetBYOC(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, enabled bool) {

View File

@ -415,6 +415,21 @@ func (q *Queries) ListUsersAdmin(ctx context.Context, arg ListUsersAdminParams)
return items, nil
}
const revokeUserAdmin = `-- name: RevokeUserAdmin :execrows
UPDATE users u SET is_admin = false, updated_at = NOW()
WHERE u.id = $1
AND u.is_admin = true
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1
`
func (q *Queries) RevokeUserAdmin(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, revokeUserAdmin, id)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`

View File

@ -47,7 +47,7 @@ func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
MaxIdleConnsPerHost: 20,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: 45 * time.Second,
ResponseHeaderTimeout: 5 * time.Minute,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,

View File

@ -239,12 +239,26 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUI
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
// Revert status on failure.
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
// Check if the agent still has this sandbox. If it was destroyed
// (e.g. frozen VM couldn't be resumed), mark as "error" instead of
// reverting to "running" — which would create a ghost record.
// Use a fresh context since the original ctx may already be expired.
revertStatus := "running"
pingCtx, pingCancel := context.WithTimeout(context.Background(), 10*time.Second)
if _, pingErr := agent.PingSandbox(pingCtx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); pingErr != nil {
revertStatus = "error"
slog.Warn("sandbox gone from agent after failed pause, marking as error", "sandbox_id", sandboxIDStr)
}
pingCancel()
dbCtx, dbCancel := context.WithTimeout(context.Background(), 5*time.Second)
if _, dbErr := s.DB.UpdateSandboxStatus(dbCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: revertStatus,
}); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
dbCancel()
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}

View File

@ -57,7 +57,7 @@ if [ ! -f "${ENVD_BIN}" ]; then
exit 1
fi
if ! file "${ENVD_BIN}" | grep -q "statically linked"; then
if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then
echo "ERROR: envd is not statically linked!"
exit 1
fi

View File

@ -17,7 +17,8 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
ROOTFS="${1:-/var/lib/wrenn/images/minimal/rootfs.ext4}"
WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}"
ROOTFS="${1:-${WRENN_DIR}/images/minimal/rootfs.ext4}"
MOUNT_DIR="/tmp/wrenn-rootfs-update"
if [ ! -f "${ROOTFS}" ]; then
@ -36,6 +37,11 @@ if [ ! -f "${ENVD_BIN}" ]; then
exit 1
fi
if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then
echo "ERROR: envd is not statically linked!"
exit 1
fi
# Step 2: Mount the rootfs.
echo "==> Mounting rootfs at ${MOUNT_DIR}..."
mkdir -p "${MOUNT_DIR}"