1
0
forked from wrenn/wrenn

Implement host registration, JWT refresh tokens, and multi-host scheduling

Replaces the hardcoded CP_HOST_AGENT_ADDR single-agent setup with a
DB-driven registration system supporting multiple host agents (BYOC).

Key changes:
- Host agents register via one-time token, receive a 7-day JWT + 60-day
  refresh token; heartbeat loop auto-refreshes on 401/403 and pauses all
  sandboxes if refresh fails
- HostClientPool: lazy Connect RPC client cache keyed by host ID, replacing
  the single static agent client throughout the API and service layers
- RoundRobinScheduler: picks an online host for each new sandbox via
  ListActiveHosts; extensible for future scheduling strategies
- HostMonitor (replaces Reconciler): passive heartbeat staleness check marks
  hosts unreachable and sandboxes missing after 90s; active reconciliation
  per online host restores missing-but-alive sandboxes and stops orphans
- Graceful host delete: returns 409 with affected sandbox list without
  ?force=true; force-delete destroys sandboxes then evicts pool client
- Snapshot delete broadcasts to all online hosts (templates have no host_id)
- sandbox.Manager.PauseAll: pauses all running VMs on CP connectivity loss
- New migration: host_refresh_tokens table with token rotation (issue-then-
  revoke ordering to prevent lockout on mid-rotation crash)
- New sandbox status 'missing' (reversible, unlike 'stopped') and host
  status 'unreachable'; both reflected in OpenAPI spec
- Fix: refresh token auth failure now returns 401 (was 400 via generic
  'invalid' substring match in serviceErrToHTTP)
This commit is contained in:
2026-03-24 18:32:05 +06:00
parent f968da9768
commit 9bf67aa7f7
33 changed files with 1567 additions and 318 deletions

View File

@ -6,7 +6,6 @@ REDIS_URL=redis://localhost:6379/0
# Control Plane # Control Plane
CP_LISTEN_ADDR=:8000 CP_LISTEN_ADDR=:8000
CP_HOST_AGENT_ADDR=localhost:50051
# Host Agent # Host Agent
AGENT_LISTEN_ADDR=:50051 AGENT_LISTEN_ADDR=:50051

View File

@ -17,7 +17,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/config"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" "git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
) )
func main() { func main() {
@ -66,12 +67,11 @@ func main() {
} }
slog.Info("connected to redis") slog.Info("connected to redis")
// Connect RPC client for the host agent. // Host client pool — manages Connect RPC clients to host agents.
agentHTTP := &http.Client{Timeout: 10 * time.Minute} hostPool := lifecycle.NewHostClientPool()
agentClient := hostagentv1connect.NewHostAgentServiceClient(
agentHTTP, // Scheduler — picks a host for each new sandbox (round-robin for now).
cfg.HostAgentAddr, hostScheduler := scheduler.NewRoundRobinScheduler(queries)
)
// OAuth provider registry. // OAuth provider registry.
oauthRegistry := oauth.NewRegistry() oauthRegistry := oauth.NewRegistry()
@ -87,11 +87,11 @@ func main() {
} }
// API server. // API server.
srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
// Start reconciler. // Start host monitor (passive + active reconciliation every 30s).
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second) monitor := api.NewHostMonitor(queries, hostPool, 30*time.Second)
reconciler.Start(ctx) monitor.Start(ctx)
httpServer := &http.Server{ httpServer := &http.Server{
Addr: cfg.ListenAddr, Addr: cfg.ListenAddr,
@ -114,7 +114,7 @@ func main() {
} }
}() }()
slog.Info("control plane starting", "addr", cfg.ListenAddr, "agent", cfg.HostAgentAddr) slog.Info("control plane starting", "addr", cfg.ListenAddr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err) slog.Error("http server error", "error", err)
os.Exit(1) os.Exit(1)

View File

@ -18,7 +18,7 @@ import (
) )
func main() { func main() {
registrationToken := flag.String("register", "", "One-time registration token from the control plane") registrationToken := flag.String("register", "", "One-time registration token from the control plane (required on first run)")
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent") advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
flag.Parse() flag.Parse()
@ -42,7 +42,16 @@ func main() {
listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051") listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn") rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
cpURL := os.Getenv("AGENT_CP_URL") cpURL := os.Getenv("AGENT_CP_URL")
tokenFile := filepath.Join(rootDir, "host-token") tokenFile := filepath.Join(rootDir, "host.jwt")
if cpURL == "" {
slog.Error("AGENT_CP_URL environment variable is required")
os.Exit(1)
}
if *advertiseAddr == "" {
slog.Error("--address flag is required (externally-reachable ip:port)")
os.Exit(1)
}
cfg := sandbox.Config{ cfg := sandbox.Config{
KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"), KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
@ -58,13 +67,7 @@ func main() {
mgr.StartTTLReaper(ctx) mgr.StartTTLReaper(ctx)
if *advertiseAddr == "" { // Register with the control plane and start heartbeating.
slog.Error("--address flag is required (externally-reachable ip:port)")
os.Exit(1)
}
// Register with the control plane (if configured).
if cpURL != "" {
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL, CPURL: cpURL,
RegistrationToken: *registrationToken, RegistrationToken: *registrationToken,
@ -83,8 +86,14 @@ func main() {
} }
slog.Info("host registered", "host_id", hostID) slog.Info("host registered", "host_id", hostID)
hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second)
} // Start heartbeat loop. On CP rejection: try JWT refresh. If that fails,
// pause all running sandboxes to ensure they're not left orphaned.
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, func() {
pauseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mgr.PauseAll(pauseCtx)
})
srv := hostagent.NewServer(mgr) srv := hostagent.NewServer(mgr)
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
@ -115,7 +124,7 @@ func main() {
} }
}() }()
slog.Info("host agent starting", "addr", listenAddr) slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err) slog.Error("http server error", "error", err)
os.Exit(1) os.Exit(1)

View File

@ -0,0 +1,19 @@
-- +goose Up
-- Refresh tokens for host agent JWT rotation.
-- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation).
-- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that.
CREATE TABLE host_refresh_tokens (
id TEXT PRIMARY KEY,
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete
);
CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
-- +goose Down
DROP TABLE host_refresh_tokens;

View File

@ -0,0 +1,19 @@
-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING *;
-- name: GetHostRefreshTokenByHash :one
SELECT * FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW();
-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1;
-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL;
-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL;

View File

@ -67,3 +67,18 @@ SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC;
-- name: GetHostByTeam :one -- name: GetHostByTeam :one
SELECT * FROM hosts WHERE id = $1 AND team_id = $2; SELECT * FROM hosts WHERE id = $1 AND team_id = $2;
-- name: ListActiveHosts :many
-- Returns all hosts that have completed registration (not pending/offline).
SELECT * FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at;
-- name: UpdateHostHeartbeatAndStatus :exec
-- Updates last_heartbeat_at and transitions unreachable hosts back to online.
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1;
-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1;

View File

@ -56,3 +56,20 @@ WHERE id = ANY($1::text[]);
SELECT * FROM sandboxes SELECT * FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC; ORDER BY created_at DESC;
-- name: MarkSandboxesMissingByHost :exec
-- Called when the host monitor marks a host unreachable.
-- Marks running/starting/pending sandboxes on that host as 'missing' so users see
-- the sandbox is not currently reachable, without permanently losing the record.
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
-- name: BulkRestoreRunning :exec
-- Called by the reconciler when a host comes back online and its sandboxes are
-- confirmed alive. Restores only sandboxes that are in 'missing' state.
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::text[]) AND status = 'missing';

View File

@ -0,0 +1,20 @@
package api
import (
"context"
"fmt"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)
}
return pool.GetForHost(host)
}

View File

@ -14,17 +14,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
type execHandler struct { type execHandler struct {
db *db.Queries db *db.Queries
agent hostagentv1connect.HostAgentServiceClient pool *lifecycle.HostClientPool
} }
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler { func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
return &execHandler{db: db, agent: agent} return &execHandler{db: db, pool: pool}
} }
type execRequest struct { type execRequest struct {
@ -73,7 +73,13 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{ agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Cmd: req.Cmd, Cmd: req.Cmd,
Args: req.Args, Args: req.Args,

View File

@ -14,17 +14,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
type execStreamHandler struct { type execStreamHandler struct {
db *db.Queries db *db.Queries
agent hostagentv1connect.HostAgentServiceClient pool *lifecycle.HostClientPool
} }
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler { func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, agent: agent} return &execStreamHandler{db: db, pool: pool}
} }
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
@ -80,11 +80,17 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return return
} }
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendWSError(conn, "sandbox host is not reachable")
return
}
// Open streaming exec to host agent. // Open streaming exec to host agent.
streamCtx, cancel := context.WithCancel(ctx) streamCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{ stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Cmd: startMsg.Cmd, Cmd: startMsg.Cmd,
Args: startMsg.Args, Args: startMsg.Args,

View File

@ -11,17 +11,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
type filesHandler struct { type filesHandler struct {
db *db.Queries db *db.Queries
agent hostagentv1connect.HostAgentServiceClient pool *lifecycle.HostClientPool
} }
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler { func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
return &filesHandler{db: db, agent: agent} return &filesHandler{db: db, pool: pool}
} }
// Upload handles POST /v1/sandboxes/{id}/files/write. // Upload handles POST /v1/sandboxes/{id}/files/write.
@ -75,7 +75,13 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
return return
} }
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{ agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Path: filePath, Path: filePath,
Content: content, Content: content,
@ -120,7 +126,13 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
return return
} }
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{ agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Path: req.Path, Path: req.Path,
})) }))

View File

@ -12,17 +12,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
type filesStreamHandler struct { type filesStreamHandler struct {
db *db.Queries db *db.Queries
agent hostagentv1connect.HostAgentServiceClient pool *lifecycle.HostClientPool
} }
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler { func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
return &filesStreamHandler{db: db, agent: agent} return &filesStreamHandler{db: db, pool: pool}
} }
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write. // StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
@ -88,8 +88,14 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
} }
defer filePart.Close() defer filePart.Close()
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open client-streaming RPC to host agent. // Open client-streaming RPC to host agent.
stream := h.agent.WriteFileStream(ctx) stream := agent.WriteFileStream(ctx)
// Send metadata first. // Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{ if err := stream.Send(&pb.WriteFileStreamRequest{
@ -164,8 +170,14 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
return return
} }
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open server-streaming RPC to host agent. // Open server-streaming RPC to host agent.
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{ stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Path: req.Path, Path: req.Path,
})) }))

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"errors"
"net/http" "net/http"
"time" "time"
@ -34,6 +35,25 @@ type createHostResponse struct {
RegistrationToken string `json:"registration_token"` RegistrationToken string `json:"registration_token"`
} }
type refreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type deletePreviewResponse struct {
Host hostResponse `json:"host"`
SandboxIDs []string `json:"sandbox_ids"`
}
type hasSandboxesErrorResponse struct {
SandboxIDs []string `json:"sandbox_ids"`
}
type registerHostRequest struct { type registerHostRequest struct {
Token string `json:"token"` Token string `json:"token"`
Arch string `json:"arch,omitempty"` Arch string `json:"arch,omitempty"`
@ -46,6 +66,7 @@ type registerHostRequest struct {
type registerHostResponse struct { type registerHostResponse struct {
Host hostResponse `json:"host"` Host hostResponse `json:"host"`
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
} }
type addTagRequest struct { type addTagRequest struct {
@ -183,18 +204,54 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, hostToResponse(host)) writeJSON(w, http.StatusOK, hostToResponse(host))
} }
// Delete handles DELETE /v1/hosts/{id}. // DeletePreview handles GET /v1/hosts/{id}/delete-preview.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { // Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id") hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context()) ac := auth.MustFromContext(r.Context())
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil { preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err) status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg) writeError(w, status, code, msg)
return return
} }
writeJSON(w, http.StatusOK, deletePreviewResponse{
Host: hostToResponse(preview.Host),
SandboxIDs: preview.SandboxIDs,
})
}
// Delete handles DELETE /v1/hosts/{id}.
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return
}
// Check if it's a "has running sandboxes" error and return a structured 409.
var hasSandboxes *service.HostHasSandboxesError
if errors.As(err, &hasSandboxes) {
writeJSON(w, http.StatusConflict, map[string]any{
"error": map[string]any{
"code": "has_active_sandboxes",
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
"sandbox_ids": hasSandboxes.SandboxIDs,
},
})
return
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
} }
// RegenerateToken handles POST /v1/hosts/{id}/token. // RegenerateToken handles POST /v1/hosts/{id}/token.
@ -249,6 +306,7 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, registerHostResponse{ writeJSON(w, http.StatusCreated, registerHostResponse{
Host: hostToResponse(result.Host), Host: hostToResponse(result.Host),
Token: result.JWT, Token: result.JWT,
RefreshToken: result.RefreshToken,
}) })
} }
@ -311,6 +369,33 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req refreshTokenRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.RefreshToken == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
return
}
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, refreshTokenResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
})
}
// ListTags handles GET /v1/hosts/{id}/tags. // ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) { func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id") hostID := chi.URLParam(r, "id")

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
@ -14,20 +15,45 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate" "git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
type snapshotHandler struct { type snapshotHandler struct {
svc *service.TemplateService svc *service.TemplateService
db *db.Queries db *db.Queries
agent hostagentv1connect.HostAgentServiceClient pool *lifecycle.HostClientPool
} }
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler { func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, agent: agent} return &snapshotHandler{svc: svc, db: db, pool: pool}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors. TODO: add host_id to templates table.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
}
}
}
return nil
} }
type createSnapshotRequest struct { type createSnapshotRequest struct {
@ -93,10 +119,9 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace") writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return return
} }
// Delete old files from the agent before removing the DB record. // Delete old snapshot files from all hosts before removing the DB record.
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil { if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil {
status, code, msg := agentErrToHTTP(err) writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
return return
} }
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil { if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
@ -116,7 +141,13 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return return
} }
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID, SandboxId: req.SandboxID,
Name: req.Name, Name: req.Name,
})) }))
@ -186,11 +217,8 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
return return
} }
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ if err := h.deleteSnapshotBroadcast(ctx, name); err != nil {
Name: name, writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete snapshot files: "+msg)
return return
} }

View File

@ -0,0 +1,198 @@
package api
import (
"context"
"log/slog"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"connectrpc.com/connect"
)
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
// "unreachable" and marks their active sandboxes as "missing".
//
// 2. Active reconciliation: for each online host, calls ListSandboxes and
// reconciles DB state against live host state — restoring "missing"
// sandboxes that are actually alive, and stopping orphaned ones.
type HostMonitor struct {
db *db.Queries
pool *lifecycle.HostClientPool
interval time.Duration
}
// NewHostMonitor creates a HostMonitor.
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, interval time.Duration) *HostMonitor {
return &HostMonitor{
db: queries,
pool: pool,
interval: interval,
}
}
// Start runs the monitor loop until the context is cancelled.
func (m *HostMonitor) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.run(ctx)
}
}
}()
}
func (m *HostMonitor) run(ctx context.Context) {
hosts, err := m.db.ListActiveHosts(ctx)
if err != nil {
slog.Warn("host monitor: failed to list hosts", "error", err)
return
}
for _, host := range hosts {
m.checkHost(ctx, host)
}
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
stale := !host.LastHeartbeatAt.Valid ||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
}
return
}
// --- Active reconciliation: only for online hosts ---
if host.Status != "online" {
return
}
agent, err := m.pool.GetForHost(host)
if err != nil {
// Host has no address yet (e.g., just registered) — skip.
return
}
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
return
}
// Build set of sandbox IDs alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPaused[id] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
} else {
var toRestore []string
var toStop []string
for _, sb := range missingSandboxes {
if _, ok := alive[sb.ID]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
return
}
var toPause, toStop []string
for _, sb := range runningSandboxes {
if _, ok := alive[sb.ID]; ok {
continue
}
if _, ok := autoPaused[sb.ID]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
}
}
}

View File

@ -89,6 +89,8 @@ func serviceErrToHTTP(err error) (int, string, string) {
return http.StatusConflict, "invalid_state", msg return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "forbidden"): case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
case strings.Contains(msg, "invalid"): case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg return http.StatusBadRequest, "invalid_request", msg
default: default:

View File

@ -1193,8 +1193,16 @@ paths:
security: security:
- bearerAuth: [] - bearerAuth: []
description: | description: |
Admins can delete any host. Team owners can delete BYOC hosts Admins can delete any host. Team owners and admins can delete BYOC hosts
belonging to their team. belonging to their team. Without `?force=true`, returns 409 if the host
has active sandboxes. With `?force=true`, destroys all sandboxes first.
parameters:
- name: force
in: query
required: false
schema:
type: boolean
description: If true, destroy all sandboxes on the host before deleting.
responses: responses:
"204": "204":
description: Host deleted description: Host deleted
@ -1204,6 +1212,12 @@ paths:
application/json: application/json:
schema: schema:
$ref: "#/components/schemas/Error" $ref: "#/components/schemas/Error"
"409":
description: Host has active sandboxes (only when force is not set)
content:
application/json:
schema:
$ref: "#/components/schemas/HostHasSandboxesError"
/v1/hosts/{id}/token: /v1/hosts/{id}/token:
parameters: parameters:
@ -1312,6 +1326,72 @@ paths:
schema: schema:
$ref: "#/components/schemas/Error" $ref: "#/components/schemas/Error"
/v1/hosts/auth/refresh:
post:
summary: Refresh host JWT
operationId: refreshHostToken
tags: [hosts]
description: |
Exchanges a refresh token for a new JWT and rotated refresh token.
The old refresh token is immediately revoked. No authentication required —
the refresh token itself is the credential.
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/RefreshHostTokenRequest"
responses:
"200":
description: New JWT and rotated refresh token
content:
application/json:
schema:
$ref: "#/components/schemas/RefreshHostTokenResponse"
"401":
description: Invalid, expired, or revoked refresh token
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/delete-preview:
parameters:
- name: id
in: path
required: true
schema:
type: string
get:
summary: Preview host deletion
operationId: getHostDeletePreview
tags: [hosts]
security:
- bearerAuth: []
description: |
Returns the list of sandbox IDs that would be destroyed if the host
were deleted with `?force=true`. No state is modified.
responses:
"200":
description: Deletion preview
content:
application/json:
schema:
$ref: "#/components/schemas/HostDeletePreview"
"403":
description: Insufficient permissions
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: Host not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/tags: /v1/hosts/{id}/tags:
parameters: parameters:
- name: id - name: id
@ -1405,7 +1485,7 @@ components:
type: apiKey type: apiKey
in: header in: header
name: X-Host-Token name: X-Host-Token
description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year. description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days.
schemas: schemas:
SignupRequest: SignupRequest:
@ -1505,7 +1585,7 @@ components:
type: string type: string
status: status:
type: string type: string
enum: [pending, running, paused, stopped, error] enum: [pending, starting, running, paused, hibernated, stopped, missing, error]
template: template:
type: string type: string
vcpus: vcpus:
@ -1661,7 +1741,10 @@ components:
$ref: "#/components/schemas/Host" $ref: "#/components/schemas/Host"
token: token:
type: string type: string
description: Long-lived host JWT for X-Host-Token header. Valid for 1 year. description: Host JWT for X-Host-Token header. Valid for 7 days.
refresh_token:
type: string
description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.
Host: Host:
type: object type: object
@ -1697,7 +1780,7 @@ components:
nullable: true nullable: true
status: status:
type: string type: string
enum: [pending, online, offline, draining] enum: [pending, online, offline, draining, unreachable]
last_heartbeat_at: last_heartbeat_at:
type: string type: string
format: date-time format: date-time
@ -1711,6 +1794,54 @@ components:
type: string type: string
format: date-time format: date-time
RefreshHostTokenRequest:
type: object
required: [refresh_token]
properties:
refresh_token:
type: string
description: Refresh token obtained from registration or a previous refresh.
RefreshHostTokenResponse:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
token:
type: string
description: New host JWT. Valid for 7 days.
refresh_token:
type: string
description: New refresh token. Valid for 60 days; old token is revoked.
HostDeletePreview:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
sandbox_ids:
type: array
items:
type: string
description: IDs of sandboxes that would be destroyed on force-delete.
HostHasSandboxesError:
type: object
properties:
error:
type: object
properties:
code:
type: string
example: host_has_sandboxes
message:
type: string
sandbox_ids:
type: array
items:
type: string
description: IDs of active sandboxes blocking deletion.
AddTagRequest: AddTagRequest:
type: object type: object
required: [tag] required: [tag]

View File

@ -1,126 +0,0 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/internal/db"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// Reconciler periodically compares the host agent's sandbox list with the DB
// and marks sandboxes that no longer exist on the host as stopped.
type Reconciler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
hostID string
interval time.Duration
}
// NewReconciler creates a new reconciler.
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
return &Reconciler{
db: db,
agent: agent,
hostID: hostID,
interval: interval,
}
}
// Start runs the reconciliation loop until the context is cancelled.
func (rc *Reconciler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(rc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.reconcile(ctx)
}
}
}()
}
func (rc *Reconciler) reconcile(ctx context.Context) {
// Single RPC returns both the running sandbox list and any IDs that
// were auto-paused by the TTL reaper since the last call.
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
return
}
// Build a set of sandbox IDs that are alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
// Build auto-paused set from the same response.
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPausedSet[id] = struct{}{}
}
// Get all DB sandboxes for this host that are running.
// Paused sandboxes are excluded: they are expected to not exist on the
// host agent because pause = snapshot + destroy resources.
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: rc.hostID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
return
}
// Find sandboxes in DB that are no longer on the host.
var stale []string
for _, sb := range dbSandboxes {
if _, ok := alive[sb.ID]; !ok {
stale = append(stale, sb.ID)
}
}
if len(stale) == 0 {
return
}
// Split stale sandboxes into those auto-paused by the TTL reaper vs
// those that crashed/were orphaned.
var toPause, toStop []string
for _, id := range stale {
if _, ok := autoPausedSet[id]; ok {
toPause = append(toPause, id)
} else {
toStop = append(toStop, id)
}
}
if len(toPause) > 0 {
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
}
}
if len(toStop) > 0 {
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
}
}
}

View File

@ -11,8 +11,9 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
//go:embed openapi.yaml //go:embed openapi.yaml
@ -24,25 +25,34 @@ type Server struct {
} }
// New constructs the chi router and registers all routes. // New constructs the chi router and registers all routes.
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server { func New(
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
pgPool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
) *Server {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(requestLogger()) r.Use(requestLogger())
// Shared service layer. // Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent} sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries} apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries} templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret} hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
teamSvc := &service.TeamService{DB: queries, Pool: pool, Agent: agent} teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
sandbox := newSandboxHandler(sandboxSvc) sandbox := newSandboxHandler(sandboxSvc)
exec := newExecHandler(queries, agent) exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, agent) execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, agent) files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, agent) filesStream := newFilesStreamHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, agent) snapshots := newSnapshotHandler(templateSvc, queries, pool)
authH := newAuthHandler(queries, pool, jwtSecret) authH := newAuthHandler(queries, pgPool, jwtSecret)
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL) oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc) apiKeys := newAPIKeyHandler(apiKeySvc)
hostH := newHostHandler(hostSvc, queries) hostH := newHostHandler(hostSvc, queries)
teamH := newTeamHandler(teamSvc) teamH := newTeamHandler(teamSvc)
@ -123,6 +133,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
// Unauthenticated: one-time registration token. // Unauthenticated: one-time registration token.
r.Post("/register", hostH.Register) r.Post("/register", hostH.Register)
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat. // Host-token-authenticated: heartbeat.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat) r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
@ -134,6 +147,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Route("/{id}", func(r chi.Router) { r.Route("/{id}", func(r chi.Router) {
r.Get("/", hostH.Get) r.Get("/", hostH.Get)
r.Delete("/", hostH.Delete) r.Delete("/", hostH.Delete)
r.Get("/delete-preview", hostH.DeletePreview)
r.Post("/token", hostH.RegenerateToken) r.Post("/token", hostH.RegenerateToken)
r.Get("/tags", hostH.ListTags) r.Get("/tags", hostH.ListTags)
r.Post("/tags", hostH.AddTag) r.Post("/tags", hostH.AddTag)

View File

@ -8,7 +8,8 @@ import (
) )
const jwtExpiry = 6 * time.Hour const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 8760 * time.Hour // 1 year const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens. // Claims are the JWT payload for user tokens.
type Claims struct { type Claims struct {

View File

@ -2,7 +2,6 @@ package config
import ( import (
"os" "os"
"strings"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
@ -12,7 +11,6 @@ type Config struct {
DatabaseURL string DatabaseURL string
RedisURL string RedisURL string
ListenAddr string ListenAddr string
HostAgentAddr string
JWTSecret string JWTSecret string
OAuthGitHubClientID string OAuthGitHubClientID string
@ -27,11 +25,10 @@ func Load() Config {
// Best-effort load — missing .env file is fine. // Best-effort load — missing .env file is fine.
_ = godotenv.Load() _ = godotenv.Load()
cfg := Config{ return Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"), RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
JWTSecret: os.Getenv("JWT_SECRET"), JWTSecret: os.Getenv("JWT_SECRET"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
@ -39,13 +36,6 @@ func Load() Config {
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
CPPublicURL: os.Getenv("CP_PUBLIC_URL"), CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
} }
// Ensure the host agent address has a scheme.
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
}
return cfg
} }
func envOrDefault(key, def string) string { func envOrDefault(key, def string) string {

View File

@ -0,0 +1,92 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: host_refresh_tokens.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
`
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
return err
}
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
`
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID string `json:"id"`
HostID string `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, insertHostRefreshToken,
arg.ID,
arg.HostID,
arg.TokenHash,
arg.ExpiresAt,
)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

View File

@ -234,6 +234,50 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
return i, err return i, err
} }
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listActiveHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHosts = `-- name: ListHosts :many const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
` `
@ -461,6 +505,15 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
return err return err
} }
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts UPDATE hosts
SET arch = $2, SET arch = $2,
@ -521,6 +574,20 @@ func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
return err return err
} }
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :exec
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1
`
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
return err
}
const updateHostStatus = `-- name: UpdateHostStatus :exec const updateHostStatus = `-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1 UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
` `

View File

@ -36,6 +36,15 @@ type Host struct {
MtlsEnabled bool `json:"mtls_enabled"` MtlsEnabled bool `json:"mtls_enabled"`
} }
type HostRefreshToken struct {
ID string `json:"id"`
HostID string `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
}
type HostTag struct { type HostTag struct {
HostID string `json:"host_id"` HostID string `json:"host_id"`
Tag string `json:"tag"` Tag string `json:"tag"`

View File

@ -11,6 +11,20 @@ import (
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::text[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes UPDATE sandboxes
SET status = $2, SET status = $2,
@ -300,6 +314,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
return items, nil return items, nil
} }
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
`
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
const updateLastActive = `-- name: UpdateLastActive :exec const updateLastActive = `-- name: UpdateLastActive :exec
UPDATE sandboxes UPDATE sandboxes
SET last_active_at = $2, SET last_active_at = $2,

View File

@ -17,6 +17,13 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt.
type tokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
}
// RegistrationConfig holds the configuration for host registration. // RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct { type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000) CPURL string // Control plane base URL (e.g., http://localhost:8000)
@ -37,6 +44,17 @@ type registerRequest struct {
type registerResponse struct { type registerResponse struct {
Host json.RawMessage `json:"host"` Host json.RawMessage `json:"host"`
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
} }
type errorResponse struct { type errorResponse struct {
@ -46,20 +64,46 @@ type errorResponse struct {
} `json:"error"` } `json:"error"`
} }
// loadTokenFile reads and parses the persisted token file.
func loadTokenFile(path string) (*tokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Support legacy format (raw JWT string) for backwards compatibility.
trimmed := strings.TrimSpace(string(data))
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
}
var tf tokenFile
if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse token file: %w", err)
}
return &tf, nil
}
// saveTokenFile writes the token file as JSON with 0600 permissions.
func saveTokenFile(path string, tf tokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
return fmt.Errorf("marshal token file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists // Register calls the control plane to register this host agent and persists
// the returned JWT to disk. Returns the host JWT token string. // the returned JWT and refresh token to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// Check if we already have a saved token. // Check if we already have a saved token.
if data, err := os.ReadFile(cfg.TokenFile); err == nil { if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
token := strings.TrimSpace(string(data)) slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
if token != "" { return tf.JWT, nil
slog.Info("loaded existing host token", "file", cfg.TokenFile)
return token, nil
}
} }
if cfg.RegistrationToken == "" { if cfg.RegistrationToken == "" {
return "", fmt.Errorf("no saved host token and no registration token provided") return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
} }
arch := runtime.GOARCH arch := runtime.GOARCH
@ -117,45 +161,155 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
return "", fmt.Errorf("registration response missing token") return "", fmt.Errorf("registration response missing token")
} }
// Persist the token to disk for subsequent startups. hostID, err := hostIDFromJWT(regResp.Token)
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil { if err != nil {
return "", fmt.Errorf("extract host ID from JWT: %w", err)
}
// Persist JWT + refresh token.
tf := tokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return "", fmt.Errorf("save host token: %w", err) return "", fmt.Errorf("save host token: %w", err)
} }
slog.Info("host registered and token saved", "file", cfg.TokenFile) slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
return regResp.Token, nil return regResp.Token, nil
} }
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
// It reads and updates the token file in place.
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
tf, err := loadTokenFile(tokenFilePath)
if err != nil {
return "", fmt.Errorf("load token file: %w", err)
}
if tf.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
var refResp refreshResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
return "", fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
return "", fmt.Errorf("save refreshed token: %w", err)
}
slog.Info("host JWT refreshed", "host_id", tf.HostID)
return refResp.Token, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats // StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled. // to the control plane. It runs until the context is cancelled.
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) { //
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat" // On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func()) {
client := &http.Client{Timeout: 10 * time.Second} client := &http.Client{Timeout: 10 * time.Second}
go func() { go func() {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
consecutiveFailures := 0
pausedDueToFailure := false
currentJWT := ""
// Load the current JWT from disk.
if tf, err := loadTokenFile(tokenFilePath); err == nil {
currentJWT = tf.JWT
}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil { if err != nil {
slog.Warn("heartbeat: failed to create request", "error", err) slog.Warn("heartbeat: failed to create request", "error", err)
continue continue
} }
req.Header.Set("X-Host-Token", hostToken) req.Header.Set("X-Host-Token", currentJWT)
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
slog.Warn("heartbeat: request failed", "error", err) consecutiveFailures++
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
if consecutiveFailures >= 3 && !pausedDueToFailure {
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
if pauseAll != nil {
pauseAll()
}
pausedDueToFailure = true
}
continue continue
} }
resp.Body.Close() resp.Body.Close()
if resp.StatusCode != http.StatusNoContent { switch resp.StatusCode {
case http.StatusNoContent:
// Success.
if consecutiveFailures > 0 || pausedDueToFailure {
slog.Info("heartbeat: CP connection restored")
}
consecutiveFailures = 0
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
if refreshErr != nil {
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
pausedDueToFailure = true
}
// Stop the heartbeat loop — operator must re-register.
return
}
currentJWT = newJWT
slog.Info("heartbeat: JWT refreshed successfully")
default:
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode) slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
} }
} }
@ -166,6 +320,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
// HostIDFromToken extracts the host_id claim from a host JWT without // HostIDFromToken extracts the host_id claim from a host JWT without
// verifying the signature (the agent doesn't have the signing secret). // verifying the signature (the agent doesn't have the signing secret).
func HostIDFromToken(token string) (string, error) { func HostIDFromToken(token string) (string, error) {
return hostIDFromJWT(token)
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the token file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
if len(parts) != 3 { if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT format") return "", fmt.Errorf("invalid JWT format")

View File

@ -67,3 +67,17 @@ func NewRegistrationToken() string {
} }
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars.
func NewRefreshTokenID() string {
return "hrt-" + hex8()
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token.
func NewRefreshToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,77 @@
package lifecycle
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
// Clients are created lazily on first access and evicted when a host is removed
// or goes unreachable. The pool is safe for concurrent use.
type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
}
// NewHostClientPool creates a new pool. The underlying HTTP client uses a
// 10-minute timeout to support long-running streaming operations.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
}
}
// Get returns a Connect RPC client for the given host, creating one if necessary.
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
p.mu.RLock()
c, ok := p.clients[hostID]
p.mu.RUnlock()
if ok {
return c
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock.
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address))
p.clients[hostID] = c
return c
}
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if !h.Address.Valid || h.Address.String == "" {
return nil, fmt.Errorf("host %s has no address", h.ID)
}
return p.Get(h.ID, h.Address.String), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
// created on the next call to Get. Call this when a host's address changes or when
// a host is deleted.
func (p *HostClientPool) Evict(hostID string) {
p.mu.Lock()
delete(p.clients, hostID)
p.mu.Unlock()
}
// ensureScheme adds "http://" if the address has no scheme.
func ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return "http://" + addr
}

View File

@ -1183,6 +1183,28 @@ func (m *Manager) Shutdown(ctx context.Context) {
m.loops.ReleaseAll() m.loops.ReleaseAll()
} }
// PauseAll pauses every running sandbox managed by this host agent.
// Called when the host loses connectivity to the control plane to avoid
// leaving running VMs unmanaged. It is best-effort: failures for individual
// sandboxes are logged but do not stop the rest.
func (m *Manager) PauseAll(ctx context.Context) {
m.mu.RLock()
ids := make([]string, 0, len(m.boxes))
for id, sb := range m.boxes {
if sb.Status == models.StatusRunning {
ids = append(ids, id)
}
}
m.mu.RUnlock()
slog.Info("pausing all running sandboxes due to CP connection loss", "count", len(ids))
for _, sbID := range ids {
if err := m.Pause(ctx, sbID); err != nil {
slog.Warn("PauseAll: failed to pause sandbox", "id", sbID, "error", err)
}
}
}
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup // warnErr logs a warning if err is non-nil. Used for best-effort cleanup
// in error paths where the primary error has already been captured. // in error paths where the primary error has already been captured.
func warnErr(msg string, id string, err error) { func warnErr(msg string, id string, err error) {

View File

@ -0,0 +1,51 @@
package scheduler
import (
"context"
"fmt"
"sync/atomic"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// HostScheduler selects a host for a new sandbox. Implementations may use
// different strategies (round-robin, least-loaded, tag-based, etc.).
type HostScheduler interface {
// SelectHost returns a host that can accept a new sandbox.
// Returns an error if no suitable host is available.
SelectHost(ctx context.Context) (db.Host, error)
}
// RoundRobinScheduler cycles through online hosts in round-robin order.
// It re-fetches the host list on every call so that newly registered or
// recovered hosts are considered immediately.
type RoundRobinScheduler struct {
db *db.Queries
counter atomic.Int64
}
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
return &RoundRobinScheduler{db: queries}
}
// SelectHost returns the next online host in round-robin order.
func (s *RoundRobinScheduler) SelectHost(ctx context.Context) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
}
var online []db.Host
for _, h := range hosts {
if h.Status == "online" && h.Address.Valid && h.Address.String != "" {
online = append(online, h)
}
}
if len(online) == 0 {
return db.Host{}, fmt.Errorf("no online hosts available")
}
idx := s.counter.Add(1) - 1
return online[int(idx%int64(len(online)))], nil
}

View File

@ -2,12 +2,14 @@ package service
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"time" "time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@ -15,6 +17,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
) )
// HostService provides host management operations. // HostService provides host management operations.
@ -22,6 +26,7 @@ type HostService struct {
DB *db.Queries DB *db.Queries
Redis *redis.Client Redis *redis.Client
JWT []byte JWT []byte
Pool *lifecycle.HostClientPool
} }
// HostCreateParams holds the parameters for creating a host. // HostCreateParams holds the parameters for creating a host.
@ -50,10 +55,24 @@ type HostRegisterParams struct {
Address string Address string
} }
// HostRegisterResult holds the registered host and its long-lived JWT. // HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token.
type HostRegisterResult struct { type HostRegisterResult struct {
Host db.Host Host db.Host
JWT string JWT string
RefreshToken string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
} }
// regTokenPayload is the JSON stored in Redis for registration tokens. // regTokenPayload is the JSON stored in Redis for registration tokens.
@ -64,6 +83,14 @@ type regTokenPayload struct {
const regTokenTTL = time.Hour const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token. // Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) { func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" { if p.Type != "regular" && p.Type != "byoc" {
@ -75,7 +102,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts") return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
} }
} else { } else {
// BYOC: admin or team owner. // BYOC: platform admin, or team owner/admin.
if p.TeamID == "" { if p.TeamID == "" {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts") return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
} }
@ -90,8 +117,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
if err != nil { if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
} }
if membership.Role != "owner" { if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts") return HostCreateResult{}, err
} }
} }
} }
@ -168,7 +195,6 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status) return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
} }
// Same permission model as Create/Delete.
if !isAdmin { if !isAdmin {
if host.Type != "byoc" { if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts") return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
@ -186,8 +212,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
if err != nil { if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
} }
if membership.Role != "owner" { if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens") return HostCreateResult{}, err
} }
} }
@ -216,7 +242,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
} }
// Register validates a one-time registration token, updates the host with // Register validates a one-time registration token, updates the host with
// machine specs, and returns a long-lived host JWT. // machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) { func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation, // Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token. // preventing concurrent requests from consuming the same token.
@ -264,18 +290,89 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err) slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
} }
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state. // Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, payload.HostID) host, err := s.DB.GetHost(ctx, payload.HostID)
if err != nil { if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
} }
return HostRegisterResult{Host: host, JWT: hostJWT}, nil return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil
} }
// Heartbeat updates the last heartbeat timestamp for a host. // Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'.
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
return s.DB.UpdateHostHeartbeat(ctx, hostID) return s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
} }
// List returns hosts visible to the caller. // List returns hosts visible to the caller.
@ -301,37 +398,135 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo
return host, nil return host, nil
} }
// Delete removes a host. Admins can delete any host. Team owners can delete // DeletePreview returns what would be affected by deleting the host, without
// BYOC hosts belonging to their team. // making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error { func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) {
host, err := s.DB.GetHost(ctx, hostID) host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin)
if err != nil { if err != nil {
return fmt.Errorf("host not found: %w", err) return HostDeletePreview{}, err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = sb.ID
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = sb.ID
}
return &HostHasSandboxesError{SandboxIDs: ids}
}
// Gracefully destroy running sandboxes on the host agent (best-effort).
if len(sandboxes) > 0 && host.Address.Valid && host.Address.String != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sb.ID,
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr)
}
}
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]string, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(hostID)
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
} }
if !isAdmin {
if host.Type != "byoc" { if host.Type != "byoc" {
return fmt.Errorf("forbidden: only admins can delete regular hosts") return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
} }
if !host.TeamID.Valid || host.TeamID.String != teamID { if !host.TeamID.Valid || host.TeamID.String != teamID {
return fmt.Errorf("forbidden: host does not belong to your team") return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
} }
if userID != "" {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID, UserID: userID,
TeamID: teamID, TeamID: teamID,
}) })
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("forbidden: not a member of the specified team") return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
} }
if err != nil { if err != nil {
return fmt.Errorf("check team membership: %w", err) return db.Host{}, fmt.Errorf("check team membership: %w", err)
} }
if membership.Role != "owner" { if err := requireAdminOrOwner(membership.Role); err != nil {
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts") return db.Host{}, err
} }
} }
return s.DB.DeleteHost(ctx, hostID) return host, nil
} }
// AddTag adds a tag to a host. // AddTag adds a tag to a host.
@ -357,3 +552,14 @@ func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdm
} }
return s.DB.GetHostTags(ctx, hostID) return s.DB.GetHostTags(ctx, hostID)
} }
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

View File

@ -11,16 +11,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/validate" "git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
// SandboxService provides sandbox lifecycle operations shared between the // SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard. // REST API and the dashboard.
type SandboxService struct { type SandboxService struct {
DB *db.Queries DB *db.Queries
Agent hostagentv1connect.HostAgentServiceClient Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
} }
// SandboxCreateParams holds the parameters for creating a sandbox. // SandboxCreateParams holds the parameters for creating a sandbox.
@ -32,8 +34,34 @@ type SandboxCreateParams struct {
TimeoutSec int32 TimeoutSec int32
} }
// Create creates a new sandbox: inserts a pending DB record, calls the host agent, // agentForSandbox looks up the host for the given sandbox and returns a client.
// and updates the record to running. Returns the sandbox DB row. func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) { func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" { if p.Template == "" {
p.Template = "minimal" p.Template = "minimal"
@ -58,12 +86,23 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
} }
} }
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID() sandboxID := id.NewSandboxID()
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{ if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID, ID: sandboxID,
TeamID: p.TeamID, TeamID: p.TeamID,
HostID: "default", HostID: host.ID,
Template: p.Template, Template: p.Template,
Status: "pending", Status: "pending",
Vcpus: p.VCPUs, Vcpus: p.VCPUs,
@ -73,7 +112,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err) return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
} }
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
Template: p.Template, Template: p.Template,
Vcpus: p.VCPUs, Vcpus: p.VCPUs,
@ -126,7 +165,12 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status) return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
} }
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
})); err != nil { })); err != nil {
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
@ -151,7 +195,12 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status) return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
} }
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{ agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
TimeoutSec: sb.TimeoutSec, TimeoutSec: sb.TimeoutSec,
})) }))
@ -181,8 +230,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
return fmt.Errorf("sandbox not found: %w", err) return fmt.Errorf("sandbox not found: %w", err)
} }
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone. // Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound { })); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err) return fmt.Errorf("agent destroy: %w", err)
@ -206,7 +260,12 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status) return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
} }
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{ agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxID, SandboxId: sandboxID,
})); err != nil { })); err != nil {
return fmt.Errorf("agent ping: %w", err) return fmt.Errorf("agent ping: %w", err)

View File

@ -14,8 +14,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
) )
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`) var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
@ -24,7 +24,7 @@ var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
type TeamService struct { type TeamService struct {
DB *db.Queries DB *db.Queries
Pool *pgxpool.Pool Pool *pgxpool.Pool
Agent hostagentv1connect.HostAgentServiceClient HostPool *lifecycle.HostClientPool
} }
// TeamWithRole pairs a team with the calling user's role in it. // TeamWithRole pairs a team with the calling user's role in it.
@ -177,11 +177,17 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin
var stopIDs []string var stopIDs []string
for _, sb := range sandboxes { for _, sb := range sandboxes {
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sb.ID, SandboxId: sb.ID,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound { })); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err) slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
} }
}
}
stopIDs = append(stopIDs, sb.ID) stopIDs = append(stopIDs, sb.ID)
} }