From 124e097e23995db4e369ca3119555e85683d0854 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 17 May 2026 02:03:06 +0600 Subject: [PATCH] refactor: eliminate DRY violations across control plane and host agent Extract shared helpers to consolidate repeated patterns: - requireRunningSandbox: sandbox lookup + running check (10 call sites) - upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers) - updateLastActive: last_active_at update with background context (5 sites) - attachCowAndCreate: cow loop attach + dmsetup create (devicemapper) - issueRegistrationToken: token gen + Redis + audit (host service) - ErrNotFound sentinel: replaces string matching in hostagent server Also merges duplicate wsProcessOut/wsOutMsg types into one. Net: -208 lines, zero behavior change. --- internal/api/agent_helper.go | 86 ++++++++++++++++++++ internal/api/handlers_exec.go | 40 +--------- internal/api/handlers_exec_stream.go | 46 +---------- internal/api/handlers_files.go | 34 +------- internal/api/handlers_files_stream.go | 34 +------- internal/api/handlers_fs.go | 50 ++---------- internal/api/handlers_process.go | 108 ++++---------------------- internal/api/handlers_pty.go | 48 +----------- internal/devicemapper/devicemapper.go | 77 ++++++++---------- internal/hostagent/server.go | 15 ++-- internal/sandbox/manager.go | 14 ++-- pkg/service/host.go | 70 ++++++++--------- 12 files changed, 207 insertions(+), 415 deletions(-) diff --git a/internal/api/agent_helper.go b/internal/api/agent_helper.go index 6a7acf5..c52e8a9 100644 --- a/internal/api/agent_helper.go +++ b/internal/api/agent_helper.go @@ -3,10 +3,17 @@ package api import ( "context" "fmt" + "log/slog" + "net/http" + "time" + "github.com/go-chi/chi/v5" + "github.com/gorilla/websocket" "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/wrenn/pkg/auth" "git.omukk.dev/wrenn/wrenn/pkg/db" + "git.omukk.dev/wrenn/wrenn/pkg/id" "git.omukk.dev/wrenn/wrenn/pkg/lifecycle" "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect" ) @@ -20,3 +27,82 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host } return pool.GetForHost(host) } + +// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team, +// and verifies it is running. On failure it writes the appropriate HTTP error and +// returns false. +func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) { + sandboxIDStr := chi.URLParam(r, "id") + ctx := r.Context() + + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return db.Sandbox{}, pgtype.UUID{}, "", false + } + + sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) + if err != nil { + writeError(w, http.StatusNotFound, "not_found", "sandbox not found") + return db.Sandbox{}, pgtype.UUID{}, "", false + } + if sb.Status != "running" { + writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")") + return db.Sandbox{}, pgtype.UUID{}, "", false + } + + return sb, sandboxID, sandboxIDStr, true +} + +// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket and resolves +// the auth context — either from middleware (API key) or from the first WS message (JWT). +// Returns the connection and auth context, or an error if authentication fails. +// The caller is responsible for closing the returned connection. +func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request, jwtSecret []byte, queries *db.Queries) (*websocket.Conn, auth.AuthContext, error) { + ctx := r.Context() + ac, hasAuth := auth.FromContext(ctx) + + if hasAuth { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err) + } + return conn, ac, nil + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err) + } + + var wsAC auth.AuthContext + var authErr error + if isAdminWSRoute(ctx) { + wsAC, authErr = wsAuthenticateAdmin(ctx, conn, jwtSecret, queries) + } else { + wsAC, authErr = wsAuthenticate(ctx, conn, jwtSecret, queries) + } + if authErr != nil { + conn.Close() + return nil, auth.AuthContext{}, fmt.Errorf("authentication failed") + } + + return conn, wsAC, nil +} + +// updateLastActive updates the sandbox last_active_at timestamp. +// Uses a background context with timeout for streaming handlers where +// the request context may already be cancelled. +func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{ + ID: sandboxID, + LastActiveAt: pgtype.Timestamptz{ + Time: time.Now(), + Valid: true, + }, + }); err != nil { + slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err) + } +} diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go index 8e94da7..7a94388 100644 --- a/internal/api/handlers_exec.go +++ b/internal/api/handlers_exec.go @@ -3,14 +3,11 @@ package api import ( "encoding/base64" "encoding/json" - "log/slog" "net/http" "time" "unicode/utf8" "connectrpc.com/connect" - "github.com/go-chi/chi/v5" - "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/wrenn/pkg/auth" "git.omukk.dev/wrenn/wrenn/pkg/db" @@ -58,23 +55,11 @@ type backgroundExecResponse struct { // Exec handles POST /v1/capsules/{id}/exec. func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")") + sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { return } - if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{ - ID: sandboxID, - LastActiveAt: pgtype.Timestamptz{ - Time: time.Now(), - Valid: true, - }, - }); err != nil { - slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err) - } + updateLastActive(h.db, sandboxID, sandboxIDStr) writeJSON(w, http.StatusAccepted, backgroundExecResponse{ SandboxID: sandboxIDStr, @@ -151,16 +128,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { duration := time.Since(start) - // Update last active. - if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{ - ID: sandboxID, - LastActiveAt: pgtype.Timestamptz{ - Time: time.Now(), - Valid: true, - }, - }); err != nil { - slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err) - } + updateLastActive(h.db, sandboxID, sandboxIDStr) // Use base64 encoding if output contains non-UTF-8 bytes. stdout := resp.Msg.Stdout diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go index c8b101f..7ad6f18 100644 --- a/internal/api/handlers_exec_stream.go +++ b/internal/api/handlers_exec_stream.go @@ -5,7 +5,6 @@ import ( "encoding/json" "log/slog" "net/http" - "time" "connectrpc.com/connect" "github.com/go-chi/chi/v5" @@ -59,37 +58,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { return } - // Authenticate: use context from middleware (API key) or WS first message (JWT). - ac, hasAuth := auth.FromContext(ctx) - - if !hasAuth { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - slog.Error("websocket upgrade failed", "error", err) - return - } - defer conn.Close() - - var wsAC auth.AuthContext - var authErr error - if isAdminWSRoute(ctx) { - wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db) - } else { - wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db) - } - if authErr != nil { - sendWSError(conn, "authentication failed") - return - } - ac = wsAC - - h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) + conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db) if err != nil { - slog.Error("websocket upgrade failed", "error", err) + slog.Error("websocket upgrade/auth failed", "error", err) return } defer conn.Close() @@ -186,18 +157,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C } } - // Update last active using a fresh context (the request context may be cancelled). - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() - if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{ - ID: sandboxID, - LastActiveAt: pgtype.Timestamptz{ - Time: time.Now(), - Valid: true, - }, - }); err != nil { - slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err) - } + updateLastActive(h.db, sandboxID, sandboxIDStr) } func sendWSError(conn *websocket.Conn, msg string) { diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go index f69c8f1..d0ca30c 100644 --- a/internal/api/handlers_files.go +++ b/internal/api/handlers_files.go @@ -7,11 +7,9 @@ import ( "net/http" "connectrpc.com/connect" - "github.com/go-chi/chi/v5" "git.omukk.dev/wrenn/wrenn/pkg/auth" "git.omukk.dev/wrenn/wrenn/pkg/db" - "git.omukk.dev/wrenn/wrenn/pkg/id" "git.omukk.dev/wrenn/wrenn/pkg/lifecycle" pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen" ) @@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl // - "path" text field: absolute destination path inside the sandbox // - "file" file field: binary content to write func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -108,23 +94,11 @@ type readFileRequest struct { // Download handles POST /v1/capsules/{id}/files/read. // Accepts JSON body with path, returns raw file content with Content-Disposition. func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go index 88377ae..35dfb1f 100644 --- a/internal/api/handlers_files_stream.go +++ b/internal/api/handlers_files_stream.go @@ -8,11 +8,9 @@ import ( "net/http" "connectrpc.com/connect" - "github.com/go-chi/chi/v5" "git.omukk.dev/wrenn/wrenn/pkg/auth" "git.omukk.dev/wrenn/wrenn/pkg/db" - "git.omukk.dev/wrenn/wrenn/pkg/id" "git.omukk.dev/wrenn/wrenn/pkg/lifecycle" pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen" ) @@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file // Expects multipart/form-data with "path" text field and "file" file field. // Streams file content directly from the request body to the host agent without buffering. func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -153,23 +139,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request // StreamDownload handles POST /v1/capsules/{id}/files/stream/read. // Accepts JSON body with path, streams file content back without buffering. func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } diff --git a/internal/api/handlers_fs.go b/internal/api/handlers_fs.go index cfdd6a7..a408978 100644 --- a/internal/api/handlers_fs.go +++ b/internal/api/handlers_fs.go @@ -4,11 +4,9 @@ import ( "net/http" "connectrpc.com/connect" - "github.com/go-chi/chi/v5" "git.omukk.dev/wrenn/wrenn/pkg/auth" "git.omukk.dev/wrenn/wrenn/pkg/db" - "git.omukk.dev/wrenn/wrenn/pkg/id" "git.omukk.dev/wrenn/wrenn/pkg/lifecycle" pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen" ) @@ -58,23 +56,11 @@ type removeRequest struct { // ListDir handles POST /v1/capsules/{id}/files/list. func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) { // MakeDir handles POST /v1/capsules/{id}/files/mkdir. func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) { // Remove handles POST /v1/capsules/{id}/files/remove. func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } diff --git a/internal/api/handlers_process.go b/internal/api/handlers_process.go index 2bb6bb9..be08334 100644 --- a/internal/api/handlers_process.go +++ b/internal/api/handlers_process.go @@ -5,7 +5,6 @@ import ( "log/slog" "net/http" "strconv" - "time" "connectrpc.com/connect" "github.com/go-chi/chi/v5" @@ -44,23 +43,11 @@ type processListResponse struct { // ListProcesses handles GET /v1/capsules/{id}/processes. func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -95,24 +82,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) { // KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}. // The selector can be a numeric PID or a string tag. func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) { - sandboxIDStr := chi.URLParam(r, "id") selectorStr := chi.URLParam(r, "selector") ctx := r.Context() ac := auth.MustFromContext(ctx) - sandboxID, err := id.ParseSandboxID(sandboxIDStr) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") - return - } - - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) - if err != nil { - writeError(w, http.StatusNotFound, "not_found", "sandbox not found") - return - } - if sb.Status != "running" { - writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")") + sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID) + if !ok { return } @@ -146,14 +121,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } -// wsProcessOut is the JSON message sent to the WebSocket client. -type wsProcessOut struct { - Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error" - PID uint32 `json:"pid,omitempty"` // only for "start" - Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error" - ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit" -} - // ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream. func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) { sandboxIDStr := chi.URLParam(r, "id") @@ -166,37 +133,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) return } - // Authenticate: use context from middleware (API key) or WS first message (JWT). - ac, hasAuth := auth.FromContext(ctx) - - if !hasAuth { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - slog.Error("process stream websocket upgrade failed", "error", err) - return - } - defer conn.Close() - - var wsAC auth.AuthContext - var authErr error - if isAdminWSRoute(ctx) { - wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db) - } else { - wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db) - } - if authErr != nil { - sendProcessWSError(conn, "authentication failed") - return - } - ac = wsAC - - h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) + conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db) if err != nil { - slog.Error("process stream websocket upgrade failed", "error", err) + slog.Error("process stream websocket upgrade/auth failed", "error", err) return } defer conn.Close() @@ -207,17 +146,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) { sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { - sendProcessWSError(conn, "sandbox not found") + sendWSError(conn, "sandbox not found") return } if sb.Status != "running" { - sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")") + sendWSError(conn, "sandbox is not running (status: "+sb.Status+")") return } agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) if err != nil { - sendProcessWSError(conn, "sandbox host is not reachable") + sendWSError(conn, "sandbox host is not reachable") return } @@ -236,7 +175,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket. stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq)) if err != nil { - sendProcessWSError(conn, "failed to connect to process: "+err.Error()) + sendWSError(conn, "failed to connect to process: "+err.Error()) return } defer stream.Close() @@ -257,42 +196,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket. resp := stream.Msg() switch ev := resp.Event.(type) { case *pb.ConnectProcessResponse_Start: - writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid}) + writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid}) case *pb.ConnectProcessResponse_Data: switch o := ev.Data.Output.(type) { case *pb.ExecStreamData_Stdout: - writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)}) + writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)}) case *pb.ExecStreamData_Stderr: - writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)}) + writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)}) } case *pb.ConnectProcessResponse_End: exitCode := ev.End.ExitCode - writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode}) + writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode}) } } if err := stream.Err(); err != nil { if streamCtx.Err() == nil { - sendProcessWSError(conn, err.Error()) + sendWSError(conn, err.Error()) } } - // Update last active using a fresh context. - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() - if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{ - ID: sandboxID, - LastActiveAt: pgtype.Timestamptz{ - Time: time.Now(), - Valid: true, - }, - }); err != nil { - slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err) - } -} - -func sendProcessWSError(conn *websocket.Conn, msg string) { - writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg}) + updateLastActive(h.db, sandboxID, sandboxIDStr) } diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index d0db965..fdf9c4d 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -90,40 +90,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) { return } - // API key auth is handled by middleware (sets context). - // For browser JWT auth, we authenticate after upgrade via first WS message. - ac, hasAuth := auth.FromContext(ctx) - - if !hasAuth { - // No pre-upgrade auth — upgrade first, then authenticate via WS message. - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - slog.Error("pty websocket upgrade failed", "error", err) - return - } - defer conn.Close() - - ws := &wsWriter{conn: conn} - - var wsAC auth.AuthContext - if isAdminWSRoute(ctx) { - wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db) - } else { - wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db) - } - if err != nil { - ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true}) - return - } - ac = wsAC - - h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) + conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db) if err != nil { - slog.Error("pty websocket upgrade failed", "error", err) + slog.Error("pty websocket upgrade/auth failed", "error", err) return } defer conn.Close() @@ -168,18 +137,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true}) } - // Update last active using a fresh context. - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() - if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{ - ID: sandboxID, - LastActiveAt: pgtype.Timestamptz{ - Time: time.Now(), - Valid: true, - }, - }); err != nil { - slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err) - } + updateLastActive(h.db, sandboxID, sandboxIDStr) } func (h *ptyHandler) handleStart( diff --git a/internal/devicemapper/devicemapper.go b/internal/devicemapper/devicemapper.go index f53b109..71c7104 100644 --- a/internal/devicemapper/devicemapper.go +++ b/internal/devicemapper/devicemapper.go @@ -109,6 +109,31 @@ type SnapshotDevice struct { CowLoopDev string // loop device for the CoW file } +// attachCowAndCreate attaches a CoW file as a loop device, creates the +// dm-snapshot target, and returns the assembled SnapshotDevice. On failure +// it detaches the CoW loop device before returning. +func attachCowAndCreate(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) { + cowLoopDev, err := losetupCreateRW(cowPath) + if err != nil { + return nil, fmt.Errorf("losetup cow: %w", err) + } + + sectors := originSizeBytes / 512 + if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil { + if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil { + slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr) + } + return nil, fmt.Errorf("dmsetup create: %w", err) + } + + return &SnapshotDevice{ + Name: name, + DevicePath: "/dev/mapper/" + name, + CowPath: cowPath, + CowLoopDev: cowLoopDev, + }, nil +} + // CreateSnapshot sets up a new dm-snapshot device. // // It creates a sparse CoW file, attaches it as a loop device, and creates @@ -117,45 +142,24 @@ type SnapshotDevice struct { // // The origin loop device must already exist (from LoopRegistry.Acquire). func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) { - // Create sparse CoW file. The logical size limits how many blocks can be - // modified; because the file is sparse, only written blocks use real disk. if err := createSparseFile(cowPath, cowSizeBytes); err != nil { return nil, fmt.Errorf("create cow file: %w", err) } - cowLoopDev, err := losetupCreateRW(cowPath) + dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes) if err != nil { os.Remove(cowPath) - return nil, fmt.Errorf("losetup cow: %w", err) + return nil, err } - // The dm-snapshot virtual device size must match the origin — the snapshot - // target maps 1:1 onto origin sectors. The CoW file just needs enough - // space to store all modified blocks (it's sparse, so 20GB costs nothing). - sectors := originSizeBytes / 512 - if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil { - if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil { - slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr) - } - os.Remove(cowPath) - return nil, fmt.Errorf("dmsetup create: %w", err) - } - - devPath := "/dev/mapper/" + name - slog.Info("dm-snapshot created", "name", name, - "device", devPath, + "device", dev.DevicePath, "origin", originLoopDev, "cow", cowPath, ) - return &SnapshotDevice{ - Name: name, - DevicePath: devPath, - CowPath: cowPath, - CowLoopDev: cowLoopDev, - }, nil + return dev, nil } // RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file. @@ -171,34 +175,19 @@ func RestoreSnapshot(ctx context.Context, name, originLoopDev, cowPath string, o } } - cowLoopDev, err := losetupCreateRW(cowPath) + dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes) if err != nil { - return nil, fmt.Errorf("losetup cow: %w", err) + return nil, err } - sectors := originSizeBytes / 512 - if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil { - if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil { - slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr) - } - return nil, fmt.Errorf("dmsetup create: %w", err) - } - - devPath := "/dev/mapper/" + name - slog.Info("dm-snapshot restored", "name", name, - "device", devPath, + "device", dev.DevicePath, "origin", originLoopDev, "cow", cowPath, ) - return &SnapshotDevice{ - Name: name, - DevicePath: devPath, - CowPath: cowPath, - CowLoopDev: cowLoopDev, - }, nil + return dev, nil } // RemoveSnapshot tears down a dm-snapshot device and its CoW loop device. diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index faf6424..a935a16 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -2,6 +2,7 @@ package hostagent import ( "context" + "errors" "fmt" "io" "log/slog" @@ -193,7 +194,7 @@ func (s *Server) PingSandbox( req *connect.Request[pb.PingSandboxRequest], ) (*connect.Response[pb.PingSandboxResponse], error) { if err := s.mgr.Ping(req.Msg.SandboxId); err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } return nil, connect.NewError(connect.CodeFailedPrecondition, err) @@ -590,7 +591,7 @@ func (s *Server) GetSandboxMetrics( points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range) if err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } if strings.Contains(err.Error(), "invalid range") { @@ -608,7 +609,7 @@ func (s *Server) FlushSandboxMetrics( ) (*connect.Response[pb.FlushSandboxMetricsResponse], error) { pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId) if err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } return nil, connect.NewError(connect.CodeInternal, err) @@ -761,7 +762,7 @@ func (s *Server) StartBackground( pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd) if err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err)) @@ -779,7 +780,7 @@ func (s *Server) ListProcesses( ) (*connect.Response[pb.ListProcessesResponse], error) { procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId) if err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err)) @@ -830,7 +831,7 @@ func (s *Server) KillProcess( } if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return nil, connect.NewError(connect.CodeNotFound, err) } return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err)) @@ -859,7 +860,7 @@ func (s *Server) ConnectProcess( events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag) if err != nil { - if strings.Contains(err.Error(), "not found") { + if errors.Is(err, sandbox.ErrNotFound) { return connect.NewError(connect.CodeNotFound, err) } return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err)) diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 9aed83f..9698e64 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -2,6 +2,7 @@ package sandbox import ( "context" + "errors" "fmt" "log/slog" "net" @@ -24,6 +25,9 @@ import ( envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen" ) +// ErrNotFound is returned when a sandbox is not present in the in-memory map. +var ErrNotFound = errors.New("sandbox not found") + // Config holds the paths and defaults for the sandbox manager. type Config struct { WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package @@ -904,7 +908,7 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t m.mu.Unlock() if !ok { - return 0, fmt.Errorf("sandbox %s not found", sandboxID) + return 0, fmt.Errorf("%w: %s", ErrNotFound, sandboxID) } // Flush guest page cache to disk before stopping the VM. Without this, @@ -1395,7 +1399,7 @@ func (m *Manager) Ping(sandboxID string) error { sb, ok := m.boxes[sandboxID] if !ok { - return fmt.Errorf("sandbox not found: %s", sandboxID) + return fmt.Errorf("%w: %s", ErrNotFound, sandboxID) } if sb.Status != models.StatusRunning { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) @@ -1421,7 +1425,7 @@ func (m *Manager) get(sandboxID string) (*sandboxState, error) { sb, ok := m.boxes[sandboxID] if !ok { - return nil, fmt.Errorf("sandbox not found: %s", sandboxID) + return nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID) } return sb, nil } @@ -1731,7 +1735,7 @@ func (m *Manager) GetMetrics(sandboxID, rangeTier string) ([]MetricPoint, error) sb, ok := m.boxes[sandboxID] m.mu.RUnlock() if !ok { - return nil, fmt.Errorf("sandbox not found: %s", sandboxID) + return nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID) } if sb.ring == nil { return nil, nil @@ -1784,7 +1788,7 @@ func (m *Manager) FlushMetrics(sandboxID string) (pts10m, pts2h, pts24h []Metric sb, ok := m.boxes[sandboxID] m.mu.RUnlock() if !ok { - return nil, nil, nil, fmt.Errorf("sandbox not found: %s", sandboxID) + return nil, nil, nil, fmt.Errorf("%w: %s", ErrNotFound, sandboxID) } m.stopSampler(sb) diff --git a/pkg/service/host.go b/pkg/service/host.go index 9f5b5c8..49e3be6 100644 --- a/pkg/service/host.go +++ b/pkg/service/host.go @@ -94,6 +94,31 @@ type regTokenPayload struct { const regTokenTTL = time.Hour +func (s *HostService) issueRegistrationToken(ctx context.Context, hostID, createdBy pgtype.UUID) (string, error) { + token := id.NewRegistrationToken() + tokenID := id.NewHostTokenID() + + payload, _ := json.Marshal(regTokenPayload{ + HostID: id.FormatHostID(hostID), + TokenID: id.FormatHostTokenID(tokenID), + }) + if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { + return "", fmt.Errorf("store registration token: %w", err) + } + + now := time.Now() + if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{ + ID: tokenID, + HostID: hostID, + CreatedBy: createdBy, + ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, + }); err != nil { + slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) + } + + return token, nil +} + // requireAdminOrOwner returns nil iff the role is "owner" or "admin". func requireAdminOrOwner(role string) error { if role == "owner" || role == "admin" { @@ -159,26 +184,9 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat return HostCreateResult{}, fmt.Errorf("insert host: %w", err) } - // Generate registration token and store in Redis + Postgres audit trail. - token := id.NewRegistrationToken() - tokenID := id.NewHostTokenID() - - payload, _ := json.Marshal(regTokenPayload{ - HostID: id.FormatHostID(hostID), - TokenID: id.FormatHostTokenID(tokenID), - }) - if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { - return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) - } - - now := time.Now() - if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{ - ID: tokenID, - HostID: hostID, - CreatedBy: p.RequestingUserID, - ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, - }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) + token, err := s.issueRegistrationToken(ctx, hostID, p.RequestingUserID) + if err != nil { + return HostCreateResult{}, err } return HostCreateResult{Host: host, RegistrationToken: token}, nil @@ -218,25 +226,9 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI } } - token := id.NewRegistrationToken() - tokenID := id.NewHostTokenID() - - payload, _ := json.Marshal(regTokenPayload{ - HostID: id.FormatHostID(hostID), - TokenID: id.FormatHostTokenID(tokenID), - }) - if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { - return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) - } - - now := time.Now() - if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{ - ID: tokenID, - HostID: hostID, - CreatedBy: userID, - ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, - }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) + token, err := s.issueRegistrationToken(ctx, hostID, userID) + if err != nil { + return HostCreateResult{}, err } return HostCreateResult{Host: host, RegistrationToken: token}, nil