Files
wrenn/internal/api/middleware.go
pptx704 9bf67aa7f7 Implement host registration, JWT refresh tokens, and multi-host scheduling
Replaces the hardcoded CP_HOST_AGENT_ADDR single-agent setup with a
DB-driven registration system supporting multiple host agents (BYOC).

Key changes:
- Host agents register via one-time token, receive a 7-day JWT + 60-day
  refresh token; heartbeat loop auto-refreshes on 401/403 and pauses all
  sandboxes if refresh fails
- HostClientPool: lazy Connect RPC client cache keyed by host ID, replacing
  the single static agent client throughout the API and service layers
- RoundRobinScheduler: picks an online host for each new sandbox via
  ListActiveHosts; extensible for future scheduling strategies
- HostMonitor (replaces Reconciler): passive heartbeat staleness check marks
  hosts unreachable and sandboxes missing after 90s; active reconciliation
  per online host restores missing-but-alive sandboxes and stops orphans
- Graceful host delete: returns 409 with affected sandbox list without
  ?force=true; force-delete destroys sandboxes then evicts pool client
- Snapshot delete broadcasts to all online hosts (templates have no host_id)
- sandbox.Manager.PauseAll: pauses all running VMs on CP connectivity loss
- New migration: host_refresh_tokens table with token rotation (issue-then-
  revoke ordering to prevent lockout on mid-rotation crash)
- New sandbox status 'missing' (reversible, unlike 'stopped') and host
  status 'unreachable'; both reflected in OpenAPI spec
- Fix: refresh token auth failure now returns 401 (was 400 via generic
  'invalid' substring match in serviceErrToHTTP)
2026-03-24 18:32:05 +06:00

125 lines
3.4 KiB
Go

package api
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"strings"
"time"
"connectrpc.com/connect"
)
type errorResponse struct {
Error errorDetail `json:"error"`
}
type errorDetail struct {
Code string `json:"code"`
Message string `json:"message"`
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, status int, code, message string) {
writeJSON(w, status, errorResponse{
Error: errorDetail{Code: code, Message: message},
})
}
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
func agentErrToHTTP(err error) (int, string, string) {
switch connect.CodeOf(err) {
case connect.CodeNotFound:
return http.StatusNotFound, "not_found", err.Error()
case connect.CodeInvalidArgument:
return http.StatusBadRequest, "invalid_request", err.Error()
case connect.CodeFailedPrecondition:
return http.StatusConflict, "conflict", err.Error()
default:
return http.StatusBadGateway, "agent_error", err.Error()
}
}
// requestLogger returns middleware that logs each request.
func requestLogger() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
slog.Info("request",
"method", r.Method,
"path", r.URL.Path,
"status", sw.status,
"duration", time.Since(start),
)
})
}
}
func decodeJSON(r *http.Request, v any) error {
return json.NewDecoder(r.Body).Decode(v)
}
// serviceErrToHTTP maps a service-layer error to an HTTP status, code, and message.
// It inspects the underlying Connect RPC error if present, otherwise returns 500.
func serviceErrToHTTP(err error) (int, string, string) {
msg := err.Error()
// Check for Connect RPC errors wrapped by the service layer.
var connectErr *connect.Error
if errors.As(err, &connectErr) {
return agentErrToHTTP(connectErr)
}
// Map well-known service error patterns.
switch {
case strings.Contains(msg, "not found"):
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
default:
return http.StatusInternalServerError, "internal_error", msg
}
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(status int) {
w.status = status
w.ResponseWriter.WriteHeader(status)
}
// Hijack implements http.Hijacker, required for WebSocket upgrade.
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
}
// Flush implements http.Flusher, required for streaming responses.
func (w *statusWriter) Flush() {
if fl, ok := w.ResponseWriter.(http.Flusher); ok {
fl.Flush()
}
}