From 339cd7bee13df655041694dcf276358ec17c1731 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Fri, 24 Apr 2026 15:48:38 +0600 Subject: [PATCH] fix: security and stability fixes from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Scope WebSocket auth bypass to only WS endpoints by restructuring routes into separate chi Groups. Non-WS routes no longer passthrough unauthenticated requests with spoofed Upgrade headers. Added optionalAPIKeyOrJWT middleware for WS routes (injects auth context from API key/JWT if present, passes through otherwise) and markAdminWS middleware for admin WS routes. - Fix nil pointer dereference in envd Handler.Wait() — p.tty.Close() was called unconditionally but p.tty is nil for non-PTY processes, crashing every non-PTY process exit. - Fix goroutine leak in sandbox Pause — stopSampler was never called, leaking one sampler goroutine per successful pause operation. - Decouple PTY WebSocket reads from RPC dispatch using a buffered channel to prevent backpressure-induced connection drops under fast typing. Includes input coalescing to reduce RPC call volume. --- VERSION_CP | 2 +- .../services/process/handler/handler.go | 4 +- internal/api/handlers_pty.go | 56 ++++++- internal/api/helpers_ws.go | 7 - internal/api/middleware_admin.go | 18 +-- internal/api/middleware_auth.go | 62 +++++++- internal/api/middleware_jwt.go | 7 - internal/api/openapi.yaml | 2 +- internal/api/server.go | 144 +++++++++++------- internal/sandbox/manager.go | 5 + 10 files changed, 214 insertions(+), 93 deletions(-) diff --git a/VERSION_CP b/VERSION_CP index b1e80bb..845639e 100644 --- a/VERSION_CP +++ b/VERSION_CP @@ -1 +1 @@ -0.1.3 +0.1.4 diff --git a/envd/internal/services/process/handler/handler.go b/envd/internal/services/process/handler/handler.go index dc5a8dd..9a73103 100644 --- a/envd/internal/services/process/handler/handler.go +++ b/envd/internal/services/process/handler/handler.go @@ -446,7 +446,9 @@ func (p *Handler) Wait() { err := p.cmd.Wait() - p.tty.Close() + if p.tty != nil { + p.tty.Close() + } var errMsg *string diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index 181fc9d..f23954d 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -311,10 +311,17 @@ func runPtyLoop( } }() - // Input pump: read from WebSocket, dispatch to host agent. + // Input pump: decouple WebSocket reads from RPC dispatch. + // Reader goroutine drains the WebSocket into a buffered channel; + // sender goroutine dispatches RPCs at its own pace. This prevents + // slow RPCs from stalling WebSocket reads and causing proxy timeouts. + inputCh := make(chan wsPtyIn, 64) + + // Reader: drain WebSocket as fast as possible. wg.Add(1) go func() { defer wg.Done() + defer close(inputCh) defer cancel() for { @@ -328,6 +335,22 @@ func runPtyLoop( continue } + select { + case inputCh <- msg: + default: + // Buffer full — drop frame to keep reader unblocked. + slog.Debug("pty input buffer full, dropping frame", "type", msg.Type) + } + } + }() + + // Sender: dispatch RPCs from channel, coalescing consecutive input messages. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + for msg := range inputCh { // Use a background context for unary RPCs so they complete // even if the stream context is being cancelled. rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -339,6 +362,10 @@ func runPtyLoop( rpcCancel() continue } + + // Coalesce: drain any queued input messages into a single RPC. + data = coalescePtyInput(inputCh, data) + if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ SandboxId: sandboxID, Tag: tag, @@ -394,6 +421,33 @@ func runPtyLoop( wg.Wait() } +// coalescePtyInput drains any immediately-available "input" messages from the +// channel and appends their decoded data to buf, reducing RPC call volume +// during bursts of fast typing. +func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { + for { + select { + case msg, ok := <-ch: + if !ok { + return buf + } + if msg.Type != "input" { + // Non-input message — can't coalesce. Put-back isn't possible + // with channels, but resize/kill during a typing burst is rare + // enough that dropping one is acceptable. + return buf + } + data, err := base64.StdEncoding.DecodeString(msg.Data) + if err != nil { + continue + } + buf = append(buf, data...) + default: + return buf + } + } +} + // newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars. func newPtyTag() string { return "pty-" + id.NewPtyTag() diff --git a/internal/api/helpers_ws.go b/internal/api/helpers_ws.go index 8488cbd..f34a1df 100644 --- a/internal/api/helpers_ws.go +++ b/internal/api/helpers_ws.go @@ -3,8 +3,6 @@ package api import ( "context" "fmt" - "net/http" - "strings" "time" "github.com/gorilla/websocket" @@ -14,11 +12,6 @@ import ( "git.omukk.dev/wrenn/wrenn/pkg/id" ) -// isWebSocketUpgrade returns true if the request is a WebSocket upgrade. -func isWebSocketUpgrade(r *http.Request) bool { - return strings.EqualFold(r.Header.Get("Upgrade"), "websocket") -} - // ctxKeyAdminWS is a context key for flagging admin WS routes. type ctxKeyAdminWS struct{} diff --git a/internal/api/middleware_admin.go b/internal/api/middleware_admin.go index 670c586..e850435 100644 --- a/internal/api/middleware_admin.go +++ b/internal/api/middleware_admin.go @@ -15,7 +15,6 @@ func injectPlatformTeam() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, ok := auth.FromContext(r.Context()); !ok { - // No auth context yet (WS upgrade); handler will inject platform team after WS auth. next.ServeHTTP(w, r) return } @@ -27,23 +26,24 @@ func injectPlatformTeam() func(http.Handler) http.Handler { } } +// markAdminWS flags the request context as an admin WebSocket route. +// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin +// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin. +func markAdminWS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context()))) + }) +} + // requireAdmin validates that the authenticated user is a platform admin. // Must run after requireJWT (depends on AuthContext being present). // Re-validates against the DB — the JWT is_admin claim is for UI only; // the DB is the source of truth for admin access. -// WebSocket upgrade requests without auth context are passed through — -// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin. func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ac, ok := auth.FromContext(r.Context()) if !ok { - if isWebSocketUpgrade(r) { - ctx := r.Context() - ctx = setAdminWSFlag(ctx) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required") return } diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go index 580c8c0..0b3e571 100644 --- a/internal/api/middleware_auth.go +++ b/internal/api/middleware_auth.go @@ -85,15 +85,61 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler return } - // WebSocket upgrade requests may not carry auth headers (browsers - // cannot set custom headers on WS connections). Pass through — - // the WS handler authenticates via the first message after upgrade. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } - writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer required") }) } } + +// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject +// unauthenticated requests. It injects auth context when valid credentials +// are present (supporting SDK clients that set X-API-Key on WebSocket +// upgrades) and passes through otherwise so the handler can authenticate +// after the WebSocket upgrade via the first message. +func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try API key. + if key := r.Header.Get("X-API-Key"); key != "" { + hash := auth.HashAPIKey(key) + row, err := queries.GetAPIKeyByHash(r.Context(), hash) + if err == nil { + if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { + slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err) + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: row.TeamID, + APIKeyID: row.ID, + APIKeyName: row.Name, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + // Try JWT bearer token. + if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") { + tokenStr := strings.TrimPrefix(header, "Bearer ") + if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil { + if teamID, err := id.ParseTeamID(claims.TeamID); err == nil { + if userID, err := id.ParseUserID(claims.Subject); err == nil { + if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" { + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ + TeamID: teamID, + UserID: userID, + Email: claims.Email, + Name: claims.Name, + Role: claims.Role, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + } + } + } + + // No valid credentials — pass through for handler to authenticate. + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index b19c838..00649c6 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -22,13 +22,6 @@ func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Hand tokenStr = strings.TrimPrefix(header, "Bearer ") } if tokenStr == "" { - // WebSocket upgrade requests may not have an Authorization header - // (browsers cannot set custom headers on WS connections). Let them - // through — the handler authenticates via the first WS message. - if isWebSocketUpgrade(r) { - next.ServeHTTP(w, r) - return - } writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer required") return } diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 8d3861c..c18c575 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -2,7 +2,7 @@ openapi: "3.1.0" info: title: Wrenn API description: MicroVM-based code execution platform API. - version: "0.1.3" + version: "0.1.4" servers: - url: http://localhost:8080 diff --git a/internal/api/server.go b/internal/api/server.go index ced39a5..11b6fbb 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -161,35 +161,47 @@ func New( r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search) // Capsule lifecycle: accepts API key or JWT bearer token. - // WebSocket upgrade requests without auth headers are passed through by - // requireAPIKeyOrJWT — the WS handlers authenticate via first message. r.Route("/v1/capsules", func(r chi.Router) { - r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) - r.Post("/", sandbox.Create) - r.Get("/", sandbox.List) - r.Get("/stats", statsH.GetStats) - r.Get("/usage", usageH.GetUsage) + // Auth-required routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Post("/", sandbox.Create) + r.Get("/", sandbox.List) + r.Get("/stats", statsH.GetStats) + r.Get("/usage", usageH.GetUsage) + }) r.Route("/{id}", func(r chi.Router) { - r.Get("/", sandbox.Get) - r.Delete("/", sandbox.Destroy) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/ping", sandbox.Ping) - r.Post("/pause", sandbox.Pause) - r.Post("/resume", sandbox.Resume) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/stream/write", filesStream.StreamUpload) - r.Post("/files/stream/read", filesStream.StreamDownload) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS routes. + r.Group(func(r chi.Router) { + r.Use(requireAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/", sandbox.Get) + r.Delete("/", sandbox.Destroy) + r.Post("/exec", exec.Exec) + r.Post("/ping", sandbox.Ping) + r.Post("/pause", sandbox.Pause) + r.Post("/resume", sandbox.Resume) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/stream/write", filesStream.StreamUpload) + r.Post("/files/stream/read", filesStream.StreamDownload) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // WebSocket endpoints — handlers authenticate after upgrade. + // optionalAPIKeyOrJWT injects auth context from headers when + // present (SDK clients) but does not reject when absent (browsers). + r.Group(func(r chi.Router) { + r.Use(optionalAPIKeyOrJWT(queries, jwtSecret)) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) @@ -248,39 +260,55 @@ func New( // Platform admin routes — require JWT + DB-validated admin status. r.Route("/v1/admin", func(r chi.Router) { - r.Use(requireJWT(jwtSecret, queries)) - r.Use(requireAdmin(queries)) - r.Get("/teams", teamH.AdminListTeams) - r.Put("/teams/{id}/byoc", teamH.SetBYOC) - r.Delete("/teams/{id}", teamH.AdminDeleteTeam) - r.Get("/users", usersH.AdminListUsers) - r.Put("/users/{id}/active", usersH.SetUserActive) - r.Get("/audit-logs", auditH.AdminList) - r.Get("/templates", buildH.ListTemplates) - r.Delete("/templates/{name}", buildH.DeleteTemplate) - r.Post("/builds", buildH.Create) - r.Get("/builds", buildH.List) - r.Get("/builds/{id}", buildH.Get) - r.Post("/builds/{id}/cancel", buildH.Cancel) - r.Post("/capsules", adminCapsules.Create) - r.Get("/capsules", adminCapsules.List) + // Auth-required admin routes (non-capsule + capsule list/create). + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Get("/teams", teamH.AdminListTeams) + r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Delete("/teams/{id}", teamH.AdminDeleteTeam) + r.Get("/users", usersH.AdminListUsers) + r.Put("/users/{id}/active", usersH.SetUserActive) + r.Get("/audit-logs", auditH.AdminList) + r.Get("/templates", buildH.ListTemplates) + r.Delete("/templates/{name}", buildH.DeleteTemplate) + r.Post("/builds", buildH.Create) + r.Get("/builds", buildH.List) + r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) + r.Post("/capsules", adminCapsules.Create) + r.Get("/capsules", adminCapsules.List) + }) + r.Route("/capsules/{id}", func(r chi.Router) { - r.Use(injectPlatformTeam()) - r.Get("/", adminCapsules.Get) - r.Delete("/", adminCapsules.Destroy) - r.Post("/snapshot", adminCapsules.Snapshot) - r.Post("/exec", exec.Exec) - r.Get("/exec/stream", execStream.ExecStream) - r.Post("/files/write", files.Upload) - r.Post("/files/read", files.Download) - r.Post("/files/list", fsH.ListDir) - r.Post("/files/mkdir", fsH.MakeDir) - r.Post("/files/remove", fsH.Remove) - r.Get("/metrics", metricsH.GetMetrics) - r.Get("/pty", ptyH.PtySession) - r.Get("/processes", processH.ListProcesses) - r.Delete("/processes/{selector}", processH.KillProcess) - r.Get("/processes/{selector}/stream", processH.ConnectProcess) + // Auth-required non-WS admin capsule routes. + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret, queries)) + r.Use(requireAdmin(queries)) + r.Use(injectPlatformTeam()) + r.Get("/", adminCapsules.Get) + r.Delete("/", adminCapsules.Destroy) + r.Post("/snapshot", adminCapsules.Snapshot) + r.Post("/exec", exec.Exec) + r.Post("/files/write", files.Upload) + r.Post("/files/read", files.Download) + r.Post("/files/list", fsH.ListDir) + r.Post("/files/mkdir", fsH.MakeDir) + r.Post("/files/remove", fsH.Remove) + r.Get("/metrics", metricsH.GetMetrics) + r.Get("/processes", processH.ListProcesses) + r.Delete("/processes/{selector}", processH.KillProcess) + }) + + // Admin WebSocket endpoints — handlers authenticate after upgrade + // via wsAuthenticateAdmin. markAdminWS sets the context flag so + // handlers know to use admin auth instead of regular auth. + r.Group(func(r chi.Router) { + r.Use(markAdminWS) + r.Get("/exec/stream", execStream.ExecStream) + r.Get("/pty", ptyH.PtySession) + r.Get("/processes/{selector}/stream", processH.ConnectProcess) + }) }) }) diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 524631d..a7ff69d 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -363,6 +363,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Stop the metrics sampler goroutine before tearing down any resources + // it reads (dm device, Firecracker PID). Without this, the sampler + // leaks on every successful pause. + m.stopSampler(sb) + // Step 0: Drain in-flight proxy connections before freezing vCPUs. // This prevents Go runtime corruption inside the guest caused by stale // TCP state from connections that were alive when the VM was snapshotted.