forked from wrenn/wrenn
Extract shared helpers to consolidate repeated patterns: - requireRunningSandbox: sandbox lookup + running check (10 call sites) - upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers) - updateLastActive: last_active_at update with background context (5 sites) - attachCowAndCreate: cow loop attach + dmsetup create (devicemapper) - issueRegistrationToken: token gen + Redis + audit (host service) - ErrNotFound sentinel: replaces string matching in hostagent server Also merges duplicate wsProcessOut/wsOutMsg types into one. Net: -208 lines, zero behavior change.
109 lines
3.7 KiB
Go
109 lines
3.7 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
|
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
|
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
|
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
|
)
|
|
|
|
// agentForHost looks up the host record and returns a Connect RPC client for it.
|
|
// Returns an error if the host is not found or has no address.
|
|
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
|
|
host, err := queries.GetHost(ctx, hostID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("host not found: %w", err)
|
|
}
|
|
return pool.GetForHost(host)
|
|
}
|
|
|
|
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
|
|
// and verifies it is running. On failure it writes the appropriate HTTP error and
|
|
// returns false.
|
|
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
|
|
sandboxIDStr := chi.URLParam(r, "id")
|
|
ctx := r.Context()
|
|
|
|
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
|
}
|
|
|
|
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
|
if err != nil {
|
|
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
|
}
|
|
if sb.Status != "running" {
|
|
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
|
return db.Sandbox{}, pgtype.UUID{}, "", false
|
|
}
|
|
|
|
return sb, sandboxID, sandboxIDStr, true
|
|
}
|
|
|
|
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket and resolves
|
|
// the auth context — either from middleware (API key) or from the first WS message (JWT).
|
|
// Returns the connection and auth context, or an error if authentication fails.
|
|
// The caller is responsible for closing the returned connection.
|
|
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request, jwtSecret []byte, queries *db.Queries) (*websocket.Conn, auth.AuthContext, error) {
|
|
ctx := r.Context()
|
|
ac, hasAuth := auth.FromContext(ctx)
|
|
|
|
if hasAuth {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
|
}
|
|
return conn, ac, nil
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
|
}
|
|
|
|
var wsAC auth.AuthContext
|
|
var authErr error
|
|
if isAdminWSRoute(ctx) {
|
|
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, jwtSecret, queries)
|
|
} else {
|
|
wsAC, authErr = wsAuthenticate(ctx, conn, jwtSecret, queries)
|
|
}
|
|
if authErr != nil {
|
|
conn.Close()
|
|
return nil, auth.AuthContext{}, fmt.Errorf("authentication failed")
|
|
}
|
|
|
|
return conn, wsAC, nil
|
|
}
|
|
|
|
// updateLastActive updates the sandbox last_active_at timestamp.
|
|
// Uses a background context with timeout for streaming handlers where
|
|
// the request context may already be cancelled.
|
|
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
|
ID: sandboxID,
|
|
LastActiveAt: pgtype.Timestamptz{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}); err != nil {
|
|
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
|
}
|
|
}
|