1
0
forked from wrenn/wrenn

refactor: eliminate DRY violations across control plane and host agent

Extract shared helpers to consolidate repeated patterns:
- requireRunningSandbox: sandbox lookup + running check (10 call sites)
- upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers)
- updateLastActive: last_active_at update with background context (5 sites)
- attachCowAndCreate: cow loop attach + dmsetup create (devicemapper)
- issueRegistrationToken: token gen + Redis + audit (host service)
- ErrNotFound sentinel: replaces string matching in hostagent server

Also merges duplicate wsProcessOut/wsOutMsg types into one.

Net: -208 lines, zero behavior change.
This commit is contained in:
2026-05-17 02:03:06 +06:00
parent a5425969ed
commit 124e097e23
12 changed files with 207 additions and 415 deletions

View File

@ -5,7 +5,6 @@ import (
"log/slog"
"net/http"
"strconv"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -44,23 +43,11 @@ type processListResponse struct {
// ListProcesses handles GET /v1/capsules/{id}/processes.
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -95,24 +82,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
// The selector can be a numeric PID or a string tag.
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -146,14 +121,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// wsProcessOut is the JSON message sent to the WebSocket client.
type wsProcessOut struct {
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
PID uint32 `json:"pid,omitempty"` // only for "start"
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
@ -166,37 +133,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r, h.jwtSecret, h.db)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
slog.Error("process stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -207,17 +146,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
sendWSError(conn, "sandbox host is not reachable")
return
}
@ -236,7 +175,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
if err != nil {
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
sendWSError(conn, "failed to connect to process: "+err.Error())
return
}
defer stream.Close()
@ -257,42 +196,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.ConnectProcessResponse_Start:
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
case *pb.ConnectProcessResponse_Data:
switch o := ev.Data.Output.(type) {
case *pb.ExecStreamData_Stdout:
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
case *pb.ExecStreamData_Stderr:
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
}
case *pb.ConnectProcessResponse_End:
exitCode := ev.End.ExitCode
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
}
}
if err := stream.Err(); err != nil {
if streamCtx.Err() == nil {
sendProcessWSError(conn, err.Error())
sendWSError(conn, err.Error())
}
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
}
}
func sendProcessWSError(conn *websocket.Conn, msg string) {
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
updateLastActive(h.db, sandboxID, sandboxIDStr)
}