diff --git a/.gitignore b/.gitignore index 4be2db8..f827fd2 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ frontend/build/ internal/dashboard/static/* !internal/dashboard/static/.gitkeep.dual-graph/ .dual-graph/ +__pycache__ diff --git a/.woodpecker/pipeline.yml b/.woodpecker/pipeline.yml new file mode 100644 index 0000000..a223614 --- /dev/null +++ b/.woodpecker/pipeline.yml @@ -0,0 +1,45 @@ +when: + - event: push + branch: main + +steps: + sandbox-1: + image: python:3.13 + environment: + WRENN_API_KEY: + from_secret: wrenn_api_key + GITEA_TOKEN: + from_secret: gitea_token + commands: + - pip install wrenn + - export GO_VERSION=$$(grep '^go ' go.mod | cut -d' ' -f2) + - python .woodpecker/scripts/build.py + - VERSION=$$(cat VERSION_CP) + - git config user.name "R3dRum92" + - git config user.email "tksadik@omukk.dev" + - git tag "v$${VERSION}" + - git push "https://tksadik92:$${GITEA_TOKEN}@git.omukk.dev/tksadik92/wrenn-releases.git" "v$${VERSION}" + + sandbox-2: + image: python:3.13 + environment: + WRENN_API_KEY: + from_secret: wrenn_api_key + GITEA_TOKEN: + from_secret: gitea_token + ZHIPU_API_KEY: + from_secret: zhipu_api_key + commands: + - pip install wrenn + - python .woodpecker/scripts/release_notes.py + depends_on: [sandbox-1] + + sandbox-3: + image: python:3.13 + environment: + GITHUB_TOKEN: + from_secret: github_token + commands: + - pip install httpx + - python .woodpecker/scripts/publish_github.py + depends_on: [sandbox-2] diff --git a/.woodpecker/scripts/build.py b/.woodpecker/scripts/build.py new file mode 100644 index 0000000..6bcf22f --- /dev/null +++ b/.woodpecker/scripts/build.py @@ -0,0 +1,126 @@ +import os +import sys + +from wrenn import Capsule, StreamExitEvent, StreamStderrEvent, StreamStdoutEvent +from wrenn._git import GitCommandError +from wrenn.models import FileEntry + +GO_VERSION = os.getenv("GO_VERSION", "1.25.8") +REPO_URL = "https://git.omukk.dev/wrenn/wrenn.git" +REPO_DIR = "/opt/wrenn" +BUILDS_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "builds") + + +def run(capsule: Capsule, cmd: str, timeout: int = 30) -> int: + result = capsule.commands.run(cmd, timeout=timeout) + if result.exit_code != 0: + print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr) + if result.stderr: + print(result.stderr.strip(), file=sys.stderr) + return result.exit_code + print(f"OK [{cmd.split()[0]}]") + return 0 + + +def install_go(capsule: Capsule) -> bool: + tarball = f"go{GO_VERSION}.linux-amd64.tar.gz" + url = f"https://go.dev/dl/{tarball}" + + if run(capsule, "apt update") != 0: + return False + if run(capsule, "apt install -y make build-essential file") != 0: + return False + if run(capsule, f"curl -LO {url}", timeout=120) != 0: + return False + if run(capsule, f"tar -C /usr/local -xzf {tarball}", timeout=60) != 0: + return False + if run(capsule, 'echo "export PATH=$PATH:/usr/local/go/bin" >> ~/.profile') != 0: + return False + if run(capsule, "rm -f " + tarball) != 0: + return False + + result = capsule.commands.run("/usr/local/go/bin/go version") + print(result.stdout.strip()) + return result.exit_code == 0 + + +def clone_repo(capsule: Capsule) -> bool: + try: + capsule.git.clone(REPO_URL, REPO_DIR) + print("OK [git clone]") + return True + except GitCommandError as e: + print(f"FAIL [git clone]: {e}", file=sys.stderr) + return False + + +def build_app(capsule: Capsule) -> bool: + handle = capsule.commands.run( + "CGO_ENABLED=1 make build", + background=True, + cwd=REPO_DIR, + envs={ + "PATH": "/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + }, + ) + print(f"make build started (pid={handle.pid}), streaming output...") + + exit_code = 0 + for event in capsule.commands.connect(handle.pid): + if isinstance(event, StreamStdoutEvent): + print(event.data, end="") + elif isinstance(event, StreamStderrEvent): + print(event.data, end="", file=sys.stderr) + elif isinstance(event, StreamExitEvent): + exit_code = event.exit_code + + if exit_code != 0: + print(f"FAIL [make build]: exit={exit_code}", file=sys.stderr) + return False + print("OK [make build]") + return True + + +def download_artifacts(capsule: Capsule) -> bool: + remote_dir = f"{REPO_DIR}/builds" + entries = capsule.files.list(remote_dir, depth=1) + files = [e for e in entries if e.type != "directory"] + + if not files: + print("FAIL [download]: no files found in builds/", file=sys.stderr) + return False + + local_dir = os.path.normpath(BUILDS_DIR) + os.makedirs(local_dir, exist_ok=True) + + for entry in files: + name = entry.name or "unknown" + remote_path = f"{remote_dir}/{name}" + local_path = os.path.join(local_dir, name) + print(f"Downloading {name} ({entry.size or '?'} bytes)...") + + with open(local_path, "wb") as f: + for chunk in capsule.files.download_stream(remote_path): + f.write(chunk) + + print(f"OK [download {name}]") + + return True + + +def main() -> None: + with Capsule(wait=True, vcpus=4, memory_mb=4096) as capsule: + print(f"Capsule: {capsule.capsule_id}") + if not install_go(capsule): + sys.exit(1) + if not clone_repo(capsule): + sys.exit(1) + if not build_app(capsule): + sys.exit(1) + if not download_artifacts(capsule): + sys.exit(1) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/.woodpecker/scripts/publish_github.py b/.woodpecker/scripts/publish_github.py new file mode 100644 index 0000000..b88908d --- /dev/null +++ b/.woodpecker/scripts/publish_github.py @@ -0,0 +1,104 @@ +import os +import sys +from pathlib import Path + +import httpx + +GITHUB_REPO = "R3dRum92/wrenn-releases" +GITHUB_API = "https://api.github.com" +GITHUB_UPLOADS = "https://uploads.github.com" +BUILDS_DIR = "builds" +VERSION_FILE = "VERSION_CP" +NOTES_FILE = os.path.join(".woodpecker", "release_notes.md") + + +def main() -> None: + token = os.environ["GITHUB_TOKEN"] + + with open(VERSION_FILE) as f: + version = f.read().strip() + tag = f"v{version}" + + release_notes = "" + if os.path.exists(NOTES_FILE): + with open(NOTES_FILE) as f: + release_notes = f.read() + + headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + client = httpx.Client(headers=headers, timeout=60) + + print(f"Creating GitHub release for {tag}...") + resp = client.post( + f"{GITHUB_API}/repos/{GITHUB_REPO}/releases", + json={ + "tag_name": tag, + "name": tag, + "body": release_notes, + "draft": False, + "prerelease": False, + }, + ) + if resp.status_code == 422: + print(f"WARN [create release]: release for {tag} already exists, skipping") + data = resp.json() + errors = data.get("errors", []) + if errors: + existing_url = errors[0].get("documentation_url", "") + print(f" See: {existing_url}") + client.close() + return + if resp.status_code != 201: + print(f"FAIL [create release]: {resp.status_code} {resp.text}", file=sys.stderr) + client.close() + sys.exit(1) + + release_data = resp.json() + release_id = release_data["id"] + release_url = release_data.get("html_url", "") + print(f"OK [create release] id={release_id}") + + builds_path = Path(BUILDS_DIR) + if not builds_path.exists(): + print(f"No {BUILDS_DIR}/ directory found, skipping asset upload") + client.close() + print(f"Release published: {release_url}") + return + + upload_headers = { + **headers, + "Content-Type": "application/octet-stream", + } + + for artifact in sorted(builds_path.iterdir()): + if artifact.is_dir(): + continue + print(f"Uploading {artifact.name}...") + + with open(artifact, "rb") as f: + data = f.read() + + resp = client.post( + f"{GITHUB_UPLOADS}/repos/{GITHUB_REPO}/releases/{release_id}/assets", + params={"name": artifact.name}, + headers=upload_headers, + content=data, + ) + if resp.status_code != 201: + print( + f"WARN [upload {artifact.name}]: {resp.status_code} {resp.text}", + file=sys.stderr, + ) + else: + print(f"OK [upload {artifact.name}]") + + client.close() + print(f"Release published: {release_url}") + + +if __name__ == "__main__": + main() diff --git a/.woodpecker/scripts/release_notes.py b/.woodpecker/scripts/release_notes.py new file mode 100644 index 0000000..f385656 --- /dev/null +++ b/.woodpecker/scripts/release_notes.py @@ -0,0 +1,266 @@ +import base64 +import os +import sys + +from wrenn import Capsule + +REPO_URL = "https://git.omukk.dev/tksadik92/wrenn-releases.git" +REPO_DIR = "/opt/wrenn-releases" +CAPSULE_OUTPUT = "/tmp/release_notes.md" +LOCAL_OUTPUT = os.path.join(os.path.dirname(__file__), "..", "release_notes.md") + +# Default starting configuration +ZHIPU_API_KEY = os.environ.get("ZHIPU_API_KEY", "") +if ZHIPU_API_KEY: + DEFAULT_MODEL = "zhipuai-coding-/glm-5.1" +else: + DEFAULT_MODEL = "opencode/minimax-m2.5-free" + +DEFAULT_MODEL = "opencode/minimax-m2.5-free" # TODO: Override + +RELEASE_NOTES_EXAMPLE = """ +## What's new +Sandbox HTTP proxying, terminal reliability, and auth robustness improvements. + +### Proxy +- Fixed redirect loops for apps served inside sandboxes (Python HTTP server, Jupyter, etc.) +- Proxy traffic no longer interferes with terminal and exec connections +- Services that take a moment to start up inside a sandbox are now retried instead of immediately failing + +### Terminal (PTY) +- Terminal input is no longer blocked by slow network conditions — fast typing no longer causes timeouts or disconnects +- Input bursts are coalesced into fewer round trips — lower latency under fast typing + +### Authentication +- WebSocket connections now authenticate correctly for both SDK clients (header-based) and browser clients (message-based) + +### Bug Fixes +- Fixed crash in envd when a process exits without a PTY +- Fixed goroutine leak on sandbox pause + +### Others +- Version bump +""".strip() + + +def run(capsule: Capsule, cmd: str, cwd: str | None = None, timeout: int = 30) -> int: + result = capsule.commands.run(cmd, cwd=cwd, timeout=timeout) + if result.exit_code != 0: + print(f"FAIL [{cmd.split()[0]}]: exit={result.exit_code}", file=sys.stderr) + if result.stderr: + print(result.stderr.strip(), file=sys.stderr) + return result.exit_code + print(f"OK [{cmd.split()[0]}]") + return 0 + + +def install_opencode(capsule: Capsule) -> None: + print("Installing OpenCode...") + if run(capsule, "apt update", timeout=60) != 0: + sys.exit(1) + if ( + run( + capsule, + "curl -fsSL https://opencode.ai/install | bash -s -- --version 1.14.31", + timeout=120, + ) + != 0 + ): + sys.exit(1) + print("OK [opencode installed]") + + +def get_tags(capsule: Capsule) -> tuple[str, str | None]: + result = capsule.commands.run( + f"cd {REPO_DIR} && git tag --sort=-version:refname", + cwd=REPO_DIR, + timeout=30, + ) + if result.exit_code != 0: + print(f"FAIL [git tag]: {result.stderr}", file=sys.stderr) + sys.exit(1) + tags = [t for t in result.stdout.strip().split("\n") if t] + if not tags: + print("No tags found", file=sys.stderr) + sys.exit(1) + current_tag = tags[0] + previous_tag = tags[1] if len(tags) > 1 else None + print(f"Current tag: {current_tag}") + print(f"Previous tag: {previous_tag}") + return current_tag, previous_tag + + +def get_git_context( + capsule: Capsule, current_tag: str, previous_tag: str | None +) -> tuple[str, str]: + if previous_tag: + # FIX: Removed '-n 2' to ensure we grab ALL commits between the two tags + log_cmd = f"cd {REPO_DIR} && git log {previous_tag}..{current_tag} --pretty=format:'%s (%h)'" + else: + # Fallback to limit log size if this is the very first tag in the repo + log_cmd = ( + f"cd {REPO_DIR} && git log {current_tag} --pretty=format:'%s (%h)' -n 50" + ) + + log_result = capsule.commands.run(log_cmd, cwd=REPO_DIR, timeout=30) + if log_result.exit_code != 0: + print(f"FAIL [git log]: {log_result.stderr}", file=sys.stderr) + sys.exit(1) + + # git diff natively compares the entire tree state between tags + if previous_tag: + diff_cmd = f"cd {REPO_DIR} && git diff {previous_tag}..{current_tag} --stat" + else: + diff_cmd = f"cd {REPO_DIR} && git show {current_tag} --stat" + + diff_result = capsule.commands.run(diff_cmd, cwd=REPO_DIR, timeout=30) + if diff_result.exit_code != 0: + print(f"FAIL [git diff]: {diff_result.stderr}", file=sys.stderr) + sys.exit(1) + + return log_result.stdout.strip(), diff_result.stdout.strip() + + +def generate_release_notes( + capsule: Capsule, + current_tag: str, + git_log: str, + git_diff: str, + output_path: str, + model: str, +) -> None: + prompt = ( + f"You are writing release notes for version {current_tag} of a software project.\n\n" + f"Here is what changed between the previous version and this one:\n\n" + f"Commit messages:\n{git_log}\n\n" + f"Files and areas that changed:\n{git_diff}\n\n" + f"Write the release notes in plain, friendly language that any developer can understand " + f"without deep knowledge of the codebase. Avoid jargon like 'goroutine', 'PTY', 'envd', " + f"or internal function names — describe what the change means for the user instead. " + f"Group related changes under headings that reflect what actually changed. " + f"Only include sections that are relevant to these specific changes. " + f"Start with a short one-line summary of what this release is about. " + f"Keep each bullet point to one clear sentence.\n\n" + f"Here is an example of the style to aim for — not a template to copy:\n\n" + f"{RELEASE_NOTES_EXAMPLE}\n\n" + f"You MUST start the document with `## What's New`\n" + f"The very next line MUST be a single short summary sentence.\n" + f"Output only the markdown. No intro, no explanation." + f"CRITICAL: Do not output any conversational filler, acknowledgments, or thoughts " + f"like 'Let me look at the changes'. Output absolutely nothing except the final markdown." + ) + + prompt_b64 = base64.b64encode(prompt.encode("utf-8")).decode("utf-8") + + write_prompt_cmd = f"echo '{prompt_b64}' | base64 -d > /tmp/oc_prompt.txt" + + result = capsule.commands.run( + write_prompt_cmd, + cwd=REPO_DIR, + timeout=10, + ) + if result.exit_code != 0: + print(f"FAIL [write prompt]: {result.stderr}", file=sys.stderr) + sys.exit(1) + + # FIX: Wrapper function to handle execution and authentication dynamically + def run_opencode_with_model(target_model: str) -> int: + env = "" + if "zhipu" in target_model.lower(): + env = f"ZHIPU_API_KEY={os.environ.get('ZHIPU_API_KEY', '')}" + + cmd = ( + f"{env} " + f"~/.opencode/bin/opencode run " + f'"Read the attached file and generate the release notes. Output ONLY markdown." ' + f"--model {target_model} " + f"--file /tmp/oc_prompt.txt " + f"> {output_path}" + ) + + cmd_result = capsule.commands.run(cmd, cwd=REPO_DIR, timeout=120) + + if cmd_result.exit_code != 0: + print( + f"FAIL [opencode via {target_model}]: exit={cmd_result.exit_code}", + file=sys.stderr, + ) + print(f"STDOUT:\n{cmd_result.stdout}", file=sys.stderr) + print(f"STDERR:\n{cmd_result.stderr}", file=sys.stderr) + + return cmd_result.exit_code + + # First attempt with the target model + exit_status = run_opencode_with_model(model) + + # FIX: Catch failures (like Zhipu rate limits) and fallback to MiniMax + if exit_status != 0: + if "zhipu" in model.lower(): + print( + "\n[!] Zhipu AI failed (likely rate-limited). Falling back to MiniMax...", + file=sys.stderr, + ) + fallback_model = "opencode/minimax-m2.5-free" + exit_status = run_opencode_with_model(fallback_model) + if exit_status != 0: + print("FAIL: Fallback model also failed. Exiting.", file=sys.stderr) + sys.exit(1) + else: + sys.exit(1) + + result = capsule.commands.run(f"cat {output_path}") + print(result.stdout) + if result.stderr: + print(result.stderr) + + print(f"OK [opencode] release notes written to {output_path}") + + +def download_release_notes(capsule: Capsule) -> None: + local_path = os.path.normpath(LOCAL_OUTPUT) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + print(f"Downloading release notes from capsule...") + content = capsule.files.read_bytes(CAPSULE_OUTPUT) + with open(local_path, "wb") as f: + f.write(content) + + print(f"OK [download] release notes → {local_path}") + print(content.decode("utf-8", errors="replace")) + + +def main() -> None: + model = os.environ.get("OPENCODE_MODEL", DEFAULT_MODEL) + + with Capsule(wait=True, vcpus=2, memory_mb=2048) as capsule: + print(f"Capsule: {capsule.capsule_id}") + + install_opencode(capsule) + + capsule.git.clone( + REPO_URL, + REPO_DIR, + username="tksadik92", + ) + print("OK [git clone]") + + current_tag, previous_tag = get_tags(capsule) + git_log, git_diff = get_git_context(capsule, current_tag, previous_tag) + + # Note: This simply creates the directory string safely + output_path = os.path.normpath(CAPSULE_OUTPUT) + + generate_release_notes( + capsule, + current_tag, + git_log, + git_diff, + output_path, + model, + ) + + download_release_notes(capsule) + + +if __name__ == "__main__": + main() diff --git a/VERSION_AGENT b/VERSION_AGENT index 6e8bf73..17e51c3 100644 --- a/VERSION_AGENT +++ b/VERSION_AGENT @@ -1 +1 @@ -0.1.0 +0.1.1 diff --git a/VERSION_CP b/VERSION_CP index b1e80bb..845639e 100644 --- a/VERSION_CP +++ b/VERSION_CP @@ -1 +1 @@ -0.1.3 +0.1.4 diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 5896c2c..89d65da 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -148,7 +148,13 @@ func main() { slog.Info("host registered", "host_id", creds.HostID) // httpServer is declared here so the shutdown func can reference it. - httpServer := &http.Server{Addr: listenAddr} + // ReadTimeout/WriteTimeout are intentionally omitted — they would kill + // long-lived Connect RPC streams and WebSocket proxy connections. + httpServer := &http.Server{ + Addr: listenAddr, + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 620 * time.Second, // > typical LB upstream timeout (600s) + } // mTLS is mandatory — refuse to start without a valid certificate. var certStore hostagent.CertStore @@ -193,6 +199,7 @@ func main() { path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) proxyHandler := hostagent.NewProxyHandler(mgr) + mgr.SetOnDestroy(proxyHandler.EvictProxy) mux := http.NewServeMux() mux.Handle(path, handler) diff --git a/envd/VERSION b/envd/VERSION index 6e8bf73..17e51c3 100644 --- a/envd/VERSION +++ b/envd/VERSION @@ -1 +1 @@ -0.1.0 +0.1.1 diff --git a/envd/internal/services/process/handler/handler.go b/envd/internal/services/process/handler/handler.go index dc5a8dd..9a73103 100644 --- a/envd/internal/services/process/handler/handler.go +++ b/envd/internal/services/process/handler/handler.go @@ -446,7 +446,9 @@ func (p *Handler) Wait() { err := p.cmd.Wait() - p.tty.Close() + if p.tty != nil { + p.tty.Close() + } var errMsg *string diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 5e3754d..523513c 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "path" "regexp" "strconv" "strings" @@ -74,7 +73,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec inner: inner, db: queries, pool: pool, - transport: pool.Transport(), + transport: pool.NewProxyTransport(), cache: make(map[pgtype.UUID]proxyCacheEntry), } } @@ -167,14 +166,29 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } + // The host agent's proxy adds a /proxy/{id}/{port} prefix to Location + // headers for path-based routing. For subdomain routing the browser is at + // {port}-{id}.domain, so we strip the prefix back out. + agentProxyPrefix := "/proxy/" + sandboxIDStr + "/" + port + proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { req.URL.Scheme = agentURL.Scheme req.URL.Host = agentURL.Host - req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path)) + // Use string concatenation instead of path.Join to preserve trailing + // slashes. path.Join strips them, causing redirect loops for directory + // listings in apps like python http.server and Jupyter. + req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path req.Host = agentURL.Host }, + ModifyResponse: func(resp *http.Response) error { + if loc := resp.Header.Get("Location"); loc != "" { + loc = strings.TrimPrefix(loc, agentProxyPrefix) + resp.Header.Set("Location", loc) + } + return nil + }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("sandbox proxy error", "sandbox_id", sandboxIDStr, diff --git a/internal/api/handlers_me.go b/internal/api/handlers_me.go index fefd041..194087c 100644 --- a/internal/api/handlers_me.go +++ b/internal/api/handlers_me.go @@ -404,10 +404,10 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) { return } - mac := computeHMAC(h.jwtSecret, state) + mac := computeHMAC(h.jwtSecret, state+":"+"login") http.SetCookie(w, &http.Cookie{ Name: "oauth_state", - Value: state + ":" + mac, + Value: state + ":" + mac + ":" + "login", Path: "/", MaxAge: 600, HttpOnly: true, diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index 181fc9d..f23954d 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -311,10 +311,17 @@ func runPtyLoop( } }() - // Input pump: read from WebSocket, dispatch to host agent. + // Input pump: decouple WebSocket reads from RPC dispatch. + // Reader goroutine drains the WebSocket into a buffered channel; + // sender goroutine dispatches RPCs at its own pace. This prevents + // slow RPCs from stalling WebSocket reads and causing proxy timeouts. + inputCh := make(chan wsPtyIn, 64) + + // Reader: drain WebSocket as fast as possible. wg.Add(1) go func() { defer wg.Done() + defer close(inputCh) defer cancel() for { @@ -328,6 +335,22 @@ func runPtyLoop( continue } + select { + case inputCh <- msg: + default: + // Buffer full — drop frame to keep reader unblocked. + slog.Debug("pty input buffer full, dropping frame", "type", msg.Type) + } + } + }() + + // Sender: dispatch RPCs from channel, coalescing consecutive input messages. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + for msg := range inputCh { // Use a background context for unary RPCs so they complete // even if the stream context is being cancelled. rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -339,6 +362,10 @@ func runPtyLoop( rpcCancel() continue } + + // Coalesce: drain any queued input messages into a single RPC. + data = coalescePtyInput(inputCh, data) + if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ SandboxId: sandboxID, Tag: tag, @@ -394,6 +421,33 @@ func runPtyLoop( wg.Wait() } +// coalescePtyInput drains any immediately-available "input" messages from the +// channel and appends their decoded data to buf, reducing RPC call volume +// during bursts of fast typing. +func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { + for { + select { + case msg, ok := <-ch: + if !ok { + return buf + } + if msg.Type != "input" { + // Non-input message — can't coalesce. Put-back isn't possible + // with channels, but resize/kill during a typing burst is rare + // enough that dropping one is acceptable. + return buf + } + data, err := base64.StdEncoding.DecodeString(msg.Data) + if err != nil { + continue + } + buf = append(buf, data...) + default: + return buf + } + } +} + // newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars. func newPtyTag() string { return "pty-" + id.NewPtyTag() diff --git a/internal/api/helpers_ws.go b/internal/api/helpers_ws.go index 8488cbd..f34a1df 100644 --- a/internal/api/helpers_ws.go +++ b/internal/api/helpers_ws.go @@ -3,8 +3,6 @@ package api import ( "context" "fmt" - "net/http" - "strings" "time" "github.com/gorilla/websocket" @@ -14,11 +12,6 @@ import ( "git.omukk.dev/wrenn/wrenn/pkg/id" ) -// isWebSocketUpgrade returns true if the request is a WebSocket upgrade. -func isWebSocketUpgrade(r *http.Request) bool { - return strings.EqualFold(r.Header.Get("Upgrade"), "websocket") -} - // ctxKeyAdminWS is a context key for flagging admin WS routes. type ctxKeyAdminWS struct{} diff --git a/internal/api/middleware_admin.go b/internal/api/middleware_admin.go index 670c586..e850435 100644 --- a/internal/api/middleware_admin.go +++ b/internal/api/middleware_admin.go @@ -15,7 +15,6 @@ func injectPlatformTeam() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, ok := auth.FromContext(r.Context()); !ok { - // No auth context yet (WS upgrade); handler will inject platform team after WS auth. next.ServeHTTP(w, r) return } @@ -27,23 +26,24 @@ func injectPlatformTeam() func(http.Handler) http.Handler { } } +// markAdminWS flags the request context as an admin WebSocket route. +// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin +// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin. +func markAdminWS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context()))) + }) +} + // requireAdmin validates that the authenticated user is a platform admin. // Must run after requireJWT (depends on AuthContext being present). // Re-validates against the DB — the JWT is_admin claim is for UI only; // the DB is the source of truth for admin access. -// WebSocket upgrade requests without auth context are passed through — -// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin. func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ac, ok := auth.FromContext(r.Context()) if !ok { - if isWebSocketUpgrade(r) { - ctx := r.Context() - ctx = setAdminWSFlag(ctx) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required") return } diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go index 580c8c0..0b3e571 100644 --- a/internal/api/middleware_auth.go +++ b/internal/api/middleware_auth.go @@ -85,15 +85,61 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler return } - // WebSocket upgrade requests may not carry auth headers (browsers - // cannot set custom headers on WS connections). Pass through — - // the WS handler authenticates via the first message after upgrade. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } - writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer required") }) } } + +// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject +// unauthenticated requests. It injects auth context when valid credentials +// are present (supporting SDK clients that set X-API-Key on WebSocket +// upgrades) and passes through otherwise so the handler can authenticate +// after the WebSocket upgrade via the first message. +func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try API key. + if key := r.Header.Get("X-API-Key"); key != "" { + hash := auth.HashAPIKey(key) + row, err := queries.GetAPIKeyByHash(r.Context(), hash) + if err == nil { + if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { + slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err) + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: row.TeamID, + APIKeyID: row.ID, + APIKeyName: row.Name, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + // Try JWT bearer token. + if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") { + tokenStr := strings.TrimPrefix(header, "Bearer ") + if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil { + if teamID, err := id.ParseTeamID(claims.TeamID); err == nil { + if userID, err := id.ParseUserID(claims.Subject); err == nil { + if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" { + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: teamID, + UserID: userID, + Email: claims.Email, + Name: claims.Name, + Role: claims.Role, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + } + } + } + + // No valid credentials — pass through for handler to authenticate. + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index b19c838..00649c6 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -22,13 +22,6 @@ func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Hand tokenStr = strings.TrimPrefix(header, "Bearer ") } if tokenStr == "" { - // WebSocket upgrade requests may not have an Authorization header - // (browsers cannot set custom headers on WS connections). Let them - // through — the handler authenticates via the first WS message. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer required") return } diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 8d3861c..c18c575 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -2,7 +2,7 @@ openapi: "3.1.0" info: title: Wrenn API description: MicroVM-based code execution platform API. - version: "0.1.3" + version: "0.1.4" servers: - url: http://localhost:8080 diff --git a/internal/api/server.go b/internal/api/server.go index ced39a5..11b6fbb 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -161,35 +161,47 @@ func New( r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search) // Capsule lifecycle: accepts API key or JWT bearer token. - // WebSocket upgrade requests without auth headers are passed through by - // requireAPIKeyOrJWT — the WS handlers authenticate via first message. r.Route("/v1/capsules", func(r chi.Router) { - r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) - r.Post("/", sandbox.Create) - r.Get("/", sandbox.List) - r.Get("/stats", statsH.GetStats) - r.Get("/usage", usageH.GetUsage) + // Auth-required routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Post("/", sandbox.Create) + r.Get("/", sandbox.List) + r.Get("/stats", statsH.GetStats) + r.Get("/usage", usageH.GetUsage) + }) r.Route("/{id}", func(r chi.Router) { - r.Get("/", sandbox.Get) - r.Delete("/", sandbox.Destroy) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/ping", sandbox.Ping) - r.Post("/pause", sandbox.Pause) - r.Post("/resume", sandbox.Resume) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/stream/write", filesStream.StreamUpload) - r.Post("/files/stream/read", filesStream.StreamDownload) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/", sandbox.Get) + r.Delete("/", sandbox.Destroy) + r.Post("/exec", exec.Exec) + r.Post("/ping", sandbox.Ping) + r.Post("/pause", sandbox.Pause) + r.Post("/resume", sandbox.Resume) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/stream/write", filesStream.StreamUpload) + r.Post("/files/stream/read", filesStream.StreamDownload) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // WebSocket endpoints — handlers authenticate after upgrade. + // optionalAPIKeyOrJWT injects auth context from headers when + // present (SDK clients) but does not reject when absent (browsers). + r.Group(func(r chi.Router) { + r.Use(optionalAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) @@ -248,39 +260,55 @@ func New( // Platform admin routes — require JWT + DB-validated admin status. r.Route("/v1/admin", func(r chi.Router) { - r.Use(requireJWT(jwtSecret, queries)) - r.Use(requireAdmin(queries)) - r.Get("/teams", teamH.AdminListTeams) - r.Put("/teams/{id}/byoc", teamH.SetBYOC) - r.Delete("/teams/{id}", teamH.AdminDeleteTeam) - r.Get("/users", usersH.AdminListUsers) - r.Put("/users/{id}/active", usersH.SetUserActive) - r.Get("/audit-logs", auditH.AdminList) - r.Get("/templates", buildH.ListTemplates) - r.Delete("/templates/{name}", buildH.DeleteTemplate) - r.Post("/builds", buildH.Create) - r.Get("/builds", buildH.List) - r.Get("/builds/{id}", buildH.Get) - r.Post("/builds/{id}/cancel", buildH.Cancel) - r.Post("/capsules", adminCapsules.Create) - r.Get("/capsules", adminCapsules.List) + // Auth-required admin routes (non-capsule + capsule list/create). + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Get("/teams", teamH.AdminListTeams) + r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Delete("/teams/{id}", teamH.AdminDeleteTeam) + r.Get("/users", usersH.AdminListUsers) + r.Put("/users/{id}/active", usersH.SetUserActive) + r.Get("/audit-logs", auditH.AdminList) + r.Get("/templates", buildH.ListTemplates) + r.Delete("/templates/{name}", buildH.DeleteTemplate) + r.Post("/builds", buildH.Create) + r.Get("/builds", buildH.List) + r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) + r.Post("/capsules", adminCapsules.Create) + r.Get("/capsules", adminCapsules.List) + }) + r.Route("/capsules/{id}", func(r chi.Router) { - r.Use(injectPlatformTeam()) - r.Get("/", adminCapsules.Get) - r.Delete("/", adminCapsules.Destroy) - r.Post("/snapshot", adminCapsules.Snapshot) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS admin capsule routes. + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Use(injectPlatformTeam()) + r.Get("/", adminCapsules.Get) + r.Delete("/", adminCapsules.Destroy) + r.Post("/snapshot", adminCapsules.Snapshot) + r.Post("/exec", exec.Exec) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // Admin WebSocket endpoints — handlers authenticate after upgrade + // via wsAuthenticateAdmin. markAdminWS sets the context flag so + // handlers know to use admin auth instead of regular auth. + r.Group(func(r chi.Router) { + r.Use(markAdminWS) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go index 03994b2..294a37e 100644 --- a/internal/envdclient/client.go +++ b/internal/envdclient/client.go @@ -48,6 +48,13 @@ func (c *Client) BaseURL() string { return c.base } +// HTTPClient returns the underlying http.Client used for envd requests. +// Use this instead of http.DefaultClient when making direct HTTP calls to envd +// (e.g. file streaming) to avoid sharing the global transport with proxy traffic. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + // ExecResult holds the output of a command execution. type ExecResult struct { Stdout []byte @@ -142,7 +149,7 @@ func (c *Client) ExecStream(ctx context.Context, cmd string, args ...string) (<- return nil, fmt.Errorf("start process: %w", err) } - ch := make(chan ExecStreamEvent, 16) + ch := make(chan ExecStreamEvent, 256) go func() { defer close(ch) defer stream.Close() diff --git a/internal/envdclient/dialer.go b/internal/envdclient/dialer.go index ea6492d..1813ceb 100644 --- a/internal/envdclient/dialer.go +++ b/internal/envdclient/dialer.go @@ -2,7 +2,9 @@ package envdclient import ( "fmt" + "net" "net/http" + "time" ) // envdPort is the default port envd listens on inside the guest. @@ -13,9 +15,19 @@ func baseURL(hostIP string) string { return fmt.Sprintf("http://%s:%d", hostIP, envdPort) } -// newHTTPClient returns an http.Client suitable for talking to envd. -// No special transport is needed — envd is reachable via the host IP -// through the veth/TAP network path. +// newHTTPClient returns an http.Client with a dedicated transport for talking +// to envd. The transport is intentionally separate from http.DefaultTransport +// so that proxy traffic to user services inside the sandbox cannot interfere +// with envd RPC connections (PTY streams, exec, file ops). func newHTTPClient() *http.Client { - return &http.Client{} + return &http.Client{ + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }, + } } diff --git a/internal/envdclient/pty.go b/internal/envdclient/pty.go index 7a625fb..f94a1b0 100644 --- a/internal/envdclient/pty.go +++ b/internal/envdclient/pty.go @@ -162,7 +162,7 @@ type eventProvider interface { // drainPtyStream reads events from either a Start or Connect stream and maps // them into PtyEvent values on a channel. func drainPtyStream(ctx context.Context, stream eventProvider, expectStart bool) <-chan PtyEvent { - ch := make(chan PtyEvent, 16) + ch := make(chan PtyEvent, 256) go func() { defer close(ch) defer stream.Close() diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index 7a5097d..d7c875f 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -1,16 +1,28 @@ package hostagent import ( + "context" "fmt" "log/slog" + "net" "net/http" "net/http/httputil" + "net/url" "strconv" "strings" + "sync" + "time" "git.omukk.dev/wrenn/wrenn/internal/sandbox" ) +const ( + // proxyDialAttempts is the number of connection attempts for the proxy + // transport. Retries handle the delay between a process binding to a port + // inside the guest and socat/Go-proxy starting to forward on the TAP IP. + proxyDialAttempts = 3 +) + // ProxyHandler reverse-proxies HTTP requests to services running inside // sandboxes. It handles requests of the form: // @@ -21,16 +33,75 @@ import ( type ProxyHandler struct { mgr *sandbox.Manager transport http.RoundTripper + + // proxies caches ReverseProxy instances per sandbox+port to avoid + // per-request allocation under high-frequency REST polling. + proxies sync.Map // key: "sandboxID/port" → *httputil.ReverseProxy +} + +// newProxyTransport returns an HTTP transport dedicated to proxying user +// traffic into sandboxes. It is intentionally separate from the envdclient +// transport and http.DefaultTransport to prevent proxy traffic from +// interfering with Connect RPC streams (PTY, exec). +func newProxyTransport() http.RoundTripper { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 20 * time.Second, + } + + return &http.Transport{ + ForceAttemptHTTP2: false, // HTTP/1.1 only — avoids HTTP/2 HOL blocking + MaxIdleConnsPerHost: 20, + MaxIdleConns: 100, + IdleConnTimeout: 120 * time.Second, + DisableCompression: true, + // Retry with linear backoff to handle the delay between a process + // binding inside the guest and the port forwarder making it reachable. + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var conn net.Conn + var err error + for attempt := range proxyDialAttempts { + conn, err = dialer.DialContext(ctx, network, addr) + if err == nil { + return conn, nil + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + // Don't sleep on the last attempt. + if attempt < proxyDialAttempts-1 { + backoff := time.Duration(100*(attempt+1)) * time.Millisecond + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } + return nil, err + }, + } } // NewProxyHandler creates a new sandbox proxy handler. func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler { return &ProxyHandler{ mgr: mgr, - transport: http.DefaultTransport, + transport: newProxyTransport(), } } +// EvictProxy removes cached reverse proxy instances for a sandbox. +// Call this when a sandbox is destroyed. +func (h *ProxyHandler) EvictProxy(sandboxID string) { + h.proxies.Range(func(key, _ any) bool { + if k, ok := key.(string); ok && strings.HasPrefix(k, sandboxID+"/") { + h.proxies.Delete(key) + } + return true + }) +} + // ServeHTTP implements http.Handler. func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Expected path: /proxy/{sandbox_id}/{port}/... @@ -49,10 +120,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sandboxID := parts[0] port := parts[1] - remainder := "" - if len(parts) == 3 { - remainder = parts[2] - } // Validate port is a number in the valid range. portNum, err := strconv.Atoi(port) @@ -68,22 +135,61 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer tracker.Release() - targetHost := fmt.Sprintf("%s:%d", hostIP, portNum) + proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum)) + proxy.ServeHTTP(w, r) +} + +// getOrCreateProxy returns a cached ReverseProxy for the given sandbox+port+host, +// creating one if it doesn't exist. The targetHost is included in the key so +// that an IP change after pause/resume naturally misses the old entry. +func (h *ProxyHandler) getOrCreateProxy(sandboxID, port, targetHost string) *httputil.ReverseProxy { + cacheKey := sandboxID + "/" + port + "/" + targetHost + + if v, ok := h.proxies.Load(cacheKey); ok { + return v.(*httputil.ReverseProxy) + } + + proxyPrefix := "/proxy/" + sandboxID + "/" + port proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { + // Extract remainder from the original path: /proxy/{id}/{port}/{remainder} + remainder := "" + if trimmed := strings.TrimPrefix(req.URL.Path, proxyPrefix); trimmed != req.URL.Path { + remainder = strings.TrimPrefix(trimmed, "/") + } + req.URL.Scheme = "http" req.URL.Host = targetHost req.URL.Path = "/" + remainder - req.URL.RawQuery = r.URL.RawQuery req.Host = targetHost }, + // Rewrite redirect Location headers so they include the /proxy/{id}/{port} + // prefix. Handles both root-relative (/path) and absolute-URL redirects + // (http://internal-ip:port/path) that would otherwise leak internal IPs + // or break directory navigation. + ModifyResponse: func(resp *http.Response) error { + loc := resp.Header.Get("Location") + if loc == "" { + return nil + } + if strings.HasPrefix(loc, "/") { + resp.Header.Set("Location", proxyPrefix+loc) + return nil + } + // Rewrite absolute URLs pointing to the internal target host. + if u, err := url.Parse(loc); err == nil && u.Host == targetHost { + resp.Header.Set("Location", proxyPrefix+u.RequestURI()) + } + return nil + }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err) http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) }, } - proxy.ServeHTTP(w, r) + actual, _ := h.proxies.LoadOrStore(cacheKey, proxy) + return actual.(*httputil.ReverseProxy) } diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index 663d2cb..e15ef0b 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -459,7 +459,7 @@ func (s *Server) WriteFileStream( } httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType()) - resp, err := http.DefaultClient.Do(httpReq) + resp, err := client.HTTPClient().Do(httpReq) if err != nil { pw.CloseWithError(err) <-errCh @@ -504,7 +504,7 @@ func (s *Server) ReadFileStream( return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err)) } - resp, err := http.DefaultClient.Do(httpReq) + resp, err := client.HTTPClient().Do(httpReq) if err != nil { return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err)) } diff --git a/internal/network/setup.go b/internal/network/setup.go index 3874c79..d68da89 100644 --- a/internal/network/setup.go +++ b/internal/network/setup.go @@ -269,6 +269,7 @@ func CreateNetwork(slot *Slot) error { // Create TAP device inside namespace. tapAttrs := netlink.NewLinkAttrs() tapAttrs.Name = tapName + tapAttrs.TxQLen = 5000 // Up from default 1000 to reduce drops under bursty traffic. tap := &netlink.Tuntap{ LinkAttrs: tapAttrs, Mode: netlink.TUNTAP_MODE_TAP, diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 524631d..daa1dba 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -53,6 +53,15 @@ type Manager struct { autoPausedMu sync.Mutex autoPausedIDs []string + + // onDestroy is called with the sandbox ID after cleanup completes. + // Used by ProxyHandler to evict cached reverse proxies. + onDestroy func(sandboxID string) +} + +// SetOnDestroy registers a callback invoked after each sandbox is cleaned up. +func (m *Manager) SetOnDestroy(fn func(sandboxID string)) { + m.onDestroy = fn } // sandboxState holds the runtime state for a single sandbox. @@ -314,6 +323,10 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { slog.Warn("snapshot cleanup error", "id", sandboxID, "error", err) } + if m.onDestroy != nil { + m.onDestroy(sandboxID) + } + slog.Info("sandbox destroyed", "id", sandboxID) return nil } @@ -363,6 +376,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Stop the metrics sampler goroutine before tearing down any resources + // it reads (dm device, Firecracker PID). Without this, the sampler + // leaks on every successful pause. + m.stopSampler(sb) + // Step 0: Drain in-flight proxy connections before freezing vCPUs. // This prevents Go runtime corruption inside the guest caused by stale // TCP state from connections that were alive when the VM was snapshotted. diff --git a/internal/vm/fc.go b/internal/vm/fc.go index 3d0f246..5a131a4 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -84,11 +84,21 @@ func (c *fcClient) setRootfsDrive(ctx context.Context, driveID, path string, rea } // setNetworkInterface configures a network interface attached to a TAP device. +// A tx_rate_limiter caps sustained guest→host throughput to prevent user +// application traffic from completely saturating the TAP device and starving +// envd control traffic (PTY, exec, file ops). func (c *fcClient) setNetworkInterface(ctx context.Context, ifaceID, tapName, macAddr string) error { return c.do(ctx, http.MethodPut, "/network-interfaces/"+ifaceID, map[string]any{ "iface_id": ifaceID, "host_dev_name": tapName, "guest_mac": macAddr, + "tx_rate_limiter": map[string]any{ + "bandwidth": map[string]any{ + "size": 209715200, // 200 MB/s sustained + "refill_time": 1000, // refill period: 1 second + "one_time_burst": 104857600, // 100 MB initial burst + }, + }, }) } diff --git a/pkg/lifecycle/hostpool.go b/pkg/lifecycle/hostpool.go index 3931d7b..48ed6c9 100644 --- a/pkg/lifecycle/hostpool.go +++ b/pkg/lifecycle/hostpool.go @@ -3,6 +3,7 @@ package lifecycle import ( "crypto/tls" "fmt" + "net" "net/http" "strings" "sync" @@ -115,6 +116,34 @@ func (p *HostClientPool) ResolveAddr(addr string) string { return p.ensureScheme(addr) } +// NewProxyTransport returns a new http.RoundTripper configured for proxying +// user traffic to sandbox services. It is intentionally separate from the RPC +// transport returned by Transport() so that heavy proxy traffic (Jupyter +// WebSocket, REST API polling) cannot interfere with Connect RPC streams (PTY, +// exec) via HTTP/2 flow control or connection pool contention. +func (p *HostClientPool) NewProxyTransport() http.RoundTripper { + t := &http.Transport{ + ForceAttemptHTTP2: false, // HTTP/1.1 only — avoids HTTP/2 HOL blocking + MaxIdleConnsPerHost: 20, + MaxIdleConns: 100, + IdleConnTimeout: 120 * time.Second, + DisableCompression: true, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 20 * time.Second, + }).DialContext, + } + + // If the pool uses TLS, the proxy transport must too. + if p.httpClient.Transport != nil { + if ht, ok := p.httpClient.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil { + t.TLSClientConfig = ht.TLSClientConfig.Clone() + } + } + + return t +} + // EnsureScheme adds "http://" if the address has no scheme. // Deprecated: use pool.ResolveAddr which respects the pool's TLS setting. func EnsureScheme(addr string) string {