1
0
forked from wrenn/wrenn

Implement host registration, JWT refresh tokens, and multi-host scheduling

Replaces the hardcoded CP_HOST_AGENT_ADDR single-agent setup with a
DB-driven registration system supporting multiple host agents (BYOC).

Key changes:
- Host agents register via one-time token, receive a 7-day JWT + 60-day
  refresh token; heartbeat loop auto-refreshes on 401/403 and pauses all
  sandboxes if refresh fails
- HostClientPool: lazy Connect RPC client cache keyed by host ID, replacing
  the single static agent client throughout the API and service layers
- RoundRobinScheduler: picks an online host for each new sandbox via
  ListActiveHosts; extensible for future scheduling strategies
- HostMonitor (replaces Reconciler): passive heartbeat staleness check marks
  hosts unreachable and sandboxes missing after 90s; active reconciliation
  per online host restores missing-but-alive sandboxes and stops orphans
- Graceful host delete: returns 409 with affected sandbox list without
  ?force=true; force-delete destroys sandboxes then evicts pool client
- Snapshot delete broadcasts to all online hosts (templates have no host_id)
- sandbox.Manager.PauseAll: pauses all running VMs on CP connectivity loss
- New migration: host_refresh_tokens table with token rotation (issue-then-
  revoke ordering to prevent lockout on mid-rotation crash)
- New sandbox status 'missing' (reversible, unlike 'stopped') and host
  status 'unreachable'; both reflected in OpenAPI spec
- Fix: refresh token auth failure now returns 401 (was 400 via generic
  'invalid' substring match in serviceErrToHTTP)
This commit is contained in:
2026-03-24 18:32:05 +06:00
parent f968da9768
commit 9bf67aa7f7
33 changed files with 1567 additions and 318 deletions

View File

@ -0,0 +1,20 @@
package api
import (
"context"
"fmt"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)
}
return pool.GetForHost(host)
}

View File

@ -14,17 +14,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
return &execHandler{db: db, agent: agent}
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
return &execHandler{db: db, pool: pool}
}
type execRequest struct {
@ -73,7 +73,13 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
start := time.Now()
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: req.Cmd,
Args: req.Args,

View File

@ -14,17 +14,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
return &execStreamHandler{db: db, agent: agent}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
}
var upgrader = websocket.Upgrader{
@ -80,11 +80,17 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendWSError(conn, "sandbox host is not reachable")
return
}
// Open streaming exec to host agent.
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID,
Cmd: startMsg.Cmd,
Args: startMsg.Args,

View File

@ -11,17 +11,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
return &filesHandler{db: db, agent: agent}
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
return &filesHandler{db: db, pool: pool}
}
// Upload handles POST /v1/sandboxes/{id}/files/write.
@ -75,7 +75,13 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
return
}
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
Path: filePath,
Content: content,
@ -120,7 +126,13 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID,
Path: req.Path,
}))

View File

@ -12,17 +12,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
return &filesStreamHandler{db: db, agent: agent}
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
return &filesStreamHandler{db: db, pool: pool}
}
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
@ -88,8 +88,14 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
}
defer filePart.Close()
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open client-streaming RPC to host agent.
stream := h.agent.WriteFileStream(ctx)
stream := agent.WriteFileStream(ctx)
// Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{
@ -164,8 +170,14 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open server-streaming RPC to host agent.
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID,
Path: req.Path,
}))

View File

@ -1,6 +1,7 @@
package api
import (
"errors"
"net/http"
"time"
@ -34,6 +35,25 @@ type createHostResponse struct {
RegistrationToken string `json:"registration_token"`
}
type refreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type deletePreviewResponse struct {
Host hostResponse `json:"host"`
SandboxIDs []string `json:"sandbox_ids"`
}
type hasSandboxesErrorResponse struct {
SandboxIDs []string `json:"sandbox_ids"`
}
type registerHostRequest struct {
Token string `json:"token"`
Arch string `json:"arch,omitempty"`
@ -44,8 +64,9 @@ type registerHostRequest struct {
}
type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type addTagRequest struct {
@ -183,18 +204,54 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, hostToResponse(host))
}
// Delete handles DELETE /v1/hosts/{id}.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
writeJSON(w, http.StatusOK, deletePreviewResponse{
Host: hostToResponse(preview.Host),
SandboxIDs: preview.SandboxIDs,
})
}
// Delete handles DELETE /v1/hosts/{id}.
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
w.WriteHeader(http.StatusNoContent)
return
}
// Check if it's a "has running sandboxes" error and return a structured 409.
var hasSandboxes *service.HostHasSandboxesError
if errors.As(err, &hasSandboxes) {
writeJSON(w, http.StatusConflict, map[string]any{
"error": map[string]any{
"code": "has_active_sandboxes",
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
"sandbox_ids": hasSandboxes.SandboxIDs,
},
})
return
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
@ -247,8 +304,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusCreated, registerHostResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
})
}
@ -311,6 +369,33 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req refreshTokenRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.RefreshToken == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
return
}
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, refreshTokenResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
})
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
@ -14,20 +15,45 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
svc *service.TemplateService
db *db.Queries
pool *lifecycle.HostClientPool
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, agent: agent}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors. TODO: add host_id to templates table.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
}
}
}
return nil
}
type createSnapshotRequest struct {
@ -93,10 +119,9 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old files from the agent before removing the DB record.
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
// Delete old snapshot files from all hosts before removing the DB record.
if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
@ -116,7 +141,13 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
}))
@ -186,11 +217,8 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
return
}
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
Name: name,
})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete snapshot files: "+msg)
if err := h.deleteSnapshotBroadcast(ctx, name); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}

View File

@ -0,0 +1,198 @@
package api
import (
"context"
"log/slog"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"connectrpc.com/connect"
)
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
// "unreachable" and marks their active sandboxes as "missing".
//
// 2. Active reconciliation: for each online host, calls ListSandboxes and
// reconciles DB state against live host state — restoring "missing"
// sandboxes that are actually alive, and stopping orphaned ones.
type HostMonitor struct {
db *db.Queries
pool *lifecycle.HostClientPool
interval time.Duration
}
// NewHostMonitor creates a HostMonitor.
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, interval time.Duration) *HostMonitor {
return &HostMonitor{
db: queries,
pool: pool,
interval: interval,
}
}
// Start runs the monitor loop until the context is cancelled.
func (m *HostMonitor) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.run(ctx)
}
}
}()
}
func (m *HostMonitor) run(ctx context.Context) {
hosts, err := m.db.ListActiveHosts(ctx)
if err != nil {
slog.Warn("host monitor: failed to list hosts", "error", err)
return
}
for _, host := range hosts {
m.checkHost(ctx, host)
}
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
stale := !host.LastHeartbeatAt.Valid ||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
}
return
}
// --- Active reconciliation: only for online hosts ---
if host.Status != "online" {
return
}
agent, err := m.pool.GetForHost(host)
if err != nil {
// Host has no address yet (e.g., just registered) — skip.
return
}
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
return
}
// Build set of sandbox IDs alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPaused[id] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
} else {
var toRestore []string
var toStop []string
for _, sb := range missingSandboxes {
if _, ok := alive[sb.ID]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
return
}
var toPause, toStop []string
for _, sb := range runningSandboxes {
if _, ok := alive[sb.ID]; ok {
continue
}
if _, ok := autoPaused[sb.ID]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
}
}
}

View File

@ -89,6 +89,8 @@ func serviceErrToHTTP(err error) (int, string, string) {
return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
default:

View File

@ -1193,8 +1193,16 @@ paths:
security:
- bearerAuth: []
description: |
Admins can delete any host. Team owners can delete BYOC hosts
belonging to their team.
Admins can delete any host. Team owners and admins can delete BYOC hosts
belonging to their team. Without `?force=true`, returns 409 if the host
has active sandboxes. With `?force=true`, destroys all sandboxes first.
parameters:
- name: force
in: query
required: false
schema:
type: boolean
description: If true, destroy all sandboxes on the host before deleting.
responses:
"204":
description: Host deleted
@ -1204,6 +1212,12 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
"409":
description: Host has active sandboxes (only when force is not set)
content:
application/json:
schema:
$ref: "#/components/schemas/HostHasSandboxesError"
/v1/hosts/{id}/token:
parameters:
@ -1312,6 +1326,72 @@ paths:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/auth/refresh:
post:
summary: Refresh host JWT
operationId: refreshHostToken
tags: [hosts]
description: |
Exchanges a refresh token for a new JWT and rotated refresh token.
The old refresh token is immediately revoked. No authentication required —
the refresh token itself is the credential.
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/RefreshHostTokenRequest"
responses:
"200":
description: New JWT and rotated refresh token
content:
application/json:
schema:
$ref: "#/components/schemas/RefreshHostTokenResponse"
"401":
description: Invalid, expired, or revoked refresh token
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/delete-preview:
parameters:
- name: id
in: path
required: true
schema:
type: string
get:
summary: Preview host deletion
operationId: getHostDeletePreview
tags: [hosts]
security:
- bearerAuth: []
description: |
Returns the list of sandbox IDs that would be destroyed if the host
were deleted with `?force=true`. No state is modified.
responses:
"200":
description: Deletion preview
content:
application/json:
schema:
$ref: "#/components/schemas/HostDeletePreview"
"403":
description: Insufficient permissions
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: Host not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/tags:
parameters:
- name: id
@ -1405,7 +1485,7 @@ components:
type: apiKey
in: header
name: X-Host-Token
description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year.
description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days.
schemas:
SignupRequest:
@ -1505,7 +1585,7 @@ components:
type: string
status:
type: string
enum: [pending, running, paused, stopped, error]
enum: [pending, starting, running, paused, hibernated, stopped, missing, error]
template:
type: string
vcpus:
@ -1661,7 +1741,10 @@ components:
$ref: "#/components/schemas/Host"
token:
type: string
description: Long-lived host JWT for X-Host-Token header. Valid for 1 year.
description: Host JWT for X-Host-Token header. Valid for 7 days.
refresh_token:
type: string
description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.
Host:
type: object
@ -1697,7 +1780,7 @@ components:
nullable: true
status:
type: string
enum: [pending, online, offline, draining]
enum: [pending, online, offline, draining, unreachable]
last_heartbeat_at:
type: string
format: date-time
@ -1711,6 +1794,54 @@ components:
type: string
format: date-time
RefreshHostTokenRequest:
type: object
required: [refresh_token]
properties:
refresh_token:
type: string
description: Refresh token obtained from registration or a previous refresh.
RefreshHostTokenResponse:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
token:
type: string
description: New host JWT. Valid for 7 days.
refresh_token:
type: string
description: New refresh token. Valid for 60 days; old token is revoked.
HostDeletePreview:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
sandbox_ids:
type: array
items:
type: string
description: IDs of sandboxes that would be destroyed on force-delete.
HostHasSandboxesError:
type: object
properties:
error:
type: object
properties:
code:
type: string
example: host_has_sandboxes
message:
type: string
sandbox_ids:
type: array
items:
type: string
description: IDs of active sandboxes blocking deletion.
AddTagRequest:
type: object
required: [tag]

View File

@ -1,126 +0,0 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/internal/db"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// Reconciler periodically compares the host agent's sandbox list with the DB
// and marks sandboxes that no longer exist on the host as stopped.
type Reconciler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
hostID string
interval time.Duration
}
// NewReconciler creates a new reconciler.
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
return &Reconciler{
db: db,
agent: agent,
hostID: hostID,
interval: interval,
}
}
// Start runs the reconciliation loop until the context is cancelled.
func (rc *Reconciler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(rc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.reconcile(ctx)
}
}
}()
}
func (rc *Reconciler) reconcile(ctx context.Context) {
// Single RPC returns both the running sandbox list and any IDs that
// were auto-paused by the TTL reaper since the last call.
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
return
}
// Build a set of sandbox IDs that are alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
// Build auto-paused set from the same response.
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPausedSet[id] = struct{}{}
}
// Get all DB sandboxes for this host that are running.
// Paused sandboxes are excluded: they are expected to not exist on the
// host agent because pause = snapshot + destroy resources.
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: rc.hostID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
return
}
// Find sandboxes in DB that are no longer on the host.
var stale []string
for _, sb := range dbSandboxes {
if _, ok := alive[sb.ID]; !ok {
stale = append(stale, sb.ID)
}
}
if len(stale) == 0 {
return
}
// Split stale sandboxes into those auto-paused by the TTL reaper vs
// those that crashed/were orphaned.
var toPause, toStop []string
for _, id := range stale {
if _, ok := autoPausedSet[id]; ok {
toPause = append(toPause, id)
} else {
toStop = append(toStop, id)
}
}
if len(toPause) > 0 {
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
}
}
if len(toStop) > 0 {
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
}
}
}

View File

@ -11,8 +11,9 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
//go:embed openapi.yaml
@ -24,25 +25,34 @@ type Server struct {
}
// New constructs the chi router and registers all routes.
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
func New(
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
pgPool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
// Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
teamSvc := &service.TeamService{DB: queries, Pool: pool, Agent: agent}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
sandbox := newSandboxHandler(sandboxSvc)
exec := newExecHandler(queries, agent)
execStream := newExecStreamHandler(queries, agent)
files := newFilesHandler(queries, agent)
filesStream := newFilesStreamHandler(queries, agent)
snapshots := newSnapshotHandler(templateSvc, queries, agent)
authH := newAuthHandler(queries, pool, jwtSecret)
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool)
authH := newAuthHandler(queries, pgPool, jwtSecret)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc)
hostH := newHostHandler(hostSvc, queries)
teamH := newTeamHandler(teamSvc)
@ -123,6 +133,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
// Unauthenticated: one-time registration token.
r.Post("/register", hostH.Register)
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
@ -134,6 +147,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Route("/{id}", func(r chi.Router) {
r.Get("/", hostH.Get)
r.Delete("/", hostH.Delete)
r.Get("/delete-preview", hostH.DeletePreview)
r.Post("/token", hostH.RegenerateToken)
r.Get("/tags", hostH.ListTags)
r.Post("/tags", hostH.AddTag)

View File

@ -8,7 +8,8 @@ import (
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 8760 * time.Hour // 1 year
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens.
type Claims struct {

View File

@ -2,18 +2,16 @@ package config
import (
"os"
"strings"
"github.com/joho/godotenv"
)
// Config holds the control plane configuration.
type Config struct {
DatabaseURL string
RedisURL string
ListenAddr string
HostAgentAddr string
JWTSecret string
DatabaseURL string
RedisURL string
ListenAddr string
JWTSecret string
OAuthGitHubClientID string
OAuthGitHubClientSecret string
@ -27,25 +25,17 @@ func Load() Config {
// Best-effort load — missing .env file is fine.
_ = godotenv.Load()
cfg := Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
JWTSecret: os.Getenv("JWT_SECRET"),
return Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
}
// Ensure the host agent address has a scheme.
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
}
return cfg
}
func envOrDefault(key, def string) string {

View File

@ -0,0 +1,92 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: host_refresh_tokens.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
`
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
return err
}
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
`
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID string `json:"id"`
HostID string `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, insertHostRefreshToken,
arg.ID,
arg.HostID,
arg.TokenHash,
arg.ExpiresAt,
)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

View File

@ -234,6 +234,50 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
return i, err
}
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listActiveHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
`
@ -461,6 +505,15 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
return err
}
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
@ -521,6 +574,20 @@ func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
return err
}
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :exec
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1
`
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
return err
}
const updateHostStatus = `-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`

View File

@ -36,6 +36,15 @@ type Host struct {
MtlsEnabled bool `json:"mtls_enabled"`
}
type HostRefreshToken struct {
ID string `json:"id"`
HostID string `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
}
type HostTag struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`

View File

@ -11,6 +11,20 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::text[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
@ -300,6 +314,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
return items, nil
}
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
`
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
const updateLastActive = `-- name: UpdateLastActive :exec
UPDATE sandboxes
SET last_active_at = $2,

View File

@ -17,6 +17,13 @@ import (
"golang.org/x/sys/unix"
)
// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt.
type tokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
}
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
@ -35,8 +42,19 @@ type registerRequest struct {
}
type registerResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type errorResponse struct {
@ -46,20 +64,46 @@ type errorResponse struct {
} `json:"error"`
}
// loadTokenFile reads and parses the persisted token file.
func loadTokenFile(path string) (*tokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Support legacy format (raw JWT string) for backwards compatibility.
trimmed := strings.TrimSpace(string(data))
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
}
var tf tokenFile
if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse token file: %w", err)
}
return &tf, nil
}
// saveTokenFile writes the token file as JSON with 0600 permissions.
func saveTokenFile(path string, tf tokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
return fmt.Errorf("marshal token file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists
// the returned JWT to disk. Returns the host JWT token string.
// the returned JWT and refresh token to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// Check if we already have a saved token.
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
token := strings.TrimSpace(string(data))
if token != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile)
return token, nil
}
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf.JWT, nil
}
if cfg.RegistrationToken == "" {
return "", fmt.Errorf("no saved host token and no registration token provided")
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
}
arch := runtime.GOARCH
@ -117,45 +161,155 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
return "", fmt.Errorf("registration response missing token")
}
// Persist the token to disk for subsequent startups.
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
hostID, err := hostIDFromJWT(regResp.Token)
if err != nil {
return "", fmt.Errorf("extract host ID from JWT: %w", err)
}
// Persist JWT + refresh token.
tf := tokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return "", fmt.Errorf("save host token: %w", err)
}
slog.Info("host registered and token saved", "file", cfg.TokenFile)
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
return regResp.Token, nil
}
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
// It reads and updates the token file in place.
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
tf, err := loadTokenFile(tokenFilePath)
if err != nil {
return "", fmt.Errorf("load token file: %w", err)
}
if tf.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
var refResp refreshResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
return "", fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
return "", fmt.Errorf("save refreshed token: %w", err)
}
slog.Info("host JWT refreshed", "host_id", tf.HostID)
return refResp.Token, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
//
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func()) {
client := &http.Client{Timeout: 10 * time.Second}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
consecutiveFailures := 0
pausedDueToFailure := false
currentJWT := ""
// Load the current JWT from disk.
if tf, err := loadTokenFile(tokenFilePath); err == nil {
currentJWT = tf.JWT
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
slog.Warn("heartbeat: failed to create request", "error", err)
continue
}
req.Header.Set("X-Host-Token", hostToken)
req.Header.Set("X-Host-Token", currentJWT)
resp, err := client.Do(req)
if err != nil {
slog.Warn("heartbeat: request failed", "error", err)
consecutiveFailures++
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
if consecutiveFailures >= 3 && !pausedDueToFailure {
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
if pauseAll != nil {
pauseAll()
}
pausedDueToFailure = true
}
continue
}
resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
switch resp.StatusCode {
case http.StatusNoContent:
// Success.
if consecutiveFailures > 0 || pausedDueToFailure {
slog.Info("heartbeat: CP connection restored")
}
consecutiveFailures = 0
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
if refreshErr != nil {
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
pausedDueToFailure = true
}
// Stop the heartbeat loop — operator must re-register.
return
}
currentJWT = newJWT
slog.Info("heartbeat: JWT refreshed successfully")
default:
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
}
}
@ -166,6 +320,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
// HostIDFromToken extracts the host_id claim from a host JWT without
// verifying the signature (the agent doesn't have the signing secret).
func HostIDFromToken(token string) (string, error) {
return hostIDFromJWT(token)
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the token file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT format")

View File

@ -67,3 +67,17 @@ func NewRegistrationToken() string {
}
return hex.EncodeToString(b)
}
// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars.
func NewRefreshTokenID() string {
return "hrt-" + hex8()
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token.
func NewRefreshToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,77 @@
package lifecycle
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
// Clients are created lazily on first access and evicted when a host is removed
// or goes unreachable. The pool is safe for concurrent use.
type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
}
// NewHostClientPool creates a new pool. The underlying HTTP client uses a
// 10-minute timeout to support long-running streaming operations.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
}
}
// Get returns a Connect RPC client for the given host, creating one if necessary.
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
p.mu.RLock()
c, ok := p.clients[hostID]
p.mu.RUnlock()
if ok {
return c
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock.
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address))
p.clients[hostID] = c
return c
}
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if !h.Address.Valid || h.Address.String == "" {
return nil, fmt.Errorf("host %s has no address", h.ID)
}
return p.Get(h.ID, h.Address.String), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
// created on the next call to Get. Call this when a host's address changes or when
// a host is deleted.
func (p *HostClientPool) Evict(hostID string) {
p.mu.Lock()
delete(p.clients, hostID)
p.mu.Unlock()
}
// ensureScheme adds "http://" if the address has no scheme.
func ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return "http://" + addr
}

View File

@ -1183,6 +1183,28 @@ func (m *Manager) Shutdown(ctx context.Context) {
m.loops.ReleaseAll()
}
// PauseAll pauses every running sandbox managed by this host agent.
// Called when the host loses connectivity to the control plane to avoid
// leaving running VMs unmanaged. It is best-effort: failures for individual
// sandboxes are logged but do not stop the rest.
func (m *Manager) PauseAll(ctx context.Context) {
m.mu.RLock()
ids := make([]string, 0, len(m.boxes))
for id, sb := range m.boxes {
if sb.Status == models.StatusRunning {
ids = append(ids, id)
}
}
m.mu.RUnlock()
slog.Info("pausing all running sandboxes due to CP connection loss", "count", len(ids))
for _, sbID := range ids {
if err := m.Pause(ctx, sbID); err != nil {
slog.Warn("PauseAll: failed to pause sandbox", "id", sbID, "error", err)
}
}
}
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup
// in error paths where the primary error has already been captured.
func warnErr(msg string, id string, err error) {

View File

@ -0,0 +1,51 @@
package scheduler
import (
"context"
"fmt"
"sync/atomic"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// HostScheduler selects a host for a new sandbox. Implementations may use
// different strategies (round-robin, least-loaded, tag-based, etc.).
type HostScheduler interface {
// SelectHost returns a host that can accept a new sandbox.
// Returns an error if no suitable host is available.
SelectHost(ctx context.Context) (db.Host, error)
}
// RoundRobinScheduler cycles through online hosts in round-robin order.
// It re-fetches the host list on every call so that newly registered or
// recovered hosts are considered immediately.
type RoundRobinScheduler struct {
db *db.Queries
counter atomic.Int64
}
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
return &RoundRobinScheduler{db: queries}
}
// SelectHost returns the next online host in round-robin order.
func (s *RoundRobinScheduler) SelectHost(ctx context.Context) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
}
var online []db.Host
for _, h := range hosts {
if h.Status == "online" && h.Address.Valid && h.Address.String != "" {
online = append(online, h)
}
}
if len(online) == 0 {
return db.Host{}, fmt.Errorf("no online hosts available")
}
idx := s.counter.Add(1) - 1
return online[int(idx%int64(len(online)))], nil
}

View File

@ -2,12 +2,14 @@ package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
@ -15,6 +17,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// HostService provides host management operations.
@ -22,6 +26,7 @@ type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
}
// HostCreateParams holds the parameters for creating a host.
@ -50,10 +55,24 @@ type HostRegisterParams struct {
Address string
}
// HostRegisterResult holds the registered host and its long-lived JWT.
// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token.
type HostRegisterResult struct {
Host db.Host
JWT string
Host db.Host
JWT string
RefreshToken string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
@ -64,6 +83,14 @@ type regTokenPayload struct {
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
@ -75,7 +102,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: admin or team owner.
// BYOC: platform admin, or team owner/admin.
if p.TeamID == "" {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
@ -90,8 +117,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
@ -168,7 +195,6 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
// Same permission model as Create/Delete.
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
@ -186,8 +212,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
@ -216,7 +242,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a long-lived host JWT.
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
@ -264,18 +290,89 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil
}
// Heartbeat updates the last heartbeat timestamp for a host.
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'.
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
return s.DB.UpdateHostHeartbeat(ctx, hostID)
return s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
}
// List returns hosts visible to the caller.
@ -301,37 +398,135 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo
return host, nil
}
// Delete removes a host. Admins can delete any host. Team owners can delete
// BYOC hosts belonging to their team.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
host, err := s.DB.GetHost(ctx, hostID)
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin)
if err != nil {
return fmt.Errorf("host not found: %w", err)
return HostDeletePreview{}, err
}
if !isAdmin {
if host.Type != "byoc" {
return fmt.Errorf("forbidden: only admins can delete regular hosts")
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = sb.ID
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = sb.ID
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return fmt.Errorf("forbidden: host does not belong to your team")
return &HostHasSandboxesError{SandboxIDs: ids}
}
// Gracefully destroy running sandboxes on the host agent (best-effort).
if len(sandboxes) > 0 && host.Address.Valid && host.Address.String != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sb.ID,
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr)
}
}
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]string, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(hostID)
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID != "" {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("forbidden: not a member of the specified team")
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return fmt.Errorf("check team membership: %w", err)
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return s.DB.DeleteHost(ctx, hostID)
return host, nil
}
// AddTag adds a tag to a host.
@ -357,3 +552,14 @@ func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdm
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

View File

@ -11,16 +11,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard.
type SandboxService struct {
DB *db.Queries
Agent hostagentv1connect.HostAgentServiceClient
DB *db.Queries
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
}
// SandboxCreateParams holds the parameters for creating a sandbox.
@ -32,8 +34,34 @@ type SandboxCreateParams struct {
TimeoutSec int32
}
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
// and updates the record to running. Returns the sandbox DB row.
// agentForSandbox looks up the host for the given sandbox and returns a client.
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" {
p.Template = "minimal"
@ -58,12 +86,23 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
}
}
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID()
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID,
TeamID: p.TeamID,
HostID: "default",
HostID: host.ID,
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
@ -73,7 +112,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxID,
Template: p.Template,
Vcpus: p.VCPUs,
@ -126,7 +165,12 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxID,
})); err != nil {
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
@ -151,7 +195,12 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
}
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxID,
TimeoutSec: sb.TimeoutSec,
}))
@ -181,8 +230,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
return fmt.Errorf("sandbox not found: %w", err)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxID,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
@ -206,7 +260,12 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxID,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)

View File

@ -14,17 +14,17 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
// TeamService provides team management operations.
type TeamService struct {
DB *db.Queries
Pool *pgxpool.Pool
Agent hostagentv1connect.HostAgentServiceClient
DB *db.Queries
Pool *pgxpool.Pool
HostPool *lifecycle.HostClientPool
}
// TeamWithRole pairs a team with the calling user's role in it.
@ -177,10 +177,16 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin
var stopIDs []string
for _, sb := range sandboxes {
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sb.ID,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sb.ID,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}