From 339cd7bee13df655041694dcf276358ec17c1731 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Fri, 24 Apr 2026 15:48:38 +0600 Subject: [PATCH 1/3] fix: security and stability fixes from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Scope WebSocket auth bypass to only WS endpoints by restructuring routes into separate chi Groups. Non-WS routes no longer passthrough unauthenticated requests with spoofed Upgrade headers. Added optionalAPIKeyOrJWT middleware for WS routes (injects auth context from API key/JWT if present, passes through otherwise) and markAdminWS middleware for admin WS routes. - Fix nil pointer dereference in envd Handler.Wait() — p.tty.Close() was called unconditionally but p.tty is nil for non-PTY processes, crashing every non-PTY process exit. - Fix goroutine leak in sandbox Pause — stopSampler was never called, leaking one sampler goroutine per successful pause operation. - Decouple PTY WebSocket reads from RPC dispatch using a buffered channel to prevent backpressure-induced connection drops under fast typing. Includes input coalescing to reduce RPC call volume. --- VERSION_CP | 2 +- .../services/process/handler/handler.go | 4 +- internal/api/handlers_pty.go | 56 ++++++- internal/api/helpers_ws.go | 7 - internal/api/middleware_admin.go | 18 +-- internal/api/middleware_auth.go | 62 +++++++- internal/api/middleware_jwt.go | 7 - internal/api/openapi.yaml | 2 +- internal/api/server.go | 144 +++++++++++------- internal/sandbox/manager.go | 5 + 10 files changed, 214 insertions(+), 93 deletions(-) diff --git a/VERSION_CP b/VERSION_CP index b1e80bb..845639e 100644 --- a/VERSION_CP +++ b/VERSION_CP @@ -1 +1 @@ -0.1.3 +0.1.4 diff --git a/envd/internal/services/process/handler/handler.go b/envd/internal/services/process/handler/handler.go index dc5a8dd..9a73103 100644 --- a/envd/internal/services/process/handler/handler.go +++ b/envd/internal/services/process/handler/handler.go @@ -446,7 +446,9 @@ func (p *Handler) Wait() { err := p.cmd.Wait() - p.tty.Close() + if p.tty != nil { + p.tty.Close() + } var errMsg *string diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index 181fc9d..f23954d 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -311,10 +311,17 @@ func runPtyLoop( } }() - // Input pump: read from WebSocket, dispatch to host agent. + // Input pump: decouple WebSocket reads from RPC dispatch. + // Reader goroutine drains the WebSocket into a buffered channel; + // sender goroutine dispatches RPCs at its own pace. This prevents + // slow RPCs from stalling WebSocket reads and causing proxy timeouts. + inputCh := make(chan wsPtyIn, 64) + + // Reader: drain WebSocket as fast as possible. wg.Add(1) go func() { defer wg.Done() + defer close(inputCh) defer cancel() for { @@ -328,6 +335,22 @@ func runPtyLoop( continue } + select { + case inputCh <- msg: + default: + // Buffer full — drop frame to keep reader unblocked. + slog.Debug("pty input buffer full, dropping frame", "type", msg.Type) + } + } + }() + + // Sender: dispatch RPCs from channel, coalescing consecutive input messages. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + for msg := range inputCh { // Use a background context for unary RPCs so they complete // even if the stream context is being cancelled. rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -339,6 +362,10 @@ func runPtyLoop( rpcCancel() continue } + + // Coalesce: drain any queued input messages into a single RPC. + data = coalescePtyInput(inputCh, data) + if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ SandboxId: sandboxID, Tag: tag, @@ -394,6 +421,33 @@ func runPtyLoop( wg.Wait() } +// coalescePtyInput drains any immediately-available "input" messages from the +// channel and appends their decoded data to buf, reducing RPC call volume +// during bursts of fast typing. +func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { + for { + select { + case msg, ok := <-ch: + if !ok { + return buf + } + if msg.Type != "input" { + // Non-input message — can't coalesce. Put-back isn't possible + // with channels, but resize/kill during a typing burst is rare + // enough that dropping one is acceptable. + return buf + } + data, err := base64.StdEncoding.DecodeString(msg.Data) + if err != nil { + continue + } + buf = append(buf, data...) + default: + return buf + } + } +} + // newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars. func newPtyTag() string { return "pty-" + id.NewPtyTag() diff --git a/internal/api/helpers_ws.go b/internal/api/helpers_ws.go index 8488cbd..f34a1df 100644 --- a/internal/api/helpers_ws.go +++ b/internal/api/helpers_ws.go @@ -3,8 +3,6 @@ package api import ( "context" "fmt" - "net/http" - "strings" "time" "github.com/gorilla/websocket" @@ -14,11 +12,6 @@ import ( "git.omukk.dev/wrenn/wrenn/pkg/id" ) -// isWebSocketUpgrade returns true if the request is a WebSocket upgrade. -func isWebSocketUpgrade(r *http.Request) bool { - return strings.EqualFold(r.Header.Get("Upgrade"), "websocket") -} - // ctxKeyAdminWS is a context key for flagging admin WS routes. type ctxKeyAdminWS struct{} diff --git a/internal/api/middleware_admin.go b/internal/api/middleware_admin.go index 670c586..e850435 100644 --- a/internal/api/middleware_admin.go +++ b/internal/api/middleware_admin.go @@ -15,7 +15,6 @@ func injectPlatformTeam() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, ok := auth.FromContext(r.Context()); !ok { - // No auth context yet (WS upgrade); handler will inject platform team after WS auth. next.ServeHTTP(w, r) return } @@ -27,23 +26,24 @@ func injectPlatformTeam() func(http.Handler) http.Handler { } } +// markAdminWS flags the request context as an admin WebSocket route. +// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin +// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin. +func markAdminWS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context()))) + }) +} + // requireAdmin validates that the authenticated user is a platform admin. // Must run after requireJWT (depends on AuthContext being present). // Re-validates against the DB — the JWT is_admin claim is for UI only; // the DB is the source of truth for admin access. -// WebSocket upgrade requests without auth context are passed through — -// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin. func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ac, ok := auth.FromContext(r.Context()) if !ok { - if isWebSocketUpgrade(r) { - ctx := r.Context() - ctx = setAdminWSFlag(ctx) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required") return } diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go index 580c8c0..0b3e571 100644 --- a/internal/api/middleware_auth.go +++ b/internal/api/middleware_auth.go @@ -85,15 +85,61 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler return } - // WebSocket upgrade requests may not carry auth headers (browsers - // cannot set custom headers on WS connections). Pass through — - // the WS handler authenticates via the first message after upgrade. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } - writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer required") }) } } + +// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject +// unauthenticated requests. It injects auth context when valid credentials +// are present (supporting SDK clients that set X-API-Key on WebSocket +// upgrades) and passes through otherwise so the handler can authenticate +// after the WebSocket upgrade via the first message. +func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try API key. + if key := r.Header.Get("X-API-Key"); key != "" { + hash := auth.HashAPIKey(key) + row, err := queries.GetAPIKeyByHash(r.Context(), hash) + if err == nil { + if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { + slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err) + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: row.TeamID, + APIKeyID: row.ID, + APIKeyName: row.Name, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + // Try JWT bearer token. + if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") { + tokenStr := strings.TrimPrefix(header, "Bearer ") + if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil { + if teamID, err := id.ParseTeamID(claims.TeamID); err == nil { + if userID, err := id.ParseUserID(claims.Subject); err == nil { + if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" { + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: teamID, + UserID: userID, + Email: claims.Email, + Name: claims.Name, + Role: claims.Role, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + } + } + } + + // No valid credentials — pass through for handler to authenticate. + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index b19c838..00649c6 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -22,13 +22,6 @@ func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Hand tokenStr = strings.TrimPrefix(header, "Bearer ") } if tokenStr == "" { - // WebSocket upgrade requests may not have an Authorization header - // (browsers cannot set custom headers on WS connections). Let them - // through — the handler authenticates via the first WS message. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer required") return } diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 8d3861c..c18c575 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -2,7 +2,7 @@ openapi: "3.1.0" info: title: Wrenn API description: MicroVM-based code execution platform API. - version: "0.1.3" + version: "0.1.4" servers: - url: http://localhost:8080 diff --git a/internal/api/server.go b/internal/api/server.go index ced39a5..11b6fbb 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -161,35 +161,47 @@ func New( r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search) // Capsule lifecycle: accepts API key or JWT bearer token. - // WebSocket upgrade requests without auth headers are passed through by - // requireAPIKeyOrJWT — the WS handlers authenticate via first message. r.Route("/v1/capsules", func(r chi.Router) { - r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) - r.Post("/", sandbox.Create) - r.Get("/", sandbox.List) - r.Get("/stats", statsH.GetStats) - r.Get("/usage", usageH.GetUsage) + // Auth-required routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Post("/", sandbox.Create) + r.Get("/", sandbox.List) + r.Get("/stats", statsH.GetStats) + r.Get("/usage", usageH.GetUsage) + }) r.Route("/{id}", func(r chi.Router) { - r.Get("/", sandbox.Get) - r.Delete("/", sandbox.Destroy) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/ping", sandbox.Ping) - r.Post("/pause", sandbox.Pause) - r.Post("/resume", sandbox.Resume) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/stream/write", filesStream.StreamUpload) - r.Post("/files/stream/read", filesStream.StreamDownload) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/", sandbox.Get) + r.Delete("/", sandbox.Destroy) + r.Post("/exec", exec.Exec) + r.Post("/ping", sandbox.Ping) + r.Post("/pause", sandbox.Pause) + r.Post("/resume", sandbox.Resume) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/stream/write", filesStream.StreamUpload) + r.Post("/files/stream/read", filesStream.StreamDownload) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // WebSocket endpoints — handlers authenticate after upgrade. + // optionalAPIKeyOrJWT injects auth context from headers when + // present (SDK clients) but does not reject when absent (browsers). + r.Group(func(r chi.Router) { + r.Use(optionalAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) @@ -248,39 +260,55 @@ func New( // Platform admin routes — require JWT + DB-validated admin status. r.Route("/v1/admin", func(r chi.Router) { - r.Use(requireJWT(jwtSecret, queries)) - r.Use(requireAdmin(queries)) - r.Get("/teams", teamH.AdminListTeams) - r.Put("/teams/{id}/byoc", teamH.SetBYOC) - r.Delete("/teams/{id}", teamH.AdminDeleteTeam) - r.Get("/users", usersH.AdminListUsers) - r.Put("/users/{id}/active", usersH.SetUserActive) - r.Get("/audit-logs", auditH.AdminList) - r.Get("/templates", buildH.ListTemplates) - r.Delete("/templates/{name}", buildH.DeleteTemplate) - r.Post("/builds", buildH.Create) - r.Get("/builds", buildH.List) - r.Get("/builds/{id}", buildH.Get) - r.Post("/builds/{id}/cancel", buildH.Cancel) - r.Post("/capsules", adminCapsules.Create) - r.Get("/capsules", adminCapsules.List) + // Auth-required admin routes (non-capsule + capsule list/create). + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Get("/teams", teamH.AdminListTeams) + r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Delete("/teams/{id}", teamH.AdminDeleteTeam) + r.Get("/users", usersH.AdminListUsers) + r.Put("/users/{id}/active", usersH.SetUserActive) + r.Get("/audit-logs", auditH.AdminList) + r.Get("/templates", buildH.ListTemplates) + r.Delete("/templates/{name}", buildH.DeleteTemplate) + r.Post("/builds", buildH.Create) + r.Get("/builds", buildH.List) + r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) + r.Post("/capsules", adminCapsules.Create) + r.Get("/capsules", adminCapsules.List) + }) + r.Route("/capsules/{id}", func(r chi.Router) { - r.Use(injectPlatformTeam()) - r.Get("/", adminCapsules.Get) - r.Delete("/", adminCapsules.Destroy) - r.Post("/snapshot", adminCapsules.Snapshot) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS admin capsule routes. + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Use(injectPlatformTeam()) + r.Get("/", adminCapsules.Get) + r.Delete("/", adminCapsules.Destroy) + r.Post("/snapshot", adminCapsules.Snapshot) + r.Post("/exec", exec.Exec) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // Admin WebSocket endpoints — handlers authenticate after upgrade + // via wsAuthenticateAdmin. markAdminWS sets the context flag so + // handlers know to use admin auth instead of regular auth. + r.Group(func(r chi.Router) { + r.Use(markAdminWS) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 524631d..a7ff69d 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -363,6 +363,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Stop the metrics sampler goroutine before tearing down any resources + // it reads (dm device, Firecracker PID). Without this, the sampler + // leaks on every successful pause. + m.stopSampler(sb) + // Step 0: Drain in-flight proxy connections before freezing vCPUs. // This prevents Go runtime corruption inside the guest caused by stale // TCP state from connections that were alive when the VM was snapshotted. From 5e13879954a5b3cf9b0381ea1a6c2e210b8ee489 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sat, 25 Apr 2026 02:00:39 +0600 Subject: [PATCH 2/3] fix: OAuth ConnectProvider state HMAC format mismatch ConnectProvider computed HMAC over bare state, but Callback always verifies HMAC(state+":"+intent). This caused the account-linking flow to always fail with invalid_state. --- internal/api/handlers_me.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/api/handlers_me.go b/internal/api/handlers_me.go index fefd041..194087c 100644 --- a/internal/api/handlers_me.go +++ b/internal/api/handlers_me.go @@ -404,10 +404,10 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) { return } - mac := computeHMAC(h.jwtSecret, state) + mac := computeHMAC(h.jwtSecret, state+":"+"login") http.SetCookie(w, &http.Cookie{ Name: "oauth_state", - Value: state + ":" + mac, + Value: state + ":" + mac + ":" + "login", Path: "/", MaxAge: 600, HttpOnly: true, From bd986101538e44dbd57889f0e3b0fde532dcf3a7 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sat, 25 Apr 2026 04:21:55 +0600 Subject: [PATCH 3/3] fix: sandbox network responsiveness under port-binding apps Running port-binding applications (Jupyter, http.server, NextJS) inside sandboxes caused severe PTY sluggishness and proxy navigation errors. Root cause: the CP sandbox proxy and Connect RPC pool shared a single HTTP transport. Heavy proxy traffic (Jupyter WebSocket, REST polling) interfered with PTY RPC streams via HTTP/2 flow control contention. Transport isolation (main fix): - Add dedicated proxy transport on CP (NewProxyTransport) with HTTP/2 disabled, separate from the RPC pool transport - Add dedicated proxy transport on host agent, replacing http.DefaultTransport - Add dedicated envdclient transport with tuned connection pooling - Replace http.DefaultClient in file streaming RPCs with per-sandbox envd client Proxy path rewriting (navigation fix): - Add ModifyResponse to rewrite Location headers with /proxy/{id}/{port} prefix, handling both root-relative and absolute-URL redirects - Strip prefix back out in CP subdomain proxy for correct browser behavior - Replace path.Join with string concat in CP Director to preserve trailing slashes (prevents redirect loops on directory listings) Proxy resilience: - Add dial retry with linear backoff (3 attempts) to handle socat startup delay when ports are first detected - Cache ReverseProxy instances per sandbox+port+host in sync.Map - Add EvictProxy callback wired into sandbox Manager.Destroy Buffer and server hardening: - Increase PTY and exec stream channel buffers from 16 to 256 - Add ReadHeaderTimeout (10s) and IdleTimeout (620s) to host agent HTTP server Network tuning: - Set TAP device TxQueueLen to 5000 (up from default 1000) - Add Firecracker tx_rate_limiter (200 MB/s sustained, 100 MB burst) to prevent guest traffic from saturating the TAP --- cmd/host-agent/main.go | 9 +- internal/api/handler_sandbox_proxy.go | 20 ++++- internal/envdclient/client.go | 9 +- internal/envdclient/dialer.go | 20 ++++- internal/envdclient/pty.go | 2 +- internal/hostagent/proxy.go | 122 ++++++++++++++++++++++++-- internal/hostagent/server.go | 4 +- internal/network/setup.go | 1 + internal/sandbox/manager.go | 13 +++ internal/vm/fc.go | 10 +++ pkg/lifecycle/hostpool.go | 29 ++++++ 11 files changed, 219 insertions(+), 20 deletions(-) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 5896c2c..89d65da 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -148,7 +148,13 @@ func main() { slog.Info("host registered", "host_id", creds.HostID) // httpServer is declared here so the shutdown func can reference it. - httpServer := &http.Server{Addr: listenAddr} + // ReadTimeout/WriteTimeout are intentionally omitted — they would kill + // long-lived Connect RPC streams and WebSocket proxy connections. + httpServer := &http.Server{ + Addr: listenAddr, + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 620 * time.Second, // > typical LB upstream timeout (600s) + } // mTLS is mandatory — refuse to start without a valid certificate. var certStore hostagent.CertStore @@ -193,6 +199,7 @@ func main() { path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) proxyHandler := hostagent.NewProxyHandler(mgr) + mgr.SetOnDestroy(proxyHandler.EvictProxy) mux := http.NewServeMux() mux.Handle(path, handler) diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 5e3754d..523513c 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "path" "regexp" "strconv" "strings" @@ -74,7 +73,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec inner: inner, db: queries, pool: pool, - transport: pool.Transport(), + transport: pool.NewProxyTransport(), cache: make(map[pgtype.UUID]proxyCacheEntry), } } @@ -167,14 +166,29 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } + // The host agent's proxy adds a /proxy/{id}/{port} prefix to Location + // headers for path-based routing. For subdomain routing the browser is at + // {port}-{id}.domain, so we strip the prefix back out. + agentProxyPrefix := "/proxy/" + sandboxIDStr + "/" + port + proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { req.URL.Scheme = agentURL.Scheme req.URL.Host = agentURL.Host - req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path)) + // Use string concatenation instead of path.Join to preserve trailing + // slashes. path.Join strips them, causing redirect loops for directory + // listings in apps like python http.server and Jupyter. + req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path req.Host = agentURL.Host }, + ModifyResponse: func(resp *http.Response) error { + if loc := resp.Header.Get("Location"); loc != "" { + loc = strings.TrimPrefix(loc, agentProxyPrefix) + resp.Header.Set("Location", loc) + } + return nil + }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("sandbox proxy error", "sandbox_id", sandboxIDStr, diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go index 03994b2..294a37e 100644 --- a/internal/envdclient/client.go +++ b/internal/envdclient/client.go @@ -48,6 +48,13 @@ func (c *Client) BaseURL() string { return c.base } +// HTTPClient returns the underlying http.Client used for envd requests. +// Use this instead of http.DefaultClient when making direct HTTP calls to envd +// (e.g. file streaming) to avoid sharing the global transport with proxy traffic. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + // ExecResult holds the output of a command execution. type ExecResult struct { Stdout []byte @@ -142,7 +149,7 @@ func (c *Client) ExecStream(ctx context.Context, cmd string, args ...string) (<- return nil, fmt.Errorf("start process: %w", err) } - ch := make(chan ExecStreamEvent, 16) + ch := make(chan ExecStreamEvent, 256) go func() { defer close(ch) defer stream.Close() diff --git a/internal/envdclient/dialer.go b/internal/envdclient/dialer.go index ea6492d..1813ceb 100644 --- a/internal/envdclient/dialer.go +++ b/internal/envdclient/dialer.go @@ -2,7 +2,9 @@ package envdclient import ( "fmt" + "net" "net/http" + "time" ) // envdPort is the default port envd listens on inside the guest. @@ -13,9 +15,19 @@ func baseURL(hostIP string) string { return fmt.Sprintf("http://%s:%d", hostIP, envdPort) } -// newHTTPClient returns an http.Client suitable for talking to envd. -// No special transport is needed — envd is reachable via the host IP -// through the veth/TAP network path. +// newHTTPClient returns an http.Client with a dedicated transport for talking +// to envd. The transport is intentionally separate from http.DefaultTransport +// so that proxy traffic to user services inside the sandbox cannot interfere +// with envd RPC connections (PTY streams, exec, file ops). func newHTTPClient() *http.Client { - return &http.Client{} + return &http.Client{ + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }, + } } diff --git a/internal/envdclient/pty.go b/internal/envdclient/pty.go index 7a625fb..f94a1b0 100644 --- a/internal/envdclient/pty.go +++ b/internal/envdclient/pty.go @@ -162,7 +162,7 @@ type eventProvider interface { // drainPtyStream reads events from either a Start or Connect stream and maps // them into PtyEvent values on a channel. func drainPtyStream(ctx context.Context, stream eventProvider, expectStart bool) <-chan PtyEvent { - ch := make(chan PtyEvent, 16) + ch := make(chan PtyEvent, 256) go func() { defer close(ch) defer stream.Close() diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index 7a5097d..d7c875f 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -1,16 +1,28 @@ package hostagent import ( + "context" "fmt" "log/slog" + "net" "net/http" "net/http/httputil" + "net/url" "strconv" "strings" + "sync" + "time" "git.omukk.dev/wrenn/wrenn/internal/sandbox" ) +const ( + // proxyDialAttempts is the number of connection attempts for the proxy + // transport. Retries handle the delay between a process binding to a port + // inside the guest and socat/Go-proxy starting to forward on the TAP IP. + proxyDialAttempts = 3 +) + // ProxyHandler reverse-proxies HTTP requests to services running inside // sandboxes. It handles requests of the form: // @@ -21,16 +33,75 @@ import ( type ProxyHandler struct { mgr *sandbox.Manager transport http.RoundTripper + + // proxies caches ReverseProxy instances per sandbox+port to avoid + // per-request allocation under high-frequency REST polling. + proxies sync.Map // key: "sandboxID/port" → *httputil.ReverseProxy +} + +// newProxyTransport returns an HTTP transport dedicated to proxying user +// traffic into sandboxes. It is intentionally separate from the envdclient +// transport and http.DefaultTransport to prevent proxy traffic from +// interfering with Connect RPC streams (PTY, exec). +func newProxyTransport() http.RoundTripper { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 20 * time.Second, + } + + return &http.Transport{ + ForceAttemptHTTP2: false, // HTTP/1.1 only — avoids HTTP/2 HOL blocking + MaxIdleConnsPerHost: 20, + MaxIdleConns: 100, + IdleConnTimeout: 120 * time.Second, + DisableCompression: true, + // Retry with linear backoff to handle the delay between a process + // binding inside the guest and the port forwarder making it reachable. + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var conn net.Conn + var err error + for attempt := range proxyDialAttempts { + conn, err = dialer.DialContext(ctx, network, addr) + if err == nil { + return conn, nil + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + // Don't sleep on the last attempt. + if attempt < proxyDialAttempts-1 { + backoff := time.Duration(100*(attempt+1)) * time.Millisecond + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } + return nil, err + }, + } } // NewProxyHandler creates a new sandbox proxy handler. func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler { return &ProxyHandler{ mgr: mgr, - transport: http.DefaultTransport, + transport: newProxyTransport(), } } +// EvictProxy removes cached reverse proxy instances for a sandbox. +// Call this when a sandbox is destroyed. +func (h *ProxyHandler) EvictProxy(sandboxID string) { + h.proxies.Range(func(key, _ any) bool { + if k, ok := key.(string); ok && strings.HasPrefix(k, sandboxID+"/") { + h.proxies.Delete(key) + } + return true + }) +} + // ServeHTTP implements http.Handler. func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Expected path: /proxy/{sandbox_id}/{port}/... @@ -49,10 +120,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sandboxID := parts[0] port := parts[1] - remainder := "" - if len(parts) == 3 { - remainder = parts[2] - } // Validate port is a number in the valid range. portNum, err := strconv.Atoi(port) @@ -68,22 +135,61 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer tracker.Release() - targetHost := fmt.Sprintf("%s:%d", hostIP, portNum) + proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum)) + proxy.ServeHTTP(w, r) +} + +// getOrCreateProxy returns a cached ReverseProxy for the given sandbox+port+host, +// creating one if it doesn't exist. The targetHost is included in the key so +// that an IP change after pause/resume naturally misses the old entry. +func (h *ProxyHandler) getOrCreateProxy(sandboxID, port, targetHost string) *httputil.ReverseProxy { + cacheKey := sandboxID + "/" + port + "/" + targetHost + + if v, ok := h.proxies.Load(cacheKey); ok { + return v.(*httputil.ReverseProxy) + } + + proxyPrefix := "/proxy/" + sandboxID + "/" + port proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { + // Extract remainder from the original path: /proxy/{id}/{port}/{remainder} + remainder := "" + if trimmed := strings.TrimPrefix(req.URL.Path, proxyPrefix); trimmed != req.URL.Path { + remainder = strings.TrimPrefix(trimmed, "/") + } + req.URL.Scheme = "http" req.URL.Host = targetHost req.URL.Path = "/" + remainder - req.URL.RawQuery = r.URL.RawQuery req.Host = targetHost }, + // Rewrite redirect Location headers so they include the /proxy/{id}/{port} + // prefix. Handles both root-relative (/path) and absolute-URL redirects + // (http://internal-ip:port/path) that would otherwise leak internal IPs + // or break directory navigation. + ModifyResponse: func(resp *http.Response) error { + loc := resp.Header.Get("Location") + if loc == "" { + return nil + } + if strings.HasPrefix(loc, "/") { + resp.Header.Set("Location", proxyPrefix+loc) + return nil + } + // Rewrite absolute URLs pointing to the internal target host. + if u, err := url.Parse(loc); err == nil && u.Host == targetHost { + resp.Header.Set("Location", proxyPrefix+u.RequestURI()) + } + return nil + }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err) http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) }, } - proxy.ServeHTTP(w, r) + actual, _ := h.proxies.LoadOrStore(cacheKey, proxy) + return actual.(*httputil.ReverseProxy) } diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index 663d2cb..e15ef0b 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -459,7 +459,7 @@ func (s *Server) WriteFileStream( } httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType()) - resp, err := http.DefaultClient.Do(httpReq) + resp, err := client.HTTPClient().Do(httpReq) if err != nil { pw.CloseWithError(err) <-errCh @@ -504,7 +504,7 @@ func (s *Server) ReadFileStream( return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err)) } - resp, err := http.DefaultClient.Do(httpReq) + resp, err := client.HTTPClient().Do(httpReq) if err != nil { return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err)) } diff --git a/internal/network/setup.go b/internal/network/setup.go index 3874c79..d68da89 100644 --- a/internal/network/setup.go +++ b/internal/network/setup.go @@ -269,6 +269,7 @@ func CreateNetwork(slot *Slot) error { // Create TAP device inside namespace. tapAttrs := netlink.NewLinkAttrs() tapAttrs.Name = tapName + tapAttrs.TxQLen = 5000 // Up from default 1000 to reduce drops under bursty traffic. tap := &netlink.Tuntap{ LinkAttrs: tapAttrs, Mode: netlink.TUNTAP_MODE_TAP, diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index a7ff69d..daa1dba 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -53,6 +53,15 @@ type Manager struct { autoPausedMu sync.Mutex autoPausedIDs []string + + // onDestroy is called with the sandbox ID after cleanup completes. + // Used by ProxyHandler to evict cached reverse proxies. + onDestroy func(sandboxID string) +} + +// SetOnDestroy registers a callback invoked after each sandbox is cleaned up. +func (m *Manager) SetOnDestroy(fn func(sandboxID string)) { + m.onDestroy = fn } // sandboxState holds the runtime state for a single sandbox. @@ -314,6 +323,10 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { slog.Warn("snapshot cleanup error", "id", sandboxID, "error", err) } + if m.onDestroy != nil { + m.onDestroy(sandboxID) + } + slog.Info("sandbox destroyed", "id", sandboxID) return nil } diff --git a/internal/vm/fc.go b/internal/vm/fc.go index 3d0f246..5a131a4 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -84,11 +84,21 @@ func (c *fcClient) setRootfsDrive(ctx context.Context, driveID, path string, rea } // setNetworkInterface configures a network interface attached to a TAP device. +// A tx_rate_limiter caps sustained guest→host throughput to prevent user +// application traffic from completely saturating the TAP device and starving +// envd control traffic (PTY, exec, file ops). func (c *fcClient) setNetworkInterface(ctx context.Context, ifaceID, tapName, macAddr string) error { return c.do(ctx, http.MethodPut, "/network-interfaces/"+ifaceID, map[string]any{ "iface_id": ifaceID, "host_dev_name": tapName, "guest_mac": macAddr, + "tx_rate_limiter": map[string]any{ + "bandwidth": map[string]any{ + "size": 209715200, // 200 MB/s sustained + "refill_time": 1000, // refill period: 1 second + "one_time_burst": 104857600, // 100 MB initial burst + }, + }, }) } diff --git a/pkg/lifecycle/hostpool.go b/pkg/lifecycle/hostpool.go index 3931d7b..48ed6c9 100644 --- a/pkg/lifecycle/hostpool.go +++ b/pkg/lifecycle/hostpool.go @@ -3,6 +3,7 @@ package lifecycle import ( "crypto/tls" "fmt" + "net" "net/http" "strings" "sync" @@ -115,6 +116,34 @@ func (p *HostClientPool) ResolveAddr(addr string) string { return p.ensureScheme(addr) } +// NewProxyTransport returns a new http.RoundTripper configured for proxying +// user traffic to sandbox services. It is intentionally separate from the RPC +// transport returned by Transport() so that heavy proxy traffic (Jupyter +// WebSocket, REST API polling) cannot interfere with Connect RPC streams (PTY, +// exec) via HTTP/2 flow control or connection pool contention. +func (p *HostClientPool) NewProxyTransport() http.RoundTripper { + t := &http.Transport{ + ForceAttemptHTTP2: false, // HTTP/1.1 only — avoids HTTP/2 HOL blocking + MaxIdleConnsPerHost: 20, + MaxIdleConns: 100, + IdleConnTimeout: 120 * time.Second, + DisableCompression: true, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 20 * time.Second, + }).DialContext, + } + + // If the pool uses TLS, the proxy transport must too. + if p.httpClient.Transport != nil { + if ht, ok := p.httpClient.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil { + t.TLSClientConfig = ht.TLSClientConfig.Clone() + } + } + + return t +} + // EnsureScheme adds "http://" if the address has no scheme. // Deprecated: use pool.ResolveAddr which respects the pool's TLS setting. func EnsureScheme(addr string) string {