1
0
forked from wrenn/wrenn

refactor: eliminate DRY violations across control plane and host agent

Extract shared helpers to consolidate repeated patterns:
- requireRunningSandbox: sandbox lookup + running check (10 call sites)
- upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers)
- updateLastActive: last_active_at update with background context (5 sites)
- attachCowAndCreate: cow loop attach + dmsetup create (devicemapper)
- issueRegistrationToken: token gen + Redis + audit (host service)
- ErrNotFound sentinel: replaces string matching in hostagent server

Also merges duplicate wsProcessOut/wsOutMsg types into one.

Net: -208 lines, zero behavior change.
This commit is contained in:
2026-05-17 02:03:06 +06:00
parent a5425969ed
commit 124e097e23
12 changed files with 207 additions and 415 deletions

View File

@ -3,10 +3,17 @@ package api
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
@ -20,3 +27,82 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
}
return pool.GetForHost(host)
}
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
// and verifies it is running. On failure it writes the appropriate HTTP error and
// returns false.
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
return sb, sandboxID, sandboxIDStr, true
}
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket and resolves
// the auth context — either from middleware (API key) or from the first WS message (JWT).
// Returns the connection and auth context, or an error if authentication fails.
// The caller is responsible for closing the returned connection.
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request, jwtSecret []byte, queries *db.Queries) (*websocket.Conn, auth.AuthContext, error) {
ctx := r.Context()
ac, hasAuth := auth.FromContext(ctx)
if hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
}
return conn, ac, nil
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
}
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, jwtSecret, queries)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, jwtSecret, queries)
}
if authErr != nil {
conn.Close()
return nil, auth.AuthContext{}, fmt.Errorf("authentication failed")
}
return conn, wsAC, nil
}
// updateLastActive updates the sandbox last_active_at timestamp.
// Uses a background context with timeout for streaming handlers where
// the request context may already be cancelled.
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
}

View File

@ -3,14 +3,11 @@ package api
import (
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"time"
"unicode/utf8"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
// Exec handles POST /v1/capsules/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
SandboxID: sandboxIDStr,
@ -151,16 +128,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
duration := time.Since(start)
// Update last active.
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
// Use base64 encoding if output contains non-UTF-8 bytes.
stdout := resp.Msg.Stdout

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -59,37 +58,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
slog.Error("websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -186,18 +157,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
}
}
// Update last active using a fresh context (the request context may be cancelled).
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func sendWSError(conn *websocket.Conn, msg string) {

View File

@ -7,11 +7,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -108,23 +94,11 @@ type readFileRequest struct {
// Download handles POST /v1/capsules/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -8,11 +8,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -153,23 +139,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -4,11 +4,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -58,23 +56,11 @@ type removeRequest struct {
// ListDir handles POST /v1/capsules/{id}/files/list.
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
// Remove handles POST /v1/capsules/{id}/files/remove.
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -5,7 +5,6 @@ import (
"log/slog"
"net/http"
"strconv"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -44,23 +43,11 @@ type processListResponse struct {
// ListProcesses handles GET /v1/capsules/{id}/processes.
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -95,24 +82,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
// The selector can be a numeric PID or a string tag.
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -146,14 +121,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// wsProcessOut is the JSON message sent to the WebSocket client.
type wsProcessOut struct {
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
PID uint32 `json:"pid,omitempty"` // only for "start"
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
@ -166,37 +133,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
slog.Error("process stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -207,17 +146,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
sendWSError(conn, "sandbox host is not reachable")
return
}
@ -236,7 +175,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
if err != nil {
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
sendWSError(conn, "failed to connect to process: "+err.Error())
return
}
defer stream.Close()
@ -257,42 +196,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.ConnectProcessResponse_Start:
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
case *pb.ConnectProcessResponse_Data:
switch o := ev.Data.Output.(type) {
case *pb.ExecStreamData_Stdout:
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
case *pb.ExecStreamData_Stderr:
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
}
case *pb.ConnectProcessResponse_End:
exitCode := ev.End.ExitCode
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
}
}
if err := stream.Err(); err != nil {
if streamCtx.Err() == nil {
sendProcessWSError(conn, err.Error())
sendWSError(conn, err.Error())
}
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
}
}
func sendProcessWSError(conn *websocket.Conn, msg string) {
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
updateLastActive(h.db, sandboxID, sandboxIDStr)
}

View File

@ -90,40 +90,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
return
}
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
slog.Error("pty websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -168,18 +137,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) handleStart(