diff --git a/.env.example b/.env.example index f128de7..dee152c 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,6 @@ REDIS_URL=redis://localhost:6379/0 # Control Plane CP_LISTEN_ADDR=:8000 -CP_HOST_AGENT_ADDR=localhost:50051 # Host Agent AGENT_LISTEN_ADDR=:50051 diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 3f52b41..b11051e 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -17,7 +17,8 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/db" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/scheduler" ) func main() { @@ -66,12 +67,11 @@ func main() { } slog.Info("connected to redis") - // Connect RPC client for the host agent. - agentHTTP := &http.Client{Timeout: 10 * time.Minute} - agentClient := hostagentv1connect.NewHostAgentServiceClient( - agentHTTP, - cfg.HostAgentAddr, - ) + // Host client pool — manages Connect RPC clients to host agents. + hostPool := lifecycle.NewHostClientPool() + + // Scheduler — picks a host for each new sandbox (round-robin for now). + hostScheduler := scheduler.NewRoundRobinScheduler(queries) // OAuth provider registry. oauthRegistry := oauth.NewRegistry() @@ -87,11 +87,11 @@ func main() { } // API server. - srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) - // Start reconciler. - reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second) - reconciler.Start(ctx) + // Start host monitor (passive + active reconciliation every 30s). + monitor := api.NewHostMonitor(queries, hostPool, 30*time.Second) + monitor.Start(ctx) httpServer := &http.Server{ Addr: cfg.ListenAddr, @@ -114,7 +114,7 @@ func main() { } }() - slog.Info("control plane starting", "addr", cfg.ListenAddr, "agent", cfg.HostAgentAddr) + slog.Info("control plane starting", "addr", cfg.ListenAddr) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { slog.Error("http server error", "error", err) os.Exit(1) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index c426a81..2d34cd1 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -8,9 +8,12 @@ import ( "os" "os/signal" "path/filepath" + "sync" "syscall" "time" + "github.com/joho/godotenv" + "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/hostagent" "git.omukk.dev/wrenn/sandbox/internal/sandbox" @@ -18,7 +21,10 @@ import ( ) func main() { - registrationToken := flag.String("register", "", "One-time registration token from the control plane") + // Best-effort load — missing .env file is fine. + _ = godotenv.Load() + + registrationToken := flag.String("register", "", "One-time registration token from the control plane (required on first run)") advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent") flag.Parse() @@ -42,7 +48,16 @@ func main() { listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051") rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn") cpURL := os.Getenv("AGENT_CP_URL") - tokenFile := filepath.Join(rootDir, "host-token") + tokenFile := filepath.Join(rootDir, "host.jwt") + + if cpURL == "" { + slog.Error("AGENT_CP_URL environment variable is required") + os.Exit(1) + } + if *advertiseAddr == "" { + slog.Error("--address flag is required (externally-reachable ip:port)") + os.Exit(1) + } cfg := sandbox.Config{ KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"), @@ -58,64 +73,80 @@ func main() { mgr.StartTTLReaper(ctx) - if *advertiseAddr == "" { - slog.Error("--address flag is required (externally-reachable ip:port)") + // Register with the control plane and start heartbeating. + hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + CPURL: cpURL, + RegistrationToken: *registrationToken, + TokenFile: tokenFile, + Address: *advertiseAddr, + }) + if err != nil { + slog.Error("host registration failed", "error", err) os.Exit(1) } - // Register with the control plane (if configured). - if cpURL != "" { - hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ - CPURL: cpURL, - RegistrationToken: *registrationToken, - TokenFile: tokenFile, - Address: *advertiseAddr, - }) - if err != nil { - slog.Error("host registration failed", "error", err) - os.Exit(1) - } - - hostID, err := hostagent.HostIDFromToken(hostToken) - if err != nil { - slog.Error("failed to extract host ID from token", "error", err) - os.Exit(1) - } - - slog.Info("host registered", "host_id", hostID) - hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second) + hostID, err := hostagent.HostIDFromToken(hostToken) + if err != nil { + slog.Error("failed to extract host ID from token", "error", err) + os.Exit(1) } - srv := hostagent.NewServer(mgr) + slog.Info("host registered", "host_id", hostID) + + // httpServer is declared here so the shutdown func can reference it. + httpServer := &http.Server{Addr: listenAddr} + + // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown + // and httpServer.Shutdown are each called exactly once regardless of + // whether shutdown is triggered by a signal, a heartbeat 404, or the + // Terminate RPC. + var shutdownOnce sync.Once + doShutdown := func(reason string) { + shutdownOnce.Do(func() { + slog.Info("shutting down", "reason", reason) + cancel() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + mgr.Shutdown(shutdownCtx) + if err := httpServer.Shutdown(shutdownCtx); err != nil { + slog.Error("http server shutdown error", "error", err) + } + }) + } + + srv := hostagent.NewServer(mgr, func() { + doShutdown("Terminate RPC received") + }) path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) mux := http.NewServeMux() mux.Handle(path, handler) + httpServer.Handler = mux - httpServer := &http.Server{ - Addr: listenAddr, - Handler: mux, - } + // Start heartbeat loop. Handler must be set before this because the + // immediate beat can trigger doShutdown → httpServer.Shutdown synchronously. + hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, + // pauseAll: called on 3 consecutive network failures. + func() { + pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer pauseCancel() + mgr.PauseAll(pauseCtx) + }, + // onDeleted: called when CP returns 404 (host was deleted). + func() { + doShutdown("host deleted from CP") + }, + ) - // Graceful shutdown on signal. + // Graceful shutdown on SIGINT/SIGTERM. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigCh - slog.Info("received signal, shutting down", "signal", sig) - cancel() - - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer shutdownCancel() - - mgr.Shutdown(shutdownCtx) - - if err := httpServer.Shutdown(shutdownCtx); err != nil { - slog.Error("http server shutdown error", "error", err) - } + doShutdown("signal: " + sig.String()) }() - slog.Info("host agent starting", "addr", listenAddr) + slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { slog.Error("http server error", "error", err) os.Exit(1) diff --git a/db/migrations/20260324120214_host_refresh_tokens.sql b/db/migrations/20260324120214_host_refresh_tokens.sql new file mode 100644 index 0000000..02a13f7 --- /dev/null +++ b/db/migrations/20260324120214_host_refresh_tokens.sql @@ -0,0 +1,19 @@ +-- +goose Up + +-- Refresh tokens for host agent JWT rotation. +-- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation). +-- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that. +CREATE TABLE host_refresh_tokens ( + id TEXT PRIMARY KEY, + host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete +); + +CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id); + +-- +goose Down + +DROP TABLE host_refresh_tokens; diff --git a/db/queries/host_refresh_tokens.sql b/db/queries/host_refresh_tokens.sql new file mode 100644 index 0000000..5a41164 --- /dev/null +++ b/db/queries/host_refresh_tokens.sql @@ -0,0 +1,19 @@ +-- name: InsertHostRefreshToken :one +INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING *; + +-- name: GetHostRefreshTokenByHash :one +SELECT * FROM host_refresh_tokens +WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW(); + +-- name: RevokeHostRefreshToken :exec +UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1; + +-- name: RevokeHostRefreshTokensByHost :exec +UPDATE host_refresh_tokens SET revoked_at = NOW() +WHERE host_id = $1 AND revoked_at IS NULL; + +-- name: DeleteExpiredHostRefreshTokens :exec +DELETE FROM host_refresh_tokens +WHERE expires_at < NOW() OR revoked_at IS NOT NULL; diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql index 7f8c9e4..27ece00 100644 --- a/db/queries/hosts.sql +++ b/db/queries/hosts.sql @@ -67,3 +67,19 @@ SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC; -- name: GetHostByTeam :one SELECT * FROM hosts WHERE id = $1 AND team_id = $2; + +-- name: ListActiveHosts :many +-- Returns all hosts that have completed registration (not pending/offline). +SELECT * FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at; + +-- name: UpdateHostHeartbeatAndStatus :execrows +-- Updates last_heartbeat_at and transitions unreachable hosts back to online. +-- Returns 0 if no host was found (deleted), which the caller treats as 404. +UPDATE hosts +SET last_heartbeat_at = NOW(), + status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END, + updated_at = NOW() +WHERE id = $1; + +-- name: MarkHostUnreachable :exec +UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index d897bff..131fe1e 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -56,3 +56,20 @@ WHERE id = ANY($1::text[]); SELECT * FROM sandboxes WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') ORDER BY created_at DESC; + +-- name: MarkSandboxesMissingByHost :exec +-- Called when the host monitor marks a host unreachable. +-- Marks running/starting/pending sandboxes on that host as 'missing' so users see +-- the sandbox is not currently reachable, without permanently losing the record. +UPDATE sandboxes +SET status = 'missing', + last_updated = NOW() +WHERE host_id = $1 AND status IN ('running', 'starting', 'pending'); + +-- name: BulkRestoreRunning :exec +-- Called by the reconciler when a host comes back online and its sandboxes are +-- confirmed alive. Restores only sandboxes that are in 'missing' state. +UPDATE sandboxes +SET status = 'running', + last_updated = NOW() +WHERE id = ANY($1::text[]) AND status = 'missing'; diff --git a/frontend/src/lib/api/hosts.ts b/frontend/src/lib/api/hosts.ts new file mode 100644 index 0000000..031b7f0 --- /dev/null +++ b/frontend/src/lib/api/hosts.ts @@ -0,0 +1,84 @@ +import { apiFetch } from './client'; + +export type Host = { + id: string; + type: 'regular' | 'byoc'; + team_id?: string; + team_name?: string; + provider?: string; + availability_zone?: string; + arch?: string; + cpu_cores?: number; + memory_mb?: number; + disk_gb?: number; + address?: string; + status: 'pending' | 'online' | 'offline' | 'unreachable' | 'draining'; + last_heartbeat_at?: string; + created_by: string; + created_at: string; + updated_at: string; +}; + +export type CreateHostParams = { + type: 'regular' | 'byoc'; + team_id?: string; + provider?: string; + availability_zone?: string; +}; + +export type CreateHostResult = { + host: Host; + registration_token: string; +}; + +export async function listHosts(): Promise<{ ok: true; data: Host[] } | { ok: false; error: string }> { + return apiFetch('GET', '/api/v1/hosts'); +} + +export async function createHost( + params: CreateHostParams +): Promise<{ ok: true; data: CreateHostResult } | { ok: false; error: string }> { + return apiFetch('POST', '/api/v1/hosts', params); +} + +export async function deleteHost( + id: string, + force = false +): Promise<{ ok: true } | { ok: false; error: string; sandbox_ids?: string[] }> { + const url = `/api/v1/hosts/${id}${force ? '?force=true' : ''}`; + const res = await apiFetch('DELETE', url); + if (!res.ok) { + return res as { ok: false; error: string }; + } + return { ok: true }; +} + +export async function getDeletePreview( + id: string +): Promise<{ ok: true; data: { host: Host; sandbox_ids: string[] } } | { ok: false; error: string }> { + return apiFetch<{ host: Host; sandbox_ids: string[] }>('GET', `/api/v1/hosts/${id}/delete-preview`); +} + +export function statusColor(status: Host['status']): string { + switch (status) { + case 'online': + return 'var(--color-accent)'; + case 'pending': + return 'var(--color-amber)'; + case 'offline': + case 'unreachable': + return 'var(--color-red)'; + case 'draining': + return 'var(--color-blue)'; + default: + return 'var(--color-text-muted)'; + } +} + +export function formatSpecs(host: Host): string { + const parts: string[] = []; + if (host.cpu_cores) parts.push(`${host.cpu_cores} vCPU`); + if (host.memory_mb) parts.push(`${Math.round(host.memory_mb / 1024)}GB RAM`); + if (host.disk_gb) parts.push(`${host.disk_gb}GB disk`); + return parts.join(' · ') || '—'; +} diff --git a/frontend/src/lib/api/team.ts b/frontend/src/lib/api/team.ts index a1ff935..0ffc4ed 100644 --- a/frontend/src/lib/api/team.ts +++ b/frontend/src/lib/api/team.ts @@ -29,6 +29,7 @@ export type TeamWithRole = { id: string; name: string; slug: string; + is_byoc: boolean; created_at: string; role: string; }; diff --git a/frontend/src/lib/auth.svelte.ts b/frontend/src/lib/auth.svelte.ts index b42cf52..d39a0f4 100644 --- a/frontend/src/lib/auth.svelte.ts +++ b/frontend/src/lib/auth.svelte.ts @@ -19,12 +19,23 @@ function isTokenExpired(token: string): boolean { } } +function decodeJWTPayload(token: string): Record { + try { + const payload = token.split('.')[1]; + return JSON.parse(atob(payload.replace(/-/g, '+').replace(/_/g, '/'))); + } catch { + return {}; + } +} + function createAuth() { let token = $state(null); let userId = $state(null); let teamId = $state(null); let email = $state(null); let name = $state(null); + let isAdmin = $state(false); + let role = $state('member'); let initialized = $state(false); // Initialize from localStorage synchronously at module load. @@ -36,6 +47,9 @@ function createAuth() { teamId = localStorage.getItem(STORAGE_KEYS.teamId); email = localStorage.getItem(STORAGE_KEYS.email); name = localStorage.getItem(STORAGE_KEYS.name); + const payload = decodeJWTPayload(stored); + isAdmin = Boolean(payload.is_admin); + role = String(payload.role || 'member'); } else if (stored) { // Expired — clean up. for (const key of Object.values(STORAGE_KEYS)) { @@ -63,6 +77,12 @@ function createAuth() { get name() { return name; }, + get isAdmin() { + return isAdmin; + }, + get role() { + return role; + }, get isAuthenticated() { return isAuthenticated; }, @@ -76,6 +96,9 @@ function createAuth() { teamId = data.team_id; email = data.email; name = data.name; + const payload = decodeJWTPayload(data.token); + isAdmin = Boolean(payload.is_admin); + role = String(payload.role || 'member'); localStorage.setItem(STORAGE_KEYS.token, data.token); localStorage.setItem(STORAGE_KEYS.userId, data.user_id); @@ -90,6 +113,8 @@ function createAuth() { teamId = null; email = null; name = null; + isAdmin = false; + role = 'member'; for (const key of Object.values(STORAGE_KEYS)) { localStorage.removeItem(key); diff --git a/frontend/src/lib/components/AdminSidebar.svelte b/frontend/src/lib/components/AdminSidebar.svelte new file mode 100644 index 0000000..4bed5cc --- /dev/null +++ b/frontend/src/lib/components/AdminSidebar.svelte @@ -0,0 +1,184 @@ + + + diff --git a/frontend/src/lib/components/Sidebar.svelte b/frontend/src/lib/components/Sidebar.svelte index 47feb54..c9a7afc 100644 --- a/frontend/src/lib/components/Sidebar.svelte +++ b/frontend/src/lib/components/Sidebar.svelte @@ -19,7 +19,9 @@ IconSidebar, IconBell, IconDocs, - IconAudit + IconAudit, + IconServer, + IconShield } from './icons'; let { collapsed = $bindable(false) }: { collapsed: boolean } = $props(); @@ -39,6 +41,8 @@ label: string; icon: typeof IconMonitor; href: string; + disabled?: boolean; + disabledHint?: string; }; const platformItems: NavItem[] = [ @@ -46,11 +50,24 @@ { label: 'Templates', icon: IconBox, href: '/dashboard/snapshots' } ]; - const managementItems: NavItem[] = [ + let currentTeamIsByoc = $derived( + teamsStore.list.find((t) => t.id === auth.teamId)?.is_byoc ?? false + ); + + let managementItems = $derived([ { label: 'Keys', icon: IconKey, href: '/dashboard/keys' }, { label: 'Team', icon: IconMembers, href: '/dashboard/team' }, - { label: 'Audit Logs', icon: IconAudit, href: '/dashboard/audit' } - ]; + { label: 'Audit Logs', icon: IconAudit, href: '/dashboard/audit' }, + ...(currentTeamIsByoc + ? [{ + label: 'BYOC', + icon: IconServer, + href: '/dashboard/byoc', + disabled: auth.role === 'member', + disabledHint: 'Available to team owners and admins only' + }] + : []) + ]); const billingItems: NavItem[] = [ { label: 'Usage', icon: IconUsage, href: '/dashboard/usage' }, @@ -232,6 +249,16 @@
+ {#if auth.isAdmin} + + + {#if !collapsed}Admin{/if} + + {/if} {/if} {#each items as item} - {#if isActive(item.href)} + {#if item.disabled} +
+ + {#if !collapsed} + {item.label} + {/if} +
+ {:else if isActive(item.href)}
+ let { size = 18, class: className = '' }: { size?: number; class?: string } = $props(); + + + diff --git a/frontend/src/lib/components/icons/IconServer.svelte b/frontend/src/lib/components/icons/IconServer.svelte new file mode 100644 index 0000000..c1ae7b9 --- /dev/null +++ b/frontend/src/lib/components/icons/IconServer.svelte @@ -0,0 +1,21 @@ + + + diff --git a/frontend/src/lib/components/icons/IconShield.svelte b/frontend/src/lib/components/icons/IconShield.svelte new file mode 100644 index 0000000..056bc16 --- /dev/null +++ b/frontend/src/lib/components/icons/IconShield.svelte @@ -0,0 +1,18 @@ + + + diff --git a/frontend/src/lib/components/icons/index.ts b/frontend/src/lib/components/icons/index.ts index 6296641..fa90069 100644 --- a/frontend/src/lib/components/icons/index.ts +++ b/frontend/src/lib/components/icons/index.ts @@ -23,3 +23,6 @@ export { default as IconBell } from './IconBell.svelte'; export { default as IconDocs } from './IconDocs.svelte'; export { default as IconAudit } from './IconAudit.svelte'; export { default as IconBox } from './IconBox.svelte'; +export { default as IconServer } from './IconServer.svelte'; +export { default as IconGear } from './IconGear.svelte'; +export { default as IconShield } from './IconShield.svelte'; diff --git a/frontend/src/routes/admin/+layout.svelte b/frontend/src/routes/admin/+layout.svelte new file mode 100644 index 0000000..599de61 --- /dev/null +++ b/frontend/src/routes/admin/+layout.svelte @@ -0,0 +1,7 @@ + + + +{@render children()} diff --git a/frontend/src/routes/admin/+layout.ts b/frontend/src/routes/admin/+layout.ts new file mode 100644 index 0000000..0f46f49 --- /dev/null +++ b/frontend/src/routes/admin/+layout.ts @@ -0,0 +1,9 @@ +import { browser } from '$app/environment'; +import { redirect } from '@sveltejs/kit'; +import { auth } from '$lib/auth.svelte'; + +export const load = () => { + if (!browser) return; + if (!auth.isAuthenticated) redirect(302, '/login'); + if (!auth.isAdmin) redirect(302, '/dashboard'); +}; diff --git a/frontend/src/routes/admin/+page.svelte b/frontend/src/routes/admin/+page.svelte new file mode 100644 index 0000000..b5a56c1 --- /dev/null +++ b/frontend/src/routes/admin/+page.svelte @@ -0,0 +1,5 @@ + diff --git a/frontend/src/routes/admin/hosts/+page.svelte b/frontend/src/routes/admin/hosts/+page.svelte new file mode 100644 index 0000000..16c7476 --- /dev/null +++ b/frontend/src/routes/admin/hosts/+page.svelte @@ -0,0 +1,679 @@ + + +
+ + +
+ +
+
+
+

+ Hosts +

+

+ Platform and BYOC compute across all teams. +

+
+ {#if activeTab === 'platform'} + + {/if} +
+ + + {#if !loading && !error} +
+
+ {totalCount} + total +
+
+ + + + + {onlineCount} + online +
+ {#if pendingCount > 0} +
+ {pendingCount} + pending +
+ {/if} +
+ {/if} +
+ + +
+ {#each [['platform', 'Platform', platformHosts.length], ['byoc', 'BYOC', byocHosts.length]] as [id, label, count] (id)} + + {/each} +
+ + +
+ {#if loading} + {@render skeletonRows()} + {:else if error} +
+ {error} +
+ {:else if activeTab === 'platform'} + {@render hostsTable(platformHosts, false)} + {:else} + + {#if byocHosts.length === 0} + {@render emptyState('byoc')} + {:else} +
+ {#each byocGroups as group (group.teamId ?? '__none__')} + {@const groupPageHosts = byocPageHosts.filter(h => h.team_id === group.teamId || (group.teamId === null && !h.team_id))} + {#if groupPageHosts.length > 0} +
+
+ + {group.teamName} + + + {group.hosts.length} + +
+ {@render hostsTable(groupPageHosts, false)} +
+ {/if} + {/each} + + + {#if byocPageCount > 1} +
+ + Page {byocPage + 1} of {byocPageCount} · {byocHosts.length} hosts + +
+ + +
+
+ {/if} +
+ {/if} + {/if} +
+
+
+ +{#snippet skeletonRows()} +
+ + + + + + + + + + + + {#each Array(5) as _, i} + + + + + + + + {/each} + +
HostStatus
+
+
+
+
+
+
+
+
+{/snippet} + +{#snippet hostsTable(hosts: Host[], _showTeam: boolean)} + {#if hosts.length === 0} + {@render emptyState('platform')} + {:else} +
+ + + + + + + + + + + + {#each hosts as host (host.id)} + + + + + + + + {/each} + +
HostStatus
+
{host.id}
+ {#if host.address} +
{host.address}
+ {/if} + {#if host.provider || host.availability_zone} +
+ {[host.provider, host.availability_zone].filter(Boolean).join(' · ')} +
+ {/if} +
+ + {#if host.status === 'online'} + + + + + {:else} + + {/if} + {host.status} + + + +
+
+ {/if} +{/snippet} + +{#snippet emptyState(type: 'platform' | 'byoc')} +
+
+ +
+

+ {type === 'platform' ? 'No platform hosts yet.' : 'No BYOC hosts across any team.'} +

+

+ {type === 'platform' + ? 'Add a host to start scheduling capsules onto your own compute.' + : 'Teams that register their own compute will appear here.'} +

+
+{/snippet} + + +{#if showCreate} +
+
{ if (!creating) showCreate = false; }} + onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }} + >
+
+

+ Add Platform Host +

+

+ Register a new platform-managed host. You'll receive a one-time registration token. +

+ + {#if createError} +
+ {createError} +
+ {/if} + +
+
+ + +
+
+ + +
+
+ +
+ + +
+
+
+{/if} + + +{#if createdResult} +
+
+
+ +
+ + + +
+ +

+ Host registered +

+

+ Pass this token to the host agent to complete registration. It expires in + 1 hour and is single-use. +

+ +
+
+ + {createdResult.registration_token} + + +
+
+ +
+ +

+ This token will not be shown again. Store it safely before closing. +

+
+ +
+ +
+
+
+{/if} + + +{#if deleteTarget} +
+
{ if (!deleting) deleteTarget = null; }} + onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }} + >
+
+

+ Delete Host +

+

+ Permanently remove {deleteTarget.id}. +

+ + {#if deletePreviewLoading} +
+ + Checking active capsules… +
+ {:else if deletePreviewSandboxes.length > 0} +
+

+ {deletePreviewSandboxes.length} active capsule{deletePreviewSandboxes.length === 1 ? '' : 's'} will be destroyed. +

+

+ All running workloads on this host will be terminated immediately. +

+
+ {/if} + + {#if deleteError} +
+ {deleteError} +
+ {/if} + +
+ + +
+
+
+{/if} + + diff --git a/frontend/src/routes/dashboard/byoc/+page.svelte b/frontend/src/routes/dashboard/byoc/+page.svelte new file mode 100644 index 0000000..acd682f --- /dev/null +++ b/frontend/src/routes/dashboard/byoc/+page.svelte @@ -0,0 +1,587 @@ + + +
+ + +
+ +
+
+
+

+ BYOC Hosts +

+

+ Your own compute, running Wrenn capsules. +

+
+ {#if canManage} + + {/if} +
+ + + {#if !loading && !error && hosts.length > 0} +
+
+ {hosts.length} + total +
+
+ + + + + {onlineCount} + online +
+
+ {/if} +
+ + +
+ {#if loading} + {@render skeletonRows()} + {:else if error} +
+ {error} +
+ {:else if hosts.length === 0} + {@render emptyState()} + {:else} +
+ + + + + + + + + {#if canManage} + + {/if} + + + + {#each hosts as host (host.id)} + + + + + + + {#if canManage} + + {/if} + + {/each} + +
HostStatus
+
{host.id}
+ {#if host.address} +
{host.address}
+ {/if} + {#if host.provider || host.availability_zone} +
+ {[host.provider, host.availability_zone].filter(Boolean).join(' · ')} +
+ {/if} +
+ + {#if host.status === 'online'} + + + + + {:else} + + {/if} + {host.status} + + + +
+
+ {/if} +
+
+
+ +{#snippet skeletonRows()} +
+ + + + + + + + + + + + {#each Array(4) as _, i} + + + + + + + + {/each} + +
HostStatus
+
+
+
+
+
+
+{/snippet} + +{#snippet emptyState()} +
+
+ +
+ {#if canManage} +

+ No hosts yet. +

+

+ Register a server and Wrenn will schedule capsules on your own infrastructure. +

+ + {:else} +

+ No hosts registered. +

+

+ Ask a team owner or admin to register a BYOC host for your team. +

+ {/if} +
+{/snippet} + + +{#if showCreate} +
+
{ if (!creating) showCreate = false; }} + onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }} + >
+
+

+ Register Host +

+

+ Add a server to your team's BYOC pool. You'll receive a one-time registration token. +

+ + {#if createError} +
+ {createError} +
+ {/if} + +
+
+ + +
+
+ + +
+
+ +
+ + +
+
+
+{/if} + + +{#if createdResult} +
+
+
+ +
+ + + +
+ +

+ Host registered +

+

+ Pass this token to the host agent to complete registration. It expires in + 1 hour and is single-use. +

+ +
+
+ + {createdResult.registration_token} + + +
+
+ +
+ +

+ This token will not be shown again. Store it safely before closing. +

+
+ +
+ +
+
+
+{/if} + + +{#if deleteTarget} +
+
{ if (!deleting) deleteTarget = null; }} + onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }} + >
+
+

+ Delete Host +

+

+ Remove {deleteTarget.id} from your BYOC pool. +

+ + {#if deletePreviewLoading} +
+ + Checking active capsules… +
+ {:else if deletePreviewSandboxes.length > 0} +
+

+ {deletePreviewSandboxes.length} active capsule{deletePreviewSandboxes.length === 1 ? '' : 's'} will be destroyed. +

+

+ All running workloads on this host will be terminated immediately. +

+
+ {/if} + + {#if deleteError} +
+ {deleteError} +
+ {/if} + +
+ + +
+
+
+{/if} + + diff --git a/internal/api/agent_helper.go b/internal/api/agent_helper.go new file mode 100644 index 0000000..ac5b38e --- /dev/null +++ b/internal/api/agent_helper.go @@ -0,0 +1,20 @@ +package api + +import ( + "context" + "fmt" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" +) + +// agentForHost looks up the host record and returns a Connect RPC client for it. +// Returns an error if the host is not found or has no address. +func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) { + host, err := queries.GetHost(ctx, hostID) + if err != nil { + return nil, fmt.Errorf("host not found: %w", err) + } + return pool.GetForHost(host) +} diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go index ae63883..ba60d8e 100644 --- a/internal/api/handlers_auth.go +++ b/internal/api/handlers_auth.go @@ -168,7 +168,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner") + token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") return @@ -228,7 +228,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role) + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") return @@ -298,7 +298,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role) + token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") return diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go index 9307a67..84b3833 100644 --- a/internal/api/handlers_exec.go +++ b/internal/api/handlers_exec.go @@ -14,17 +14,17 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) type execHandler struct { - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient + db *db.Queries + pool *lifecycle.HostClientPool } -func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler { - return &execHandler{db: db, agent: agent} +func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler { + return &execHandler{db: db, pool: pool} } type execRequest struct { @@ -73,7 +73,13 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { start := time.Now() - resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{ + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + + resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{ SandboxId: sandboxID, Cmd: req.Cmd, Args: req.Args, diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go index 009f41b..3ecfdfe 100644 --- a/internal/api/handlers_exec_stream.go +++ b/internal/api/handlers_exec_stream.go @@ -14,17 +14,17 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) type execStreamHandler struct { - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient + db *db.Queries + pool *lifecycle.HostClientPool } -func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler { - return &execStreamHandler{db: db, agent: agent} +func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler { + return &execStreamHandler{db: db, pool: pool} } var upgrader = websocket.Upgrader{ @@ -80,11 +80,17 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { return } + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + sendWSError(conn, "sandbox host is not reachable") + return + } + // Open streaming exec to host agent. streamCtx, cancel := context.WithCancel(ctx) defer cancel() - stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{ + stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{ SandboxId: sandboxID, Cmd: startMsg.Cmd, Args: startMsg.Args, diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go index c1c0291..c5fff70 100644 --- a/internal/api/handlers_files.go +++ b/internal/api/handlers_files.go @@ -11,17 +11,17 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) type filesHandler struct { - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient + db *db.Queries + pool *lifecycle.HostClientPool } -func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler { - return &filesHandler{db: db, agent: agent} +func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler { + return &filesHandler{db: db, pool: pool} } // Upload handles POST /v1/sandboxes/{id}/files/write. @@ -75,7 +75,13 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { return } - if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{ + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + + if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{ SandboxId: sandboxID, Path: filePath, Content: content, @@ -120,7 +126,13 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { return } - resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{ + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + + resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{ SandboxId: sandboxID, Path: req.Path, })) diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go index 66a3c5b..66e89c7 100644 --- a/internal/api/handlers_files_stream.go +++ b/internal/api/handlers_files_stream.go @@ -12,17 +12,17 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) type filesStreamHandler struct { - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient + db *db.Queries + pool *lifecycle.HostClientPool } -func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler { - return &filesStreamHandler{db: db, agent: agent} +func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler { + return &filesStreamHandler{db: db, pool: pool} } // StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write. @@ -88,8 +88,14 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request } defer filePart.Close() + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + // Open client-streaming RPC to host agent. - stream := h.agent.WriteFileStream(ctx) + stream := agent.WriteFileStream(ctx) // Send metadata first. if err := stream.Send(&pb.WriteFileStreamRequest{ @@ -164,8 +170,14 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque return } + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + // Open server-streaming RPC to host agent. - stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{ + stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{ SandboxId: sandboxID, Path: req.Path, })) diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go index a6484a3..762fc91 100644 --- a/internal/api/handlers_hosts.go +++ b/internal/api/handlers_hosts.go @@ -1,6 +1,7 @@ package api import ( + "errors" "net/http" "time" @@ -34,6 +35,25 @@ type createHostResponse struct { RegistrationToken string `json:"registration_token"` } +type refreshTokenRequest struct { + RefreshToken string `json:"refresh_token"` +} + +type refreshTokenResponse struct { + Host hostResponse `json:"host"` + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` +} + +type deletePreviewResponse struct { + Host hostResponse `json:"host"` + SandboxIDs []string `json:"sandbox_ids"` +} + +type hasSandboxesErrorResponse struct { + SandboxIDs []string `json:"sandbox_ids"` +} + type registerHostRequest struct { Token string `json:"token"` Arch string `json:"arch,omitempty"` @@ -44,8 +64,9 @@ type registerHostRequest struct { } type registerHostResponse struct { - Host hostResponse `json:"host"` - Token string `json:"token"` + Host hostResponse `json:"host"` + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` } type addTagRequest struct { @@ -56,6 +77,7 @@ type hostResponse struct { ID string `json:"id"` Type string `json:"type"` TeamID *string `json:"team_id,omitempty"` + TeamName *string `json:"team_name,omitempty"` Provider *string `json:"provider,omitempty"` AvailabilityZone *string `json:"availability_zone,omitempty"` Arch *string `json:"arch,omitempty"` @@ -153,16 +175,41 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { // List handles GET /v1/hosts. func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(r.Context()) + admin := h.isAdmin(r, ac.UserID) - hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID)) + hosts, err := h.svc.List(r.Context(), ac.TeamID, admin) if err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts") return } + // Collect unique team IDs so we can fetch team names in one pass. + var teamNames map[string]string + if admin { + seen := make(map[string]struct{}) + for _, host := range hosts { + if host.TeamID.Valid { + seen[host.TeamID.String] = struct{}{} + } + } + if len(seen) > 0 { + teamNames = make(map[string]string, len(seen)) + for id := range seen { + if team, err := h.queries.GetTeam(r.Context(), id); err == nil { + teamNames[id] = team.Name + } + } + } + } + resp := make([]hostResponse, len(hosts)) for i, host := range hosts { resp[i] = hostToResponse(host) + if host.TeamID.Valid { + if name, ok := teamNames[host.TeamID.String]; ok { + resp[i].TeamName = &name + } + } } writeJSON(w, http.StatusOK, resp) @@ -183,18 +230,54 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, hostToResponse(host)) } -// Delete handles DELETE /v1/hosts/{id}. -func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { +// DeletePreview handles GET /v1/hosts/{id}/delete-preview. +// Returns what would be affected without making changes, for confirmation UI. +func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) { hostID := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) - if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil { + preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) + if err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) return } - w.WriteHeader(http.StatusNoContent) + writeJSON(w, http.StatusOK, deletePreviewResponse{ + Host: hostToResponse(preview.Host), + SandboxIDs: preview.SandboxIDs, + }) +} + +// Delete handles DELETE /v1/hosts/{id}. +// Without ?force=true: returns 409 with affected sandbox IDs if any are active. +// With ?force=true: gracefully stops all sandboxes then deletes the host. +func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + force := r.URL.Query().Get("force") == "true" + + err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force) + if err == nil { + w.WriteHeader(http.StatusNoContent) + return + } + + // Check if it's a "has running sandboxes" error and return a structured 409. + var hasSandboxes *service.HostHasSandboxesError + if errors.As(err, &hasSandboxes) { + writeJSON(w, http.StatusConflict, map[string]any{ + "error": map[string]any{ + "code": "has_active_sandboxes", + "message": "host has active sandboxes; use ?force=true to destroy them and delete the host", + "sandbox_ids": hasSandboxes.SandboxIDs, + }, + }) + return + } + + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) } // RegenerateToken handles POST /v1/hosts/{id}/token. @@ -247,8 +330,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusCreated, registerHostResponse{ - Host: hostToResponse(result.Host), - Token: result.JWT, + Host: hostToResponse(result.Host), + Token: result.JWT, + RefreshToken: result.RefreshToken, }) } @@ -264,7 +348,8 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { } if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil { - writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat") + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) return } @@ -311,6 +396,33 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } +// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated). +// The host agent sends its refresh token to receive a new JWT and rotated refresh token. +func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { + var req refreshTokenRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + if req.RefreshToken == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required") + return + } + + result, err := h.svc.Refresh(r.Context(), req.RefreshToken) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, refreshTokenResponse{ + Host: hostToResponse(result.Host), + Token: result.JWT, + RefreshToken: result.RefreshToken, + }) +} + // ListTags handles GET /v1/hosts/{id}/tags. func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) { hostID := chi.URLParam(r, "id") diff --git a/internal/api/handlers_oauth.go b/internal/api/handlers_oauth.go index 1c72285..348dd85 100644 --- a/internal/api/handlers_oauth.go +++ b/internal/api/handlers_oauth.go @@ -156,7 +156,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { redirectWithError(w, r, redirectBase, "db_error") return } - token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role) + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin) if err != nil { slog.Error("oauth login: failed to sign jwt", "error", err) redirectWithError(w, r, redirectBase, "internal_error") @@ -255,7 +255,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner") + token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false) if err != nil { slog.Error("oauth: failed to sign jwt", "error", err) redirectWithError(w, r, redirectBase, "internal_error") @@ -290,7 +290,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov redirectWithError(w, r, redirectBase, "db_error") return } - token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role) + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin) if err != nil { slog.Error("oauth: retry login: failed to sign jwt", "error", err) redirectWithError(w, r, redirectBase, "internal_error") diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index f48539a..7d3e7fa 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "fmt" "log/slog" @@ -14,20 +15,45 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) type snapshotHandler struct { - svc *service.TemplateService - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient + svc *service.TemplateService + db *db.Queries + pool *lifecycle.HostClientPool } -func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler { - return &snapshotHandler{svc: svc, db: db, agent: agent} +func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool) *snapshotHandler { + return &snapshotHandler{svc: svc, db: db, pool: pool} +} + +// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts. +// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts +// and ignore NotFound errors. TODO: add host_id to templates table. +func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error { + hosts, err := h.db.ListActiveHosts(ctx) + if err != nil { + return fmt.Errorf("list hosts: %w", err) + } + for _, host := range hosts { + if host.Status != "online" { + continue + } + agent, err := h.pool.GetForHost(host) + if err != nil { + continue + } + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { + if connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err) + } + } + } + return nil } type createSnapshotRequest struct { @@ -93,10 +119,9 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace") return } - // Delete old files from the agent before removing the DB record. - if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil { - status, code, msg := agentErrToHTTP(err) - writeError(w, status, code, "failed to delete existing snapshot files: "+msg) + // Delete old snapshot files from all hosts before removing the DB record. + if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil { + writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files") return } if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil { @@ -116,7 +141,13 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } - resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ + agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable") + return + } + + resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ SandboxId: req.SandboxID, Name: req.Name, })) @@ -186,11 +217,8 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { return } - if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ - Name: name, - })); err != nil { - status, code, msg := agentErrToHTTP(err) - writeError(w, status, code, "failed to delete snapshot files: "+msg) + if err := h.deleteSnapshotBroadcast(ctx, name); err != nil { + writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files") return } diff --git a/internal/api/handlers_team.go b/internal/api/handlers_team.go index e852583..fcb5564 100644 --- a/internal/api/handlers_team.go +++ b/internal/api/handlers_team.go @@ -25,6 +25,7 @@ type teamResponse struct { ID string `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + IsByoc bool `json:"is_byoc"` CreatedAt string `json:"created_at"` } @@ -44,9 +45,10 @@ type memberResponse struct { func teamToResponse(t db.Team) teamResponse { resp := teamResponse{ - ID: t.ID, - Name: t.Name, - Slug: t.Slug, + ID: t.ID, + Name: t.Name, + Slug: t.Slug, + IsByoc: t.IsByoc, } if t.CreatedAt.Valid { resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339) @@ -321,3 +323,25 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only). +// Enables or disables the BYOC feature flag for a team. +func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) { + teamID := chi.URLParam(r, "id") + + var req struct { + Enabled bool `json:"enabled"` + } + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/host_monitor.go b/internal/api/host_monitor.go new file mode 100644 index 0000000..e2afca1 --- /dev/null +++ b/internal/api/host_monitor.go @@ -0,0 +1,202 @@ +package api + +import ( + "context" + "log/slog" + "time" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" + + "connectrpc.com/connect" +) + +// unreachableThreshold is how long a host can go without a heartbeat before +// it is considered unreachable (3 missed 30-second heartbeats). +const unreachableThreshold = 90 * time.Second + +// HostMonitor runs on a fixed interval and performs two duties: +// +// 1. Passive check: marks hosts whose last_heartbeat_at is stale as +// "unreachable" and marks their active sandboxes as "missing". +// +// 2. Active reconciliation: for each online host, calls ListSandboxes and +// reconciles DB state against live host state — restoring "missing" +// sandboxes that are actually alive, and stopping orphaned ones. +type HostMonitor struct { + db *db.Queries + pool *lifecycle.HostClientPool + interval time.Duration +} + +// NewHostMonitor creates a HostMonitor. +func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, interval time.Duration) *HostMonitor { + return &HostMonitor{ + db: queries, + pool: pool, + interval: interval, + } +} + +// Start runs the monitor loop until the context is cancelled. +func (m *HostMonitor) Start(ctx context.Context) { + go func() { + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + // Run immediately on startup so the CP doesn't wait one full interval + // before reconciling host and sandbox state. + m.run(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.run(ctx) + } + } + }() +} + +func (m *HostMonitor) run(ctx context.Context) { + hosts, err := m.db.ListActiveHosts(ctx) + if err != nil { + slog.Warn("host monitor: failed to list hosts", "error", err) + return + } + + for _, host := range hosts { + m.checkHost(ctx, host) + } +} + +func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { + // --- Passive phase: check heartbeat staleness --- + + stale := !host.LastHeartbeatAt.Valid || + time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold + + if stale && host.Status != "unreachable" { + slog.Info("host monitor: marking host unreachable", "host_id", host.ID, + "last_heartbeat", host.LastHeartbeatAt.Time) + if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil { + slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err) + } + if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil { + slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err) + } + return + } + + // --- Active reconciliation: only for online hosts --- + + if host.Status != "online" { + return + } + + agent, err := m.pool.GetForHost(host) + if err != nil { + // Host has no address yet (e.g., just registered) — skip. + return + } + + resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{})) + if err != nil { + // RPC failure is a transient condition; the passive phase will catch it + // if heartbeats stop arriving. + slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err) + return + } + + // Build set of sandbox IDs alive on the host. + alive := make(map[string]struct{}, len(resp.Msg.Sandboxes)) + for _, sb := range resp.Msg.Sandboxes { + alive[sb.SandboxId] = struct{}{} + } + + autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds)) + for _, id := range resp.Msg.AutoPausedSandboxIds { + autoPaused[id] = struct{}{} + } + + // --- Restore sandboxes that are "missing" in DB but alive on host --- + // This handles the case where CP marked them missing due to a transient + // heartbeat gap, but the host was actually fine. + + missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{ + HostID: host.ID, + Column2: []string{"missing"}, + }) + if err != nil { + slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err) + } else { + var toRestore []string + var toStop []string + for _, sb := range missingSandboxes { + if _, ok := alive[sb.ID]; ok { + toRestore = append(toRestore, sb.ID) + } else { + toStop = append(toStop, sb.ID) + } + } + if len(toRestore) > 0 { + slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore)) + if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil { + slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err) + } + } + if len(toStop) > 0 { + slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop)) + if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ + Column1: toStop, + Status: "stopped", + }); err != nil { + slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err) + } + } + } + + // --- Find running sandboxes in DB that are no longer alive on the host --- + + runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{ + HostID: host.ID, + Column2: []string{"running"}, + }) + if err != nil { + slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err) + return + } + + var toPause, toStop []string + for _, sb := range runningSandboxes { + if _, ok := alive[sb.ID]; ok { + continue + } + if _, ok := autoPaused[sb.ID]; ok { + toPause = append(toPause, sb.ID) + } else { + toStop = append(toStop, sb.ID) + } + } + + if len(toPause) > 0 { + slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause)) + if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ + Column1: toPause, + Status: "paused", + }); err != nil { + slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err) + } + } + if len(toStop) > 0 { + slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop)) + if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ + Column1: toStop, + Status: "stopped", + }); err != nil { + slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err) + } + } +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go index b327dd6..6a56293 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -89,6 +89,8 @@ func serviceErrToHTTP(err error) (int, string, string) { return http.StatusConflict, "invalid_state", msg case strings.Contains(msg, "forbidden"): return http.StatusForbidden, "forbidden", msg + case strings.Contains(msg, "invalid or expired"): + return http.StatusUnauthorized, "unauthorized", msg case strings.Contains(msg, "invalid"): return http.StatusBadRequest, "invalid_request", msg default: diff --git a/internal/api/middleware_admin.go b/internal/api/middleware_admin.go new file mode 100644 index 0000000..0685896 --- /dev/null +++ b/internal/api/middleware_admin.go @@ -0,0 +1,30 @@ +package api + +import ( + "net/http" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" +) + +// requireAdmin validates that the authenticated user is a platform admin. +// Must run after requireJWT (depends on AuthContext being present). +// Re-validates against the DB — the JWT is_admin claim is for UI only; +// the DB is the source of truth for admin access. +func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ac, ok := auth.FromContext(r.Context()) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required") + return + } + user, err := queries.GetUserByID(r.Context(), ac.UserID) + if err != nil || !user.IsAdmin { + writeError(w, http.StatusForbidden, "forbidden", "admin access required") + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index c0b17fa..96b1c68 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -26,12 +26,14 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler { } ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ - TeamID: claims.TeamID, - UserID: claims.Subject, - Email: claims.Email, - Name: claims.Name, - Role: claims.Role, + TeamID: claims.TeamID, + UserID: claims.Subject, + Email: claims.Email, + Name: claims.Name, + Role: claims.Role, + IsAdmin: claims.IsAdmin, }) + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 5f8dff5..e46fabc 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -1193,8 +1193,16 @@ paths: security: - bearerAuth: [] description: | - Admins can delete any host. Team owners can delete BYOC hosts - belonging to their team. + Admins can delete any host. Team owners and admins can delete BYOC hosts + belonging to their team. Without `?force=true`, returns 409 if the host + has active sandboxes. With `?force=true`, destroys all sandboxes first. + parameters: + - name: force + in: query + required: false + schema: + type: boolean + description: If true, destroy all sandboxes on the host before deleting. responses: "204": description: Host deleted @@ -1204,6 +1212,12 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + "409": + description: Host has active sandboxes (only when force is not set) + content: + application/json: + schema: + $ref: "#/components/schemas/HostHasSandboxesError" /v1/hosts/{id}/token: parameters: @@ -1312,6 +1326,72 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/hosts/auth/refresh: + post: + summary: Refresh host JWT + operationId: refreshHostToken + tags: [hosts] + description: | + Exchanges a refresh token for a new JWT and rotated refresh token. + The old refresh token is immediately revoked. No authentication required — + the refresh token itself is the credential. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenRequest" + responses: + "200": + description: New JWT and rotated refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenResponse" + "401": + description: Invalid, expired, or revoked refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/delete-preview: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Preview host deletion + operationId: getHostDeletePreview + tags: [hosts] + security: + - bearerAuth: [] + description: | + Returns the list of sandbox IDs that would be destroyed if the host + were deleted with `?force=true`. No state is modified. + responses: + "200": + description: Deletion preview + content: + application/json: + schema: + $ref: "#/components/schemas/HostDeletePreview" + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/hosts/{id}/tags: parameters: - name: id @@ -1405,7 +1485,7 @@ components: type: apiKey in: header name: X-Host-Token - description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year. + description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days. schemas: SignupRequest: @@ -1505,7 +1585,7 @@ components: type: string status: type: string - enum: [pending, running, paused, stopped, error] + enum: [pending, starting, running, paused, hibernated, stopped, missing, error] template: type: string vcpus: @@ -1661,7 +1741,10 @@ components: $ref: "#/components/schemas/Host" token: type: string - description: Long-lived host JWT for X-Host-Token header. Valid for 1 year. + description: Host JWT for X-Host-Token header. Valid for 7 days. + refresh_token: + type: string + description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use. Host: type: object @@ -1697,7 +1780,7 @@ components: nullable: true status: type: string - enum: [pending, online, offline, draining] + enum: [pending, online, offline, draining, unreachable] last_heartbeat_at: type: string format: date-time @@ -1711,6 +1794,54 @@ components: type: string format: date-time + RefreshHostTokenRequest: + type: object + required: [refresh_token] + properties: + refresh_token: + type: string + description: Refresh token obtained from registration or a previous refresh. + + RefreshHostTokenResponse: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + token: + type: string + description: New host JWT. Valid for 7 days. + refresh_token: + type: string + description: New refresh token. Valid for 60 days; old token is revoked. + + HostDeletePreview: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + sandbox_ids: + type: array + items: + type: string + description: IDs of sandboxes that would be destroyed on force-delete. + + HostHasSandboxesError: + type: object + properties: + error: + type: object + properties: + code: + type: string + example: host_has_sandboxes + message: + type: string + sandbox_ids: + type: array + items: + type: string + description: IDs of active sandboxes blocking deletion. + AddTagRequest: type: object required: [tag] diff --git a/internal/api/reconciler.go b/internal/api/reconciler.go deleted file mode 100644 index fcc2388..0000000 --- a/internal/api/reconciler.go +++ /dev/null @@ -1,126 +0,0 @@ -package api - -import ( - "context" - "log/slog" - "time" - - "connectrpc.com/connect" - - "git.omukk.dev/wrenn/sandbox/internal/db" - pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" -) - -// Reconciler periodically compares the host agent's sandbox list with the DB -// and marks sandboxes that no longer exist on the host as stopped. -type Reconciler struct { - db *db.Queries - agent hostagentv1connect.HostAgentServiceClient - hostID string - interval time.Duration -} - -// NewReconciler creates a new reconciler. -func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler { - return &Reconciler{ - db: db, - agent: agent, - hostID: hostID, - interval: interval, - } -} - -// Start runs the reconciliation loop until the context is cancelled. -func (rc *Reconciler) Start(ctx context.Context) { - go func() { - ticker := time.NewTicker(rc.interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - rc.reconcile(ctx) - } - } - }() -} - -func (rc *Reconciler) reconcile(ctx context.Context) { - // Single RPC returns both the running sandbox list and any IDs that - // were auto-paused by the TTL reaper since the last call. - resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{})) - if err != nil { - slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err) - return - } - - // Build a set of sandbox IDs that are alive on the host. - alive := make(map[string]struct{}, len(resp.Msg.Sandboxes)) - for _, sb := range resp.Msg.Sandboxes { - alive[sb.SandboxId] = struct{}{} - } - - // Build auto-paused set from the same response. - autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds)) - for _, id := range resp.Msg.AutoPausedSandboxIds { - autoPausedSet[id] = struct{}{} - } - - // Get all DB sandboxes for this host that are running. - // Paused sandboxes are excluded: they are expected to not exist on the - // host agent because pause = snapshot + destroy resources. - dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{ - HostID: rc.hostID, - Column2: []string{"running"}, - }) - if err != nil { - slog.Warn("reconciler: failed to list DB sandboxes", "error", err) - return - } - - // Find sandboxes in DB that are no longer on the host. - var stale []string - for _, sb := range dbSandboxes { - if _, ok := alive[sb.ID]; !ok { - stale = append(stale, sb.ID) - } - } - - if len(stale) == 0 { - return - } - - // Split stale sandboxes into those auto-paused by the TTL reaper vs - // those that crashed/were orphaned. - var toPause, toStop []string - for _, id := range stale { - if _, ok := autoPausedSet[id]; ok { - toPause = append(toPause, id) - } else { - toStop = append(toStop, id) - } - } - - if len(toPause) > 0 { - slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause) - if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ - Column1: toPause, - Status: "paused", - }); err != nil { - slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err) - } - } - - if len(toStop) > 0 { - slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop) - if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ - Column1: toStop, - Status: "stopped", - }); err != nil { - slog.Warn("reconciler: failed to update stale sandboxes", "error", err) - } - } -} diff --git a/internal/api/server.go b/internal/api/server.go index 46759a1..302aee3 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -11,8 +11,9 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/scheduler" "git.omukk.dev/wrenn/sandbox/internal/service" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) //go:embed openapi.yaml @@ -24,25 +25,34 @@ type Server struct { } // New constructs the chi router and registers all routes. -func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server { +func New( + queries *db.Queries, + pool *lifecycle.HostClientPool, + sched scheduler.HostScheduler, + pgPool *pgxpool.Pool, + rdb *redis.Client, + jwtSecret []byte, + oauthRegistry *oauth.Registry, + oauthRedirectURL string, +) *Server { r := chi.NewRouter() r.Use(requestLogger()) // Shared service layer. - sandboxSvc := &service.SandboxService{DB: queries, Agent: agent} + sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched} apiKeySvc := &service.APIKeyService{DB: queries} templateSvc := &service.TemplateService{DB: queries} - hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret} - teamSvc := &service.TeamService{DB: queries, Pool: pool, Agent: agent} + hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool} + teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} sandbox := newSandboxHandler(sandboxSvc) - exec := newExecHandler(queries, agent) - execStream := newExecStreamHandler(queries, agent) - files := newFilesHandler(queries, agent) - filesStream := newFilesStreamHandler(queries, agent) - snapshots := newSnapshotHandler(templateSvc, queries, agent) - authH := newAuthHandler(queries, pool, jwtSecret) - oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL) + exec := newExecHandler(queries, pool) + execStream := newExecStreamHandler(queries, pool) + files := newFilesHandler(queries, pool) + filesStream := newFilesStreamHandler(queries, pool) + snapshots := newSnapshotHandler(templateSvc, queries, pool) + authH := newAuthHandler(queries, pgPool, jwtSecret) + oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL) apiKeys := newAPIKeyHandler(apiKeySvc) hostH := newHostHandler(hostSvc, queries) teamH := newTeamHandler(teamSvc) @@ -123,6 +133,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p // Unauthenticated: one-time registration token. r.Post("/register", hostH.Register) + // Unauthenticated: refresh token exchange. + r.Post("/auth/refresh", hostH.RefreshToken) + // Host-token-authenticated: heartbeat. r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat) @@ -134,6 +147,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p r.Route("/{id}", func(r chi.Router) { r.Get("/", hostH.Get) r.Delete("/", hostH.Delete) + r.Get("/delete-preview", hostH.DeletePreview) r.Post("/token", hostH.RegenerateToken) r.Get("/tags", hostH.ListTags) r.Post("/tags", hostH.AddTag) @@ -142,6 +156,13 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p }) }) + // Platform admin routes — require JWT + DB-validated admin status. + r.Route("/v1/admin", func(r chi.Router) { + r.Use(requireJWT(jwtSecret)) + r.Use(requireAdmin(queries)) + r.Put("/teams/{id}/byoc", teamH.SetBYOC) + }) + return &Server{router: r} } diff --git a/internal/auth/context.go b/internal/auth/context.go index 36dd06c..98db360 100644 --- a/internal/auth/context.go +++ b/internal/auth/context.go @@ -8,11 +8,12 @@ const authCtxKey contextKey = 0 // AuthContext is stamped into request context by auth middleware. type AuthContext struct { - TeamID string - UserID string // empty when authenticated via API key - Email string // empty when authenticated via API key - Name string // empty when authenticated via API key - Role string // owner, admin, or member; empty when authenticated via API key + TeamID string + UserID string // empty when authenticated via API key + Email string // empty when authenticated via API key + Name string // empty when authenticated via API key + Role string // owner, admin, or member; empty when authenticated via API key + IsAdmin bool // platform-level admin; always false when authenticated via API key } // WithAuthContext returns a new context with the given AuthContext. diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index eebba31..a40a032 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -8,26 +8,29 @@ import ( ) const jwtExpiry = 6 * time.Hour -const hostJWTExpiry = 8760 * time.Hour // 1 year +const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token +const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer // Claims are the JWT payload for user tokens. type Claims struct { - Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens - TeamID string `json:"team_id"` - Role string `json:"role"` // owner, admin, or member within TeamID - Email string `json:"email"` - Name string `json:"name"` + Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens + TeamID string `json:"team_id"` + Role string `json:"role"` // owner, admin, or member within TeamID + Email string `json:"email"` + Name string `json:"name"` + IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag jwt.RegisteredClaims } // SignJWT signs a new 6-hour JWT for the given user. -func SignJWT(secret []byte, userID, teamID, email, name, role string) (string, error) { +func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) { now := time.Now() claims := Claims{ - TeamID: teamID, - Role: role, - Email: email, - Name: name, + TeamID: teamID, + Role: role, + Email: email, + Name: name, + IsAdmin: isAdmin, RegisteredClaims: jwt.RegisteredClaims{ Subject: userID, IssuedAt: jwt.NewNumericDate(now), diff --git a/internal/config/config.go b/internal/config/config.go index b881afb..7ef0aa6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,18 +2,16 @@ package config import ( "os" - "strings" "github.com/joho/godotenv" ) // Config holds the control plane configuration. type Config struct { - DatabaseURL string - RedisURL string - ListenAddr string - HostAgentAddr string - JWTSecret string + DatabaseURL string + RedisURL string + ListenAddr string + JWTSecret string OAuthGitHubClientID string OAuthGitHubClientSecret string @@ -27,25 +25,17 @@ func Load() Config { // Best-effort load — missing .env file is fine. _ = godotenv.Load() - cfg := Config{ - DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), - RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"), - ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), - HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"), - JWTSecret: os.Getenv("JWT_SECRET"), + return Config{ + DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), + RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"), + ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), + JWTSecret: os.Getenv("JWT_SECRET"), OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), CPPublicURL: os.Getenv("CP_PUBLIC_URL"), } - - // Ensure the host agent address has a scheme. - if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") { - cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr - } - - return cfg } func envOrDefault(key, def string) string { diff --git a/internal/db/host_refresh_tokens.sql.go b/internal/db/host_refresh_tokens.sql.go new file mode 100644 index 0000000..d02a0e7 --- /dev/null +++ b/internal/db/host_refresh_tokens.sql.go @@ -0,0 +1,92 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: host_refresh_tokens.sql + +package db + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec +DELETE FROM host_refresh_tokens +WHERE expires_at < NOW() OR revoked_at IS NOT NULL +` + +func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error { + _, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens) + return err +} + +const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one +SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens +WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW() +` + +func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) { + row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash) + var i HostRefreshToken + err := row.Scan( + &i.ID, + &i.HostID, + &i.TokenHash, + &i.ExpiresAt, + &i.CreatedAt, + &i.RevokedAt, + ) + return i, err +} + +const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one +INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at) +VALUES ($1, $2, $3, $4) +RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at +` + +type InsertHostRefreshTokenParams struct { + ID string `json:"id"` + HostID string `json:"host_id"` + TokenHash string `json:"token_hash"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) { + row := q.db.QueryRow(ctx, insertHostRefreshToken, + arg.ID, + arg.HostID, + arg.TokenHash, + arg.ExpiresAt, + ) + var i HostRefreshToken + err := row.Scan( + &i.ID, + &i.HostID, + &i.TokenHash, + &i.ExpiresAt, + &i.CreatedAt, + &i.RevokedAt, + ) + return i, err +} + +const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec +UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1 +` + +func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, revokeHostRefreshToken, id) + return err +} + +const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec +UPDATE host_refresh_tokens SET revoked_at = NOW() +WHERE host_id = $1 AND revoked_at IS NULL +` + +func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error { + _, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID) + return err +} diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index ad15290..90d97ca 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -234,6 +234,50 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams return i, err } +const listActiveHosts = `-- name: ListActiveHosts :many +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at +` + +// Returns all hosts that have completed registration (not pending/offline). +func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { + rows, err := q.db.Query(ctx, listActiveHosts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Host + for rows.Next() { + var i Host + if err := rows.Scan( + &i.ID, + &i.Type, + &i.TeamID, + &i.Provider, + &i.AvailabilityZone, + &i.Arch, + &i.CpuCores, + &i.MemoryMb, + &i.DiskGb, + &i.Address, + &i.Status, + &i.LastHeartbeatAt, + &i.Metadata, + &i.CreatedBy, + &i.CreatedAt, + &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listHosts = `-- name: ListHosts :many SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC ` @@ -461,6 +505,15 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error { return err } +const markHostUnreachable = `-- name: MarkHostUnreachable :exec +UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1 +` + +func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, markHostUnreachable, id) + return err +} + const registerHost = `-- name: RegisterHost :execrows UPDATE hosts SET arch = $2, @@ -521,6 +574,21 @@ func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error { return err } +const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows +UPDATE hosts +SET last_heartbeat_at = NOW(), + status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END, + updated_at = NOW() +WHERE id = $1 +` + +// Updates last_heartbeat_at and transitions unreachable hosts back to online. +// Returns 0 if no host was found (deleted). +func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) { + result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id) + return result.RowsAffected(), err +} + const updateHostStatus = `-- name: UpdateHostStatus :exec UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1 ` diff --git a/internal/db/models.go b/internal/db/models.go index 158193b..546c1da 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -36,6 +36,15 @@ type Host struct { MtlsEnabled bool `json:"mtls_enabled"` } +type HostRefreshToken struct { + ID string `json:"id"` + HostID string `json:"host_id"` + TokenHash string `json:"token_hash"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + RevokedAt pgtype.Timestamptz `json:"revoked_at"` +} + type HostTag struct { HostID string `json:"host_id"` Tag string `json:"tag"` diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index cf39a14..620f77e 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -11,6 +11,20 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec +UPDATE sandboxes +SET status = 'running', + last_updated = NOW() +WHERE id = ANY($1::text[]) AND status = 'missing' +` + +// Called by the reconciler when a host comes back online and its sandboxes are +// confirmed alive. Restores only sandboxes that are in 'missing' state. +func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error { + _, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1) + return err +} + const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec UPDATE sandboxes SET status = $2, @@ -300,6 +314,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San return items, nil } +const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec +UPDATE sandboxes +SET status = 'missing', + last_updated = NOW() +WHERE host_id = $1 AND status IN ('running', 'starting', 'pending') +` + +// Called when the host monitor marks a host unreachable. +// Marks running/starting/pending sandboxes on that host as 'missing' so users see +// the sandbox is not currently reachable, without permanently losing the record. +func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error { + _, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID) + return err +} + const updateLastActive = `-- name: UpdateLastActive :exec UPDATE sandboxes SET last_active_at = $2, diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go index fc74d55..9f39c3b 100644 --- a/internal/hostagent/registration.go +++ b/internal/hostagent/registration.go @@ -17,6 +17,13 @@ import ( "golang.org/x/sys/unix" ) +// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt. +type tokenFile struct { + HostID string `json:"host_id"` + JWT string `json:"jwt"` + RefreshToken string `json:"refresh_token"` +} + // RegistrationConfig holds the configuration for host registration. type RegistrationConfig struct { CPURL string // Control plane base URL (e.g., http://localhost:8000) @@ -35,8 +42,19 @@ type registerRequest struct { } type registerResponse struct { - Host json.RawMessage `json:"host"` - Token string `json:"token"` + Host json.RawMessage `json:"host"` + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` +} + +type refreshRequest struct { + RefreshToken string `json:"refresh_token"` +} + +type refreshResponse struct { + Host json.RawMessage `json:"host"` + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` } type errorResponse struct { @@ -46,20 +64,47 @@ type errorResponse struct { } `json:"error"` } -// Register calls the control plane to register this host agent and persists -// the returned JWT to disk. Returns the host JWT token string. -func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { - // Check if we already have a saved token. - if data, err := os.ReadFile(cfg.TokenFile); err == nil { - token := strings.TrimSpace(string(data)) - if token != "" { - slog.Info("loaded existing host token", "file", cfg.TokenFile) - return token, nil - } +// loadTokenFile reads and parses the persisted token file. +func loadTokenFile(path string) (*tokenFile, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err } + // Support legacy format (raw JWT string) for backwards compatibility. + trimmed := strings.TrimSpace(string(data)) + if !strings.HasPrefix(trimmed, "{") { + // Old format: just the JWT, no refresh token. + hostID, _ := hostIDFromJWT(trimmed) + return &tokenFile{HostID: hostID, JWT: trimmed}, nil + } + var tf tokenFile + if err := json.Unmarshal(data, &tf); err != nil { + return nil, fmt.Errorf("parse token file: %w", err) + } + return &tf, nil +} +// saveTokenFile writes the token file as JSON with 0600 permissions. +func saveTokenFile(path string, tf tokenFile) error { + data, err := json.MarshalIndent(tf, "", " ") + if err != nil { + return fmt.Errorf("marshal token file: %w", err) + } + return os.WriteFile(path, data, 0600) +} + +// Register calls the control plane to register this host agent and persists +// the returned JWT and refresh token to disk. Returns the host JWT token string. +func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { + // If no explicit registration token was given, reuse the saved JWT. + // A --register flag always overrides the local file so operators can + // force re-registration without manually deleting host.jwt. if cfg.RegistrationToken == "" { - return "", fmt.Errorf("no saved host token and no registration token provided") + if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { + slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID) + return tf.JWT, nil + } + return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)") } arch := runtime.GOARCH @@ -117,46 +162,182 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { return "", fmt.Errorf("registration response missing token") } - // Persist the token to disk for subsequent startups. - if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil { + hostID, err := hostIDFromJWT(regResp.Token) + if err != nil { + return "", fmt.Errorf("extract host ID from JWT: %w", err) + } + + // Persist JWT + refresh token. + tf := tokenFile{ + HostID: hostID, + JWT: regResp.Token, + RefreshToken: regResp.RefreshToken, + } + if err := saveTokenFile(cfg.TokenFile, tf); err != nil { return "", fmt.Errorf("save host token: %w", err) } - slog.Info("host registered and token saved", "file", cfg.TokenFile) + slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID) return regResp.Token, nil } +// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token. +// It reads and updates the token file in place. +func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) { + tf, err := loadTokenFile(tokenFilePath) + if err != nil { + return "", fmt.Errorf("load token file: %w", err) + } + if tf.RefreshToken == "" { + return "", fmt.Errorf("no refresh token available; host must re-register") + } + + body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken}) + url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("create refresh request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + var errResp errorResponse + if json.Unmarshal(respBody, &errResp) == nil { + return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) + } + return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) + } + + var refResp refreshResponse + if err := json.Unmarshal(respBody, &refResp); err != nil { + return "", fmt.Errorf("parse refresh response: %w", err) + } + + tf.JWT = refResp.Token + tf.RefreshToken = refResp.RefreshToken + if err := saveTokenFile(tokenFilePath, *tf); err != nil { + return "", fmt.Errorf("save refreshed token: %w", err) + } + + slog.Info("host JWT refreshed", "host_id", tf.HostID) + return refResp.Token, nil +} + // StartHeartbeat launches a background goroutine that sends periodic heartbeats // to the control plane. It runs until the context is cancelled. -func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) { - url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat" +// +// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh +// also fails (expired refresh token), it calls pauseAll and stops. +// +// On repeated network failures (3 consecutive), it calls pauseAll but keeps +// retrying — the connection may recover and the host should resume heartbeating. +// +// onDeleted is called when CP returns 404, meaning this host record was deleted. +// The token file is removed before calling onDeleted so subsequent starts prompt +// for a new registration token. +func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) { client := &http.Client{Timeout: 10 * time.Second} go func() { ticker := time.NewTicker(interval) defer ticker.Stop() + consecutiveFailures := 0 + pausedDueToFailure := false + currentJWT := "" + + // Load the current JWT from disk. + if tf, err := loadTokenFile(tokenFilePath); err == nil { + currentJWT = tf.JWT + } + + // beat sends one heartbeat. Returns true if the loop should stop. + beat := func() (stop bool) { + url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + slog.Warn("heartbeat: failed to create request", "error", err) + return false + } + req.Header.Set("X-Host-Token", currentJWT) + + resp, err := client.Do(req) + if err != nil { + consecutiveFailures++ + slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures) + if consecutiveFailures >= 3 && !pausedDueToFailure { + slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes") + if pauseAll != nil { + pauseAll() + } + pausedDueToFailure = true + } + return false + } + resp.Body.Close() + + switch resp.StatusCode { + case http.StatusNoContent: + if consecutiveFailures > 0 || pausedDueToFailure { + slog.Info("heartbeat: CP connection restored") + } + consecutiveFailures = 0 + pausedDueToFailure = false + + case http.StatusUnauthorized, http.StatusForbidden: + slog.Warn("heartbeat: JWT rejected — attempting token refresh") + newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath) + if refreshErr != nil { + slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required", + "error", refreshErr) + if pauseAll != nil && !pausedDueToFailure { + pauseAll() + pausedDueToFailure = true + } + // Stop the heartbeat loop — operator must re-register. + return true + } + currentJWT = newJWT + slog.Info("heartbeat: JWT refreshed successfully") + + case http.StatusNotFound: + slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting") + if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) { + slog.Warn("heartbeat: failed to remove token file", "error", err) + } + if onDeleted != nil { + onDeleted() + } + return true + + default: + slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode) + } + return false + } + + // Send an immediate heartbeat on startup so the CP sees the host as + // online without waiting for the first ticker tick. + if beat() { + return + } + for { select { case <-ctx.Done(): return case <-ticker.C: - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) - if err != nil { - slog.Warn("heartbeat: failed to create request", "error", err) - continue - } - req.Header.Set("X-Host-Token", hostToken) - - resp, err := client.Do(req) - if err != nil { - slog.Warn("heartbeat: request failed", "error", err) - continue - } - resp.Body.Close() - - if resp.StatusCode != http.StatusNoContent { - slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode) + if beat() { + return } } } @@ -166,6 +347,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv // HostIDFromToken extracts the host_id claim from a host JWT without // verifying the signature (the agent doesn't have the signing secret). func HostIDFromToken(token string) (string, error) { + return hostIDFromJWT(token) +} + +// hostIDFromJWT is the internal implementation used by both HostIDFromToken and +// the token file loader. +func hostIDFromJWT(token string) (string, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return "", fmt.Errorf("invalid JWT format") diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index d545e59..c0a4cfd 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -22,12 +22,15 @@ import ( // Server implements the HostAgentService Connect RPC handler. type Server struct { hostagentv1connect.UnimplementedHostAgentServiceHandler - mgr *sandbox.Manager + mgr *sandbox.Manager + terminate func() // called when the CP requests agent termination } // NewServer creates a new host agent RPC server. -func NewServer(mgr *sandbox.Manager) *Server { - return &Server{mgr: mgr} +// terminate is invoked (in a goroutine) when the CP calls the Terminate RPC, +// allowing main to perform a clean shutdown. +func NewServer(mgr *sandbox.Manager, terminate func()) *Server { + return &Server{mgr: mgr, terminate: terminate} } func (s *Server) CreateSandbox( @@ -412,3 +415,14 @@ func (s *Server) ListSandboxes( AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(), }), nil } + +func (s *Server) Terminate( + _ context.Context, + _ *connect.Request[pb.TerminateRequest], +) (*connect.Response[pb.TerminateResponse], error) { + slog.Info("terminate RPC received — scheduling shutdown") + if s.terminate != nil { + go s.terminate() + } + return connect.NewResponse(&pb.TerminateResponse{}), nil +} diff --git a/internal/id/id.go b/internal/id/id.go index 362096f..04a2506 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -67,3 +67,17 @@ func NewRegistrationToken() string { } return hex.EncodeToString(b) } + +// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars. +func NewRefreshTokenID() string { + return "hrt-" + hex8() +} + +// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token. +func NewRefreshToken() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + return hex.EncodeToString(b) +} diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go new file mode 100644 index 0000000..0caf5ec --- /dev/null +++ b/internal/lifecycle/hostpool.go @@ -0,0 +1,77 @@ +package lifecycle + +import ( + "fmt" + "net/http" + "strings" + "sync" + "time" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" +) + +// HostClientPool maintains a cache of Connect RPC clients keyed by host ID. +// Clients are created lazily on first access and evicted when a host is removed +// or goes unreachable. The pool is safe for concurrent use. +type HostClientPool struct { + mu sync.RWMutex + clients map[string]hostagentv1connect.HostAgentServiceClient + httpClient *http.Client +} + +// NewHostClientPool creates a new pool. The underlying HTTP client uses a +// 10-minute timeout to support long-running streaming operations. +func NewHostClientPool() *HostClientPool { + return &HostClientPool{ + clients: make(map[string]hostagentv1connect.HostAgentServiceClient), + httpClient: &http.Client{Timeout: 10 * time.Minute}, + } +} + +// Get returns a Connect RPC client for the given host, creating one if necessary. +// address is the host agent address (ip:port or full URL). The scheme is added if absent. +func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient { + p.mu.RLock() + c, ok := p.clients[hostID] + p.mu.RUnlock() + if ok { + return c + } + + p.mu.Lock() + defer p.mu.Unlock() + // Double-check after acquiring write lock. + if c, ok = p.clients[hostID]; ok { + return c + } + c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address)) + p.clients[hostID] = c + return c +} + +// GetForHost is a convenience wrapper that extracts the address from a db.Host +// and returns an error if the host has no address recorded yet. +func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) { + if !h.Address.Valid || h.Address.String == "" { + return nil, fmt.Errorf("host %s has no address", h.ID) + } + return p.Get(h.ID, h.Address.String), nil +} + +// Evict removes the cached client for the given host, forcing a new client to be +// created on the next call to Get. Call this when a host's address changes or when +// a host is deleted. +func (p *HostClientPool) Evict(hostID string) { + p.mu.Lock() + delete(p.clients, hostID) + p.mu.Unlock() +} + +// ensureScheme adds "http://" if the address has no scheme. +func ensureScheme(addr string) string { + if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { + return addr + } + return "http://" + addr +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index ef4e092..bf7d057 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -1183,6 +1183,28 @@ func (m *Manager) Shutdown(ctx context.Context) { m.loops.ReleaseAll() } +// PauseAll pauses every running sandbox managed by this host agent. +// Called when the host loses connectivity to the control plane to avoid +// leaving running VMs unmanaged. It is best-effort: failures for individual +// sandboxes are logged but do not stop the rest. +func (m *Manager) PauseAll(ctx context.Context) { + m.mu.RLock() + ids := make([]string, 0, len(m.boxes)) + for id, sb := range m.boxes { + if sb.Status == models.StatusRunning { + ids = append(ids, id) + } + } + m.mu.RUnlock() + + slog.Info("pausing all running sandboxes due to CP connection loss", "count", len(ids)) + for _, sbID := range ids { + if err := m.Pause(ctx, sbID); err != nil { + slog.Warn("PauseAll: failed to pause sandbox", "id", sbID, "error", err) + } + } +} + // warnErr logs a warning if err is non-nil. Used for best-effort cleanup // in error paths where the primary error has already been captured. func warnErr(msg string, id string, err error) { diff --git a/internal/scheduler/round_robin.go b/internal/scheduler/round_robin.go new file mode 100644 index 0000000..31433a0 --- /dev/null +++ b/internal/scheduler/round_robin.go @@ -0,0 +1,69 @@ +package scheduler + +import ( + "context" + "fmt" + "sync/atomic" + + "git.omukk.dev/wrenn/sandbox/internal/db" +) + +// HostScheduler selects a host for a new sandbox. Implementations may use +// different strategies (round-robin, least-loaded, tag-based, etc.). +type HostScheduler interface { + // SelectHost returns a host that can accept a new sandbox. + // For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID + // are considered. For non-BYOC teams, only online regular (platform) hosts + // are considered. Returns an error if no suitable host is available. + SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) +} + +// RoundRobinScheduler cycles through eligible online hosts in round-robin order. +// It re-fetches the host list on every call so that newly registered or +// recovered hosts are considered immediately. +type RoundRobinScheduler struct { + db *db.Queries + counter atomic.Int64 +} + +// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB. +func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler { + return &RoundRobinScheduler{db: queries} +} + +// SelectHost returns the next eligible online host in round-robin order. +func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) { + hosts, err := s.db.ListActiveHosts(ctx) + if err != nil { + return db.Host{}, fmt.Errorf("list hosts: %w", err) + } + + var eligible []db.Host + for _, h := range hosts { + if h.Status != "online" || !h.Address.Valid || h.Address.String == "" { + continue + } + if isByoc { + // BYOC team: only use hosts belonging to this team. + if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID.String != teamID { + continue + } + } else { + // Non-BYOC team: only use platform (regular) hosts. + if h.Type != "regular" { + continue + } + } + eligible = append(eligible, h) + } + + if len(eligible) == 0 { + if isByoc { + return db.Host{}, fmt.Errorf("no online BYOC hosts available for team") + } + return db.Host{}, fmt.Errorf("no online platform hosts available") + } + + idx := s.counter.Add(1) - 1 + return eligible[int(idx%int64(len(eligible)))], nil +} diff --git a/internal/service/host.go b/internal/service/host.go index a331a58..b3538df 100644 --- a/internal/service/host.go +++ b/internal/service/host.go @@ -2,12 +2,14 @@ package service import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" "log/slog" "time" + "connectrpc.com/connect" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/redis/go-redis/v9" @@ -15,6 +17,8 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) // HostService provides host management operations. @@ -22,6 +26,7 @@ type HostService struct { DB *db.Queries Redis *redis.Client JWT []byte + Pool *lifecycle.HostClientPool } // HostCreateParams holds the parameters for creating a host. @@ -50,10 +55,24 @@ type HostRegisterParams struct { Address string } -// HostRegisterResult holds the registered host and its long-lived JWT. +// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token. type HostRegisterResult struct { - Host db.Host - JWT string + Host db.Host + JWT string + RefreshToken string +} + +// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh. +type HostRefreshResult struct { + Host db.Host + JWT string + RefreshToken string +} + +// HostDeletePreview describes what will be affected by deleting a host. +type HostDeletePreview struct { + Host db.Host + SandboxIDs []string } // regTokenPayload is the JSON stored in Redis for registration tokens. @@ -64,6 +83,14 @@ type regTokenPayload struct { const regTokenTTL = time.Hour +// requireAdminOrOwner returns nil iff the role is "owner" or "admin". +func requireAdminOrOwner(role string) error { + if role == "owner" || role == "admin" { + return nil + } + return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts") +} + // Create creates a new host record and generates a one-time registration token. func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) { if p.Type != "regular" && p.Type != "byoc" { @@ -75,7 +102,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts") } } else { - // BYOC: admin or team owner. + // BYOC: platform admin, or team owner/admin. if p.TeamID == "" { return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts") } @@ -90,18 +117,21 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat if err != nil { return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) } - if membership.Role != "owner" { - return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts") + if err := requireAdminOrOwner(membership.Role); err != nil { + return HostCreateResult{}, err } } } - // Validate team exists and is not deleted for BYOC hosts. + // Validate team exists, is not deleted, and has BYOC enabled. if p.TeamID != "" { team, err := s.DB.GetTeam(ctx, p.TeamID) if err != nil || team.DeletedAt.Valid { return HostCreateResult{}, fmt.Errorf("invalid request: team not found") } + if !team.IsByoc { + return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team") + } } hostID := id.NewHostID() @@ -168,7 +198,6 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status) } - // Same permission model as Create/Delete. if !isAdmin { if host.Type != "byoc" { return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts") @@ -186,8 +215,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI if err != nil { return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) } - if membership.Role != "owner" { - return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens") + if err := requireAdminOrOwner(membership.Role); err != nil { + return HostCreateResult{}, err } } @@ -216,7 +245,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI } // Register validates a one-time registration token, updates the host with -// machine specs, and returns a long-lived host JWT. +// machine specs, and returns a short-lived host JWT plus a long-lived refresh token. func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) { // Atomic consume: GetDel returns the value and deletes in one operation, // preventing concurrent requests from consuming the same token. @@ -264,18 +293,97 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err) } + // Issue a long-lived refresh token. + refreshToken, err := s.issueRefreshToken(ctx, payload.HostID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err) + } + // Re-fetch the host to get the updated state. host, err := s.DB.GetHost(ctx, payload.HostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) } - return HostRegisterResult{Host: host, JWT: hostJWT}, nil + return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil } -// Heartbeat updates the last heartbeat timestamp for a host. +// Refresh validates a refresh token, rotates it (revokes old, issues new), +// and returns a fresh JWT plus the new refresh token. +func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) { + hash := hashToken(refreshToken) + + row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash) + if errors.Is(err, pgx.ErrNoRows) { + return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token") + } + if err != nil { + return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err) + } + + host, err := s.DB.GetHost(ctx, row.HostID) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("host not found: %w", err) + } + + // Sign new JWT. + hostJWT, err := auth.SignHostJWT(s.JWT, host.ID) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err) + } + + // Issue-then-revoke rotation: insert new token first so a crash between + // the two DB calls leaves the host with two valid tokens rather than zero. + newRefreshToken, err := s.issueRefreshToken(ctx, host.ID) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err) + } + + // Revoke old refresh token after the new one is safely persisted. + if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil { + return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err) + } + + return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil +} + +// issueRefreshToken creates a new refresh token record in the DB and returns +// the opaque token string. +func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) { + token := id.NewRefreshToken() + hash := hashToken(token) + now := time.Now() + + if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{ + ID: id.NewRefreshTokenID(), + HostID: hostID, + TokenHash: hash, + ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true}, + }); err != nil { + return "", fmt.Errorf("insert refresh token: %w", err) + } + + return token, nil +} + +// hashToken returns the hex-encoded SHA-256 hash of the token. +func hashToken(token string) string { + h := sha256.Sum256([]byte(token)) + return fmt.Sprintf("%x", h) +} + +// Heartbeat updates the last heartbeat timestamp for a host and transitions +// any 'unreachable' host back to 'online'. Returns a "host not found" error +// (which becomes 404) if the host record no longer exists (e.g., was deleted). func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { - return s.DB.UpdateHostHeartbeat(ctx, hostID) + n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID) + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("host not found") + } + return nil } // List returns hosts visible to the caller. @@ -301,37 +409,139 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo return host, nil } -// Delete removes a host. Admins can delete any host. Team owners can delete -// BYOC hosts belonging to their team. -func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error { - host, err := s.DB.GetHost(ctx, hostID) +// DeletePreview returns what would be affected by deleting the host, without +// making any changes. Use this to show the user a confirmation prompt. +func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) { + host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin) if err != nil { - return fmt.Errorf("host not found: %w", err) + return HostDeletePreview{}, err } - if !isAdmin { - if host.Type != "byoc" { - return fmt.Errorf("forbidden: only admins can delete regular hosts") + sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{ + HostID: hostID, + Column2: []string{"pending", "starting", "running", "missing"}, + }) + if err != nil { + return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err) + } + + ids := make([]string, len(sandboxes)) + for i, sb := range sandboxes { + ids[i] = sb.ID + } + + return HostDeletePreview{Host: host, SandboxIDs: ids}, nil +} + +// Delete removes a host. Without force it returns an error listing active +// sandboxes so the caller can present a confirmation. With force it gracefully +// destroys all running sandboxes before deleting the host record. +func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error { + host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin) + if err != nil { + return err + } + + sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{ + HostID: hostID, + Column2: []string{"pending", "starting", "running", "missing"}, + }) + if err != nil { + return fmt.Errorf("list sandboxes: %w", err) + } + + if len(sandboxes) > 0 && !force { + ids := make([]string, len(sandboxes)) + for i, sb := range sandboxes { + ids[i] = sb.ID } - if !host.TeamID.Valid || host.TeamID.String != teamID { - return fmt.Errorf("forbidden: host does not belong to your team") + return &HostHasSandboxesError{SandboxIDs: ids} + } + + // Gracefully destroy running sandboxes and terminate the agent (best-effort). + if host.Address.Valid && host.Address.String != "" { + agent, err := s.Pool.GetForHost(host) + if err == nil { + for _, sb := range sandboxes { + if sb.Status == "running" || sb.Status == "starting" { + _, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ + SandboxId: sb.ID, + })) + if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound { + slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr) + } + } + } + // Tell the agent to shut itself down immediately. + if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil { + slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostID, "error", rpcErr) + } } + } + + // Mark all affected sandboxes as stopped in DB. + if len(sandboxes) > 0 { + sbIDs := make([]string, len(sandboxes)) + for i, sb := range sandboxes { + sbIDs[i] = sb.ID + } + if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ + Column1: sbIDs, + Status: "stopped", + }); err != nil { + slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err) + } + } + + // Revoke all refresh tokens for this host. + if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil { + slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err) + } + + // Evict the client from the pool so no further RPCs are sent. + if s.Pool != nil { + s.Pool.Evict(hostID) + } + + return s.DB.DeleteHost(ctx, hostID) +} + +// checkDeletePermission verifies the caller has permission to delete the given +// host and returns the host record on success. +func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) { + host, err := s.DB.GetHost(ctx, hostID) + if err != nil { + return db.Host{}, fmt.Errorf("host not found: %w", err) + } + + if isAdmin { + return host, nil + } + + if host.Type != "byoc" { + return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts") + } + if !host.TeamID.Valid || host.TeamID.String != teamID { + return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team") + } + + if userID != "" { membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: userID, TeamID: teamID, }) if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("forbidden: not a member of the specified team") + return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team") } if err != nil { - return fmt.Errorf("check team membership: %w", err) + return db.Host{}, fmt.Errorf("check team membership: %w", err) } - if membership.Role != "owner" { - return fmt.Errorf("forbidden: only team owners can delete BYOC hosts") + if err := requireAdminOrOwner(membership.Role); err != nil { + return db.Host{}, err } } - return s.DB.DeleteHost(ctx, hostID) + return host, nil } // AddTag adds a tag to a host. @@ -357,3 +567,14 @@ func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdm } return s.DB.GetHostTags(ctx, hostID) } + +// HostHasSandboxesError is returned by Delete when the host has active sandboxes +// and force was not set. The caller should present the list to the user and +// re-call Delete with force=true if they confirm. +type HostHasSandboxesError struct { + SandboxIDs []string +} + +func (e *HostHasSandboxesError) Error() string { + return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs) +} diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index ae4bac3..f67eb0d 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -11,16 +11,18 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/scheduler" "git.omukk.dev/wrenn/sandbox/internal/validate" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) // SandboxService provides sandbox lifecycle operations shared between the // REST API and the dashboard. type SandboxService struct { - DB *db.Queries - Agent hostagentv1connect.HostAgentServiceClient + DB *db.Queries + Pool *lifecycle.HostClientPool + Scheduler scheduler.HostScheduler } // SandboxCreateParams holds the parameters for creating a sandbox. @@ -32,8 +34,34 @@ type SandboxCreateParams struct { TimeoutSec int32 } -// Create creates a new sandbox: inserts a pending DB record, calls the host agent, -// and updates the record to running. Returns the sandbox DB row. +// agentForSandbox looks up the host for the given sandbox and returns a client. +func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) { + sb, err := s.DB.GetSandbox(ctx, sandboxID) + if err != nil { + return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) + } + host, err := s.DB.GetHost(ctx, sb.HostID) + if err != nil { + return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err) + } + agent, err := s.Pool.GetForHost(host) + if err != nil { + return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err) + } + return agent, sb, nil +} + +// hostagentClient is a local alias to avoid the full package path in signatures. +type hostagentClient = interface { + CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error) + DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error) + PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error) + ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error) + PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error) +} + +// Create creates a new sandbox: picks a host via the scheduler, inserts a pending +// DB record, calls the host agent, and updates the record to running. func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) { if p.Template == "" { p.Template = "minimal" @@ -58,12 +86,33 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } } + if p.TeamID == "" { + return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required") + } + + // Determine whether this team uses BYOC hosts or platform hosts. + team, err := s.DB.GetTeam(ctx, p.TeamID) + if err != nil { + return db.Sandbox{}, fmt.Errorf("team not found: %w", err) + } + + // Pick a host for this sandbox. + host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc) + if err != nil { + return db.Sandbox{}, fmt.Errorf("select host: %w", err) + } + + agent, err := s.Pool.GetForHost(host) + if err != nil { + return db.Sandbox{}, fmt.Errorf("get agent client: %w", err) + } + sandboxID := id.NewSandboxID() if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{ ID: sandboxID, TeamID: p.TeamID, - HostID: "default", + HostID: host.ID, Template: p.Template, Status: "pending", Vcpus: p.VCPUs, @@ -73,7 +122,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err) } - resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ + resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ SandboxId: sandboxID, Template: p.Template, Vcpus: p.VCPUs, @@ -126,7 +175,12 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status) } - if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ + agent, _, err := s.agentForSandbox(ctx, sandboxID) + if err != nil { + return db.Sandbox{}, err + } + + if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ SandboxId: sandboxID, })); err != nil { return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) @@ -151,7 +205,12 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) ( return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status) } - resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{ + agent, _, err := s.agentForSandbox(ctx, sandboxID) + if err != nil { + return db.Sandbox{}, err + } + + resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{ SandboxId: sandboxID, TimeoutSec: sb.TimeoutSec, })) @@ -181,8 +240,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) return fmt.Errorf("sandbox not found: %w", err) } + agent, _, err := s.agentForSandbox(ctx, sandboxID) + if err != nil { + return err + } + // Destroy on host agent. A not-found response is fine — sandbox is already gone. - if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ + if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ SandboxId: sandboxID, })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { return fmt.Errorf("agent destroy: %w", err) @@ -206,7 +270,12 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err return fmt.Errorf("sandbox is not running (status: %s)", sb.Status) } - if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{ + agent, _, err := s.agentForSandbox(ctx, sandboxID) + if err != nil { + return err + } + + if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{ SandboxId: sandboxID, })); err != nil { return fmt.Errorf("agent ping: %w", err) diff --git a/internal/service/team.go b/internal/service/team.go index 87baadc..859441e 100644 --- a/internal/service/team.go +++ b/internal/service/team.go @@ -14,17 +14,17 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" - "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`) // TeamService provides team management operations. type TeamService struct { - DB *db.Queries - Pool *pgxpool.Pool - Agent hostagentv1connect.HostAgentServiceClient + DB *db.Queries + Pool *pgxpool.Pool + HostPool *lifecycle.HostClientPool } // TeamWithRole pairs a team with the calling user's role in it. @@ -177,10 +177,16 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin var stopIDs []string for _, sb := range sandboxes { - if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sb.ID, - })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err) + host, hostErr := s.DB.GetHost(ctx, sb.HostID) + if hostErr == nil { + agent, agentErr := s.HostPool.GetForHost(host) + if agentErr == nil { + if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ + SandboxId: sb.ID, + })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err) + } + } } stopIDs = append(stopIDs, sb.ID) } @@ -368,3 +374,27 @@ func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string func (s *TeamService) SearchUsersByEmailPrefix(ctx context.Context, prefix string) ([]db.SearchUsersByEmailPrefixRow, error) { return s.DB.SearchUsersByEmailPrefix(ctx, pgtype.Text{String: prefix, Valid: true}) } + +// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot +// be disabled — it is a one-way transition. +// Admin-only — the caller must verify admin status before invoking this. +func (s *TeamService) SetBYOC(ctx context.Context, teamID string, enabled bool) error { + team, err := s.DB.GetTeam(ctx, teamID) + if err != nil { + return fmt.Errorf("team not found: %w", err) + } + if team.DeletedAt.Valid { + return fmt.Errorf("team not found") + } + if !enabled { + return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled") + } + if team.IsByoc { + // Already enabled — idempotent, no-op. + return nil + } + if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil { + return fmt.Errorf("set byoc: %w", err) + } + return nil +} diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index 447f1f7..7afd4d1 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -1830,6 +1830,78 @@ func (*PingSandboxResponse) Descriptor() ([]byte, []int) { return file_hostagent_proto_rawDescGZIP(), []int{32} } +type TerminateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TerminateRequest) Reset() { + *x = TerminateRequest{} + mi := &file_hostagent_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TerminateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TerminateRequest) ProtoMessage() {} + +func (x *TerminateRequest) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[33] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TerminateRequest.ProtoReflect.Descriptor instead. +func (*TerminateRequest) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{33} +} + +type TerminateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TerminateResponse) Reset() { + *x = TerminateResponse{} + mi := &file_hostagent_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TerminateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TerminateResponse) ProtoMessage() {} + +func (x *TerminateResponse) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[34] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TerminateResponse.ProtoReflect.Descriptor instead. +func (*TerminateResponse) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{34} +} + var File_hostagent_proto protoreflect.FileDescriptor const file_hostagent_proto_rawDesc = "" + @@ -1955,7 +2027,10 @@ const file_hostagent_proto_rawDesc = "" + "\x12PingSandboxRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\"\x15\n" + - "\x13PingSandboxResponse2\xce\t\n" + + "\x13PingSandboxResponse\"\x12\n" + + "\x10TerminateRequest\"\x13\n" + + "\x11TerminateResponse2\x9c\n" + + "\n" + "\x10HostAgentService\x12X\n" + "\rCreateSandbox\x12\".hostagent.v1.CreateSandboxRequest\x1a#.hostagent.v1.CreateSandboxResponse\x12[\n" + "\x0eDestroySandbox\x12#.hostagent.v1.DestroySandboxRequest\x1a$.hostagent.v1.DestroySandboxResponse\x12U\n" + @@ -1971,7 +2046,8 @@ const file_hostagent_proto_rawDesc = "" + "ExecStream\x12\x1f.hostagent.v1.ExecStreamRequest\x1a .hostagent.v1.ExecStreamResponse0\x01\x12`\n" + "\x0fWriteFileStream\x12$.hostagent.v1.WriteFileStreamRequest\x1a%.hostagent.v1.WriteFileStreamResponse(\x01\x12]\n" + "\x0eReadFileStream\x12#.hostagent.v1.ReadFileStreamRequest\x1a$.hostagent.v1.ReadFileStreamResponse0\x01\x12R\n" + - "\vPingSandbox\x12 .hostagent.v1.PingSandboxRequest\x1a!.hostagent.v1.PingSandboxResponseB\xb0\x01\n" + + "\vPingSandbox\x12 .hostagent.v1.PingSandboxRequest\x1a!.hostagent.v1.PingSandboxResponse\x12L\n" + + "\tTerminate\x12\x1e.hostagent.v1.TerminateRequest\x1a\x1f.hostagent.v1.TerminateResponseB\xb0\x01\n" + "\x10com.hostagent.v1B\x0eHostagentProtoP\x01Z;git.omukk.dev/wrenn/sandbox/proto/hostagent/gen;hostagentv1\xa2\x02\x03HXX\xaa\x02\fHostagent.V1\xca\x02\fHostagent\\V1\xe2\x02\x18Hostagent\\V1\\GPBMetadata\xea\x02\rHostagent::V1b\x06proto3" var ( @@ -1986,7 +2062,7 @@ func file_hostagent_proto_rawDescGZIP() []byte { return file_hostagent_proto_rawDescData } -var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 33) +var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 35) var file_hostagent_proto_goTypes = []any{ (*CreateSandboxRequest)(nil), // 0: hostagent.v1.CreateSandboxRequest (*CreateSandboxResponse)(nil), // 1: hostagent.v1.CreateSandboxResponse @@ -2021,6 +2097,8 @@ var file_hostagent_proto_goTypes = []any{ (*ReadFileStreamResponse)(nil), // 30: hostagent.v1.ReadFileStreamResponse (*PingSandboxRequest)(nil), // 31: hostagent.v1.PingSandboxRequest (*PingSandboxResponse)(nil), // 32: hostagent.v1.PingSandboxResponse + (*TerminateRequest)(nil), // 33: hostagent.v1.TerminateRequest + (*TerminateResponse)(nil), // 34: hostagent.v1.TerminateResponse } var file_hostagent_proto_depIdxs = []int32{ 16, // 0: hostagent.v1.ListSandboxesResponse.sandboxes:type_name -> hostagent.v1.SandboxInfo @@ -2042,22 +2120,24 @@ var file_hostagent_proto_depIdxs = []int32{ 26, // 16: hostagent.v1.HostAgentService.WriteFileStream:input_type -> hostagent.v1.WriteFileStreamRequest 29, // 17: hostagent.v1.HostAgentService.ReadFileStream:input_type -> hostagent.v1.ReadFileStreamRequest 31, // 18: hostagent.v1.HostAgentService.PingSandbox:input_type -> hostagent.v1.PingSandboxRequest - 1, // 19: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse - 3, // 20: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse - 5, // 21: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse - 7, // 22: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse - 13, // 23: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse - 15, // 24: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse - 18, // 25: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse - 20, // 26: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse - 9, // 27: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse - 11, // 28: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse - 22, // 29: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse - 28, // 30: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse - 30, // 31: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse - 32, // 32: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse - 19, // [19:33] is the sub-list for method output_type - 5, // [5:19] is the sub-list for method input_type + 33, // 19: hostagent.v1.HostAgentService.Terminate:input_type -> hostagent.v1.TerminateRequest + 1, // 20: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse + 3, // 21: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse + 5, // 22: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse + 7, // 23: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse + 13, // 24: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse + 15, // 25: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse + 18, // 26: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse + 20, // 27: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse + 9, // 28: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse + 11, // 29: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse + 22, // 30: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse + 28, // 31: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse + 30, // 32: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse + 32, // 33: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse + 34, // 34: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse + 20, // [20:35] is the sub-list for method output_type + 5, // [5:20] is the sub-list for method input_type 5, // [5:5] is the sub-list for extension type_name 5, // [5:5] is the sub-list for extension extendee 0, // [0:5] is the sub-list for field type_name @@ -2087,7 +2167,7 @@ func file_hostagent_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_hostagent_proto_rawDesc), len(file_hostagent_proto_rawDesc)), NumEnums: 0, - NumMessages: 33, + NumMessages: 35, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go index 6eb5d45..d144451 100644 --- a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go +++ b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go @@ -74,6 +74,9 @@ const ( // HostAgentServicePingSandboxProcedure is the fully-qualified name of the HostAgentService's // PingSandbox RPC. HostAgentServicePingSandboxProcedure = "/hostagent.v1.HostAgentService/PingSandbox" + // HostAgentServiceTerminateProcedure is the fully-qualified name of the HostAgentService's + // Terminate RPC. + HostAgentServiceTerminateProcedure = "/hostagent.v1.HostAgentService/Terminate" ) // HostAgentServiceClient is a client for the hostagent.v1.HostAgentService service. @@ -108,6 +111,10 @@ type HostAgentServiceClient interface { ReadFileStream(context.Context, *connect.Request[gen.ReadFileStreamRequest]) (*connect.ServerStreamForClient[gen.ReadFileStreamResponse], error) // PingSandbox resets the inactivity timer for a running sandbox. PingSandbox(context.Context, *connect.Request[gen.PingSandboxRequest]) (*connect.Response[gen.PingSandboxResponse], error) + // Terminate instructs the host agent to destroy all sandboxes and exit. + // Called by the control plane immediately when a host is deleted so the + // agent shuts down without waiting for the next heartbeat cycle. + Terminate(context.Context, *connect.Request[gen.TerminateRequest]) (*connect.Response[gen.TerminateResponse], error) } // NewHostAgentServiceClient constructs a client for the hostagent.v1.HostAgentService service. By @@ -205,6 +212,12 @@ func NewHostAgentServiceClient(httpClient connect.HTTPClient, baseURL string, op connect.WithSchema(hostAgentServiceMethods.ByName("PingSandbox")), connect.WithClientOptions(opts...), ), + terminate: connect.NewClient[gen.TerminateRequest, gen.TerminateResponse]( + httpClient, + baseURL+HostAgentServiceTerminateProcedure, + connect.WithSchema(hostAgentServiceMethods.ByName("Terminate")), + connect.WithClientOptions(opts...), + ), } } @@ -224,6 +237,7 @@ type hostAgentServiceClient struct { writeFileStream *connect.Client[gen.WriteFileStreamRequest, gen.WriteFileStreamResponse] readFileStream *connect.Client[gen.ReadFileStreamRequest, gen.ReadFileStreamResponse] pingSandbox *connect.Client[gen.PingSandboxRequest, gen.PingSandboxResponse] + terminate *connect.Client[gen.TerminateRequest, gen.TerminateResponse] } // CreateSandbox calls hostagent.v1.HostAgentService.CreateSandbox. @@ -296,6 +310,11 @@ func (c *hostAgentServiceClient) PingSandbox(ctx context.Context, req *connect.R return c.pingSandbox.CallUnary(ctx, req) } +// Terminate calls hostagent.v1.HostAgentService.Terminate. +func (c *hostAgentServiceClient) Terminate(ctx context.Context, req *connect.Request[gen.TerminateRequest]) (*connect.Response[gen.TerminateResponse], error) { + return c.terminate.CallUnary(ctx, req) +} + // HostAgentServiceHandler is an implementation of the hostagent.v1.HostAgentService service. type HostAgentServiceHandler interface { // CreateSandbox boots a new microVM with the given configuration. @@ -328,6 +347,10 @@ type HostAgentServiceHandler interface { ReadFileStream(context.Context, *connect.Request[gen.ReadFileStreamRequest], *connect.ServerStream[gen.ReadFileStreamResponse]) error // PingSandbox resets the inactivity timer for a running sandbox. PingSandbox(context.Context, *connect.Request[gen.PingSandboxRequest]) (*connect.Response[gen.PingSandboxResponse], error) + // Terminate instructs the host agent to destroy all sandboxes and exit. + // Called by the control plane immediately when a host is deleted so the + // agent shuts down without waiting for the next heartbeat cycle. + Terminate(context.Context, *connect.Request[gen.TerminateRequest]) (*connect.Response[gen.TerminateResponse], error) } // NewHostAgentServiceHandler builds an HTTP handler from the service implementation. It returns the @@ -421,6 +444,12 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han connect.WithSchema(hostAgentServiceMethods.ByName("PingSandbox")), connect.WithHandlerOptions(opts...), ) + hostAgentServiceTerminateHandler := connect.NewUnaryHandler( + HostAgentServiceTerminateProcedure, + svc.Terminate, + connect.WithSchema(hostAgentServiceMethods.ByName("Terminate")), + connect.WithHandlerOptions(opts...), + ) return "/hostagent.v1.HostAgentService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case HostAgentServiceCreateSandboxProcedure: @@ -451,6 +480,8 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han hostAgentServiceReadFileStreamHandler.ServeHTTP(w, r) case HostAgentServicePingSandboxProcedure: hostAgentServicePingSandboxHandler.ServeHTTP(w, r) + case HostAgentServiceTerminateProcedure: + hostAgentServiceTerminateHandler.ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -515,3 +546,7 @@ func (UnimplementedHostAgentServiceHandler) ReadFileStream(context.Context, *con func (UnimplementedHostAgentServiceHandler) PingSandbox(context.Context, *connect.Request[gen.PingSandboxRequest]) (*connect.Response[gen.PingSandboxResponse], error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.PingSandbox is not implemented")) } + +func (UnimplementedHostAgentServiceHandler) Terminate(context.Context, *connect.Request[gen.TerminateRequest]) (*connect.Response[gen.TerminateResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.Terminate is not implemented")) +} diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index b9ceccf..c9cfffa 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -49,6 +49,11 @@ service HostAgentService { // PingSandbox resets the inactivity timer for a running sandbox. rpc PingSandbox(PingSandboxRequest) returns (PingSandboxResponse); + // Terminate instructs the host agent to destroy all sandboxes and exit. + // Called by the control plane immediately when a host is deleted so the + // agent shuts down without waiting for the next heartbeat cycle. + rpc Terminate(TerminateRequest) returns (TerminateResponse); + } message CreateSandboxRequest { @@ -236,3 +241,10 @@ message PingSandboxRequest { message PingSandboxResponse {} + + +// ── Terminate ──────────────────────────────────────────────────────── + +message TerminateRequest {} + +message TerminateResponse {}