forked from wrenn/wrenn
Implement host registration, JWT refresh tokens, and multi-host scheduling
Replaces the hardcoded CP_HOST_AGENT_ADDR single-agent setup with a DB-driven registration system supporting multiple host agents (BYOC). Key changes: - Host agents register via one-time token, receive a 7-day JWT + 60-day refresh token; heartbeat loop auto-refreshes on 401/403 and pauses all sandboxes if refresh fails - HostClientPool: lazy Connect RPC client cache keyed by host ID, replacing the single static agent client throughout the API and service layers - RoundRobinScheduler: picks an online host for each new sandbox via ListActiveHosts; extensible for future scheduling strategies - HostMonitor (replaces Reconciler): passive heartbeat staleness check marks hosts unreachable and sandboxes missing after 90s; active reconciliation per online host restores missing-but-alive sandboxes and stops orphans - Graceful host delete: returns 409 with affected sandbox list without ?force=true; force-delete destroys sandboxes then evicts pool client - Snapshot delete broadcasts to all online hosts (templates have no host_id) - sandbox.Manager.PauseAll: pauses all running VMs on CP connectivity loss - New migration: host_refresh_tokens table with token rotation (issue-then- revoke ordering to prevent lockout on mid-rotation crash) - New sandbox status 'missing' (reversible, unlike 'stopped') and host status 'unreachable'; both reflected in OpenAPI spec - Fix: refresh token auth failure now returns 401 (was 400 via generic 'invalid' substring match in serviceErrToHTTP)
This commit is contained in:
@ -6,7 +6,6 @@ REDIS_URL=redis://localhost:6379/0
|
|||||||
|
|
||||||
# Control Plane
|
# Control Plane
|
||||||
CP_LISTEN_ADDR=:8000
|
CP_LISTEN_ADDR=:8000
|
||||||
CP_HOST_AGENT_ADDR=localhost:50051
|
|
||||||
|
|
||||||
# Host Agent
|
# Host Agent
|
||||||
AGENT_LISTEN_ADDR=:50051
|
AGENT_LISTEN_ADDR=:50051
|
||||||
|
|||||||
@ -17,7 +17,8 @@ import (
|
|||||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/config"
|
"git.omukk.dev/wrenn/sandbox/internal/config"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -66,12 +67,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
slog.Info("connected to redis")
|
slog.Info("connected to redis")
|
||||||
|
|
||||||
// Connect RPC client for the host agent.
|
// Host client pool — manages Connect RPC clients to host agents.
|
||||||
agentHTTP := &http.Client{Timeout: 10 * time.Minute}
|
hostPool := lifecycle.NewHostClientPool()
|
||||||
agentClient := hostagentv1connect.NewHostAgentServiceClient(
|
|
||||||
agentHTTP,
|
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
||||||
cfg.HostAgentAddr,
|
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
||||||
)
|
|
||||||
|
|
||||||
// OAuth provider registry.
|
// OAuth provider registry.
|
||||||
oauthRegistry := oauth.NewRegistry()
|
oauthRegistry := oauth.NewRegistry()
|
||||||
@ -87,11 +87,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// API server.
|
// API server.
|
||||||
srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||||
|
|
||||||
// Start reconciler.
|
// Start host monitor (passive + active reconciliation every 30s).
|
||||||
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)
|
monitor := api.NewHostMonitor(queries, hostPool, 30*time.Second)
|
||||||
reconciler.Start(ctx)
|
monitor.Start(ctx)
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: cfg.ListenAddr,
|
Addr: cfg.ListenAddr,
|
||||||
@ -114,7 +114,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("control plane starting", "addr", cfg.ListenAddr, "agent", cfg.HostAgentAddr)
|
slog.Info("control plane starting", "addr", cfg.ListenAddr)
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
slog.Error("http server error", "error", err)
|
slog.Error("http server error", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
registrationToken := flag.String("register", "", "One-time registration token from the control plane")
|
registrationToken := flag.String("register", "", "One-time registration token from the control plane (required on first run)")
|
||||||
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
|
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@ -42,7 +42,16 @@ func main() {
|
|||||||
listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
|
listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
|
||||||
rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
|
rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
|
||||||
cpURL := os.Getenv("AGENT_CP_URL")
|
cpURL := os.Getenv("AGENT_CP_URL")
|
||||||
tokenFile := filepath.Join(rootDir, "host-token")
|
tokenFile := filepath.Join(rootDir, "host.jwt")
|
||||||
|
|
||||||
|
if cpURL == "" {
|
||||||
|
slog.Error("AGENT_CP_URL environment variable is required")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if *advertiseAddr == "" {
|
||||||
|
slog.Error("--address flag is required (externally-reachable ip:port)")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
cfg := sandbox.Config{
|
cfg := sandbox.Config{
|
||||||
KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
|
KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
|
||||||
@ -58,13 +67,7 @@ func main() {
|
|||||||
|
|
||||||
mgr.StartTTLReaper(ctx)
|
mgr.StartTTLReaper(ctx)
|
||||||
|
|
||||||
if *advertiseAddr == "" {
|
// Register with the control plane and start heartbeating.
|
||||||
slog.Error("--address flag is required (externally-reachable ip:port)")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register with the control plane (if configured).
|
|
||||||
if cpURL != "" {
|
|
||||||
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||||
CPURL: cpURL,
|
CPURL: cpURL,
|
||||||
RegistrationToken: *registrationToken,
|
RegistrationToken: *registrationToken,
|
||||||
@ -83,8 +86,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("host registered", "host_id", hostID)
|
slog.Info("host registered", "host_id", hostID)
|
||||||
hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second)
|
|
||||||
}
|
// Start heartbeat loop. On CP rejection: try JWT refresh. If that fails,
|
||||||
|
// pause all running sandboxes to ensure they're not left orphaned.
|
||||||
|
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, func() {
|
||||||
|
pauseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
mgr.PauseAll(pauseCtx)
|
||||||
|
})
|
||||||
|
|
||||||
srv := hostagent.NewServer(mgr)
|
srv := hostagent.NewServer(mgr)
|
||||||
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
||||||
@ -115,7 +124,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("host agent starting", "addr", listenAddr)
|
slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
slog.Error("http server error", "error", err)
|
slog.Error("http server error", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
19
db/migrations/20260324120214_host_refresh_tokens.sql
Normal file
19
db/migrations/20260324120214_host_refresh_tokens.sql
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
-- Refresh tokens for host agent JWT rotation.
|
||||||
|
-- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation).
|
||||||
|
-- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that.
|
||||||
|
CREATE TABLE host_refresh_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||||
|
token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE host_refresh_tokens;
|
||||||
19
db/queries/host_refresh_tokens.sql
Normal file
19
db/queries/host_refresh_tokens.sql
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
-- name: InsertHostRefreshToken :one
|
||||||
|
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetHostRefreshTokenByHash :one
|
||||||
|
SELECT * FROM host_refresh_tokens
|
||||||
|
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW();
|
||||||
|
|
||||||
|
-- name: RevokeHostRefreshToken :exec
|
||||||
|
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: RevokeHostRefreshTokensByHost :exec
|
||||||
|
UPDATE host_refresh_tokens SET revoked_at = NOW()
|
||||||
|
WHERE host_id = $1 AND revoked_at IS NULL;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredHostRefreshTokens :exec
|
||||||
|
DELETE FROM host_refresh_tokens
|
||||||
|
WHERE expires_at < NOW() OR revoked_at IS NOT NULL;
|
||||||
@ -67,3 +67,18 @@ SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC;
|
|||||||
|
|
||||||
-- name: GetHostByTeam :one
|
-- name: GetHostByTeam :one
|
||||||
SELECT * FROM hosts WHERE id = $1 AND team_id = $2;
|
SELECT * FROM hosts WHERE id = $1 AND team_id = $2;
|
||||||
|
|
||||||
|
-- name: ListActiveHosts :many
|
||||||
|
-- Returns all hosts that have completed registration (not pending/offline).
|
||||||
|
SELECT * FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at;
|
||||||
|
|
||||||
|
-- name: UpdateHostHeartbeatAndStatus :exec
|
||||||
|
-- Updates last_heartbeat_at and transitions unreachable hosts back to online.
|
||||||
|
UPDATE hosts
|
||||||
|
SET last_heartbeat_at = NOW(),
|
||||||
|
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: MarkHostUnreachable :exec
|
||||||
|
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1;
|
||||||
|
|||||||
@ -56,3 +56,20 @@ WHERE id = ANY($1::text[]);
|
|||||||
SELECT * FROM sandboxes
|
SELECT * FROM sandboxes
|
||||||
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
|
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
|
||||||
ORDER BY created_at DESC;
|
ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: MarkSandboxesMissingByHost :exec
|
||||||
|
-- Called when the host monitor marks a host unreachable.
|
||||||
|
-- Marks running/starting/pending sandboxes on that host as 'missing' so users see
|
||||||
|
-- the sandbox is not currently reachable, without permanently losing the record.
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = 'missing',
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
|
||||||
|
|
||||||
|
-- name: BulkRestoreRunning :exec
|
||||||
|
-- Called by the reconciler when a host comes back online and its sandboxes are
|
||||||
|
-- confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = 'running',
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = ANY($1::text[]) AND status = 'missing';
|
||||||
|
|||||||
20
internal/api/agent_helper.go
Normal file
20
internal/api/agent_helper.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// agentForHost looks up the host record and returns a Connect RPC client for it.
|
||||||
|
// Returns an error if the host is not found or has no address.
|
||||||
|
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||||
|
host, err := queries.GetHost(ctx, hostID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("host not found: %w", err)
|
||||||
|
}
|
||||||
|
return pool.GetForHost(host)
|
||||||
|
}
|
||||||
@ -14,17 +14,17 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type execHandler struct {
|
type execHandler struct {
|
||||||
db *db.Queries
|
db *db.Queries
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
|
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
|
||||||
return &execHandler{db: db, agent: agent}
|
return &execHandler{db: db, pool: pool}
|
||||||
}
|
}
|
||||||
|
|
||||||
type execRequest struct {
|
type execRequest struct {
|
||||||
@ -73,7 +73,13 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Cmd: req.Cmd,
|
Cmd: req.Cmd,
|
||||||
Args: req.Args,
|
Args: req.Args,
|
||||||
|
|||||||
@ -14,17 +14,17 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type execStreamHandler struct {
|
type execStreamHandler struct {
|
||||||
db *db.Queries
|
db *db.Queries
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||||
return &execStreamHandler{db: db, agent: agent}
|
return &execStreamHandler{db: db, pool: pool}
|
||||||
}
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
var upgrader = websocket.Upgrader{
|
||||||
@ -80,11 +80,17 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
sendWSError(conn, "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Open streaming exec to host agent.
|
// Open streaming exec to host agent.
|
||||||
streamCtx, cancel := context.WithCancel(ctx)
|
streamCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Cmd: startMsg.Cmd,
|
Cmd: startMsg.Cmd,
|
||||||
Args: startMsg.Args,
|
Args: startMsg.Args,
|
||||||
|
|||||||
@ -11,17 +11,17 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type filesHandler struct {
|
type filesHandler struct {
|
||||||
db *db.Queries
|
db *db.Queries
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
|
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
|
||||||
return &filesHandler{db: db, agent: agent}
|
return &filesHandler{db: db, pool: pool}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||||
@ -75,7 +75,13 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Path: filePath,
|
Path: filePath,
|
||||||
Content: content,
|
Content: content,
|
||||||
@ -120,7 +126,13 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Path: req.Path,
|
Path: req.Path,
|
||||||
}))
|
}))
|
||||||
|
|||||||
@ -12,17 +12,17 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type filesStreamHandler struct {
|
type filesStreamHandler struct {
|
||||||
db *db.Queries
|
db *db.Queries
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
|
||||||
return &filesStreamHandler{db: db, agent: agent}
|
return &filesStreamHandler{db: db, pool: pool}
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||||
@ -88,8 +88,14 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
|||||||
}
|
}
|
||||||
defer filePart.Close()
|
defer filePart.Close()
|
||||||
|
|
||||||
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Open client-streaming RPC to host agent.
|
// Open client-streaming RPC to host agent.
|
||||||
stream := h.agent.WriteFileStream(ctx)
|
stream := agent.WriteFileStream(ctx)
|
||||||
|
|
||||||
// Send metadata first.
|
// Send metadata first.
|
||||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||||
@ -164,8 +170,14 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Open server-streaming RPC to host agent.
|
// Open server-streaming RPC to host agent.
|
||||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Path: req.Path,
|
Path: req.Path,
|
||||||
}))
|
}))
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -34,6 +35,25 @@ type createHostResponse struct {
|
|||||||
RegistrationToken string `json:"registration_token"`
|
RegistrationToken string `json:"registration_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type refreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshTokenResponse struct {
|
||||||
|
Host hostResponse `json:"host"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type deletePreviewResponse struct {
|
||||||
|
Host hostResponse `json:"host"`
|
||||||
|
SandboxIDs []string `json:"sandbox_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type hasSandboxesErrorResponse struct {
|
||||||
|
SandboxIDs []string `json:"sandbox_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
type registerHostRequest struct {
|
type registerHostRequest struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Arch string `json:"arch,omitempty"`
|
Arch string `json:"arch,omitempty"`
|
||||||
@ -46,6 +66,7 @@ type registerHostRequest struct {
|
|||||||
type registerHostResponse struct {
|
type registerHostResponse struct {
|
||||||
Host hostResponse `json:"host"`
|
Host hostResponse `json:"host"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type addTagRequest struct {
|
type addTagRequest struct {
|
||||||
@ -183,18 +204,54 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete handles DELETE /v1/hosts/{id}.
|
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
|
||||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
// Returns what would be affected without making changes, for confirmation UI.
|
||||||
|
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
|
||||||
hostID := chi.URLParam(r, "id")
|
hostID := chi.URLParam(r, "id")
|
||||||
ac := auth.MustFromContext(r.Context())
|
ac := auth.MustFromContext(r.Context())
|
||||||
|
|
||||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||||
|
if err != nil {
|
||||||
status, code, msg := serviceErrToHTTP(err)
|
status, code, msg := serviceErrToHTTP(err)
|
||||||
writeError(w, status, code, msg)
|
writeError(w, status, code, msg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, deletePreviewResponse{
|
||||||
|
Host: hostToResponse(preview.Host),
|
||||||
|
SandboxIDs: preview.SandboxIDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete handles DELETE /v1/hosts/{id}.
|
||||||
|
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
|
||||||
|
// With ?force=true: gracefully stops all sandboxes then deletes the host.
|
||||||
|
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hostID := chi.URLParam(r, "id")
|
||||||
|
ac := auth.MustFromContext(r.Context())
|
||||||
|
force := r.URL.Query().Get("force") == "true"
|
||||||
|
|
||||||
|
err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
|
||||||
|
if err == nil {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a "has running sandboxes" error and return a structured 409.
|
||||||
|
var hasSandboxes *service.HostHasSandboxesError
|
||||||
|
if errors.As(err, &hasSandboxes) {
|
||||||
|
writeJSON(w, http.StatusConflict, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"code": "has_active_sandboxes",
|
||||||
|
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
|
||||||
|
"sandbox_ids": hasSandboxes.SandboxIDs,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, code, msg := serviceErrToHTTP(err)
|
||||||
|
writeError(w, status, code, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||||
@ -249,6 +306,7 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||||
Host: hostToResponse(result.Host),
|
Host: hostToResponse(result.Host),
|
||||||
Token: result.JWT,
|
Token: result.JWT,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,6 +369,33 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
|
||||||
|
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
|
||||||
|
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req refreshTokenRequest
|
||||||
|
if err := decodeJSON(r, &req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.RefreshToken == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
status, code, msg := serviceErrToHTTP(err)
|
||||||
|
writeError(w, status, code, msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, refreshTokenResponse{
|
||||||
|
Host: hostToResponse(result.Host),
|
||||||
|
Token: result.JWT,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||||
hostID := chi.URLParam(r, "id")
|
hostID := chi.URLParam(r, "id")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -14,20 +15,45 @@ import (
|
|||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type snapshotHandler struct {
|
type snapshotHandler struct {
|
||||||
svc *service.TemplateService
|
svc *service.TemplateService
|
||||||
db *db.Queries
|
db *db.Queries
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
|
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool) *snapshotHandler {
|
||||||
return &snapshotHandler{svc: svc, db: db, agent: agent}
|
return &snapshotHandler{svc: svc, db: db, pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||||
|
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||||
|
// and ignore NotFound errors. TODO: add host_id to templates table.
|
||||||
|
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error {
|
||||||
|
hosts, err := h.db.ListActiveHosts(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list hosts: %w", err)
|
||||||
|
}
|
||||||
|
for _, host := range hosts {
|
||||||
|
if host.Status != "online" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
agent, err := h.pool.GetForHost(host)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
|
||||||
|
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||||
|
slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type createSnapshotRequest struct {
|
type createSnapshotRequest struct {
|
||||||
@ -93,10 +119,9 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Delete old files from the agent before removing the DB record.
|
// Delete old snapshot files from all hosts before removing the DB record.
|
||||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
|
if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil {
|
||||||
status, code, msg := agentErrToHTTP(err)
|
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||||
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||||
@ -116,7 +141,13 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||||
SandboxId: req.SandboxID,
|
SandboxId: req.SandboxID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
}))
|
}))
|
||||||
@ -186,11 +217,8 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
if err := h.deleteSnapshotBroadcast(ctx, name); err != nil {
|
||||||
Name: name,
|
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||||
})); err != nil {
|
|
||||||
status, code, msg := agentErrToHTTP(err)
|
|
||||||
writeError(w, status, code, "failed to delete snapshot files: "+msg)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
198
internal/api/host_monitor.go
Normal file
198
internal/api/host_monitor.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// unreachableThreshold is how long a host can go without a heartbeat before
|
||||||
|
// it is considered unreachable (3 missed 30-second heartbeats).
|
||||||
|
const unreachableThreshold = 90 * time.Second
|
||||||
|
|
||||||
|
// HostMonitor runs on a fixed interval and performs two duties:
|
||||||
|
//
|
||||||
|
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
|
||||||
|
// "unreachable" and marks their active sandboxes as "missing".
|
||||||
|
//
|
||||||
|
// 2. Active reconciliation: for each online host, calls ListSandboxes and
|
||||||
|
// reconciles DB state against live host state — restoring "missing"
|
||||||
|
// sandboxes that are actually alive, and stopping orphaned ones.
|
||||||
|
type HostMonitor struct {
|
||||||
|
db *db.Queries
|
||||||
|
pool *lifecycle.HostClientPool
|
||||||
|
interval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostMonitor creates a HostMonitor.
|
||||||
|
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, interval time.Duration) *HostMonitor {
|
||||||
|
return &HostMonitor{
|
||||||
|
db: queries,
|
||||||
|
pool: pool,
|
||||||
|
interval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start runs the monitor loop until the context is cancelled.
|
||||||
|
func (m *HostMonitor) Start(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(m.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
m.run(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *HostMonitor) run(ctx context.Context) {
|
||||||
|
hosts, err := m.db.ListActiveHosts(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("host monitor: failed to list hosts", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
m.checkHost(ctx, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||||
|
// --- Passive phase: check heartbeat staleness ---
|
||||||
|
|
||||||
|
stale := !host.LastHeartbeatAt.Valid ||
|
||||||
|
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
|
||||||
|
|
||||||
|
if stale && host.Status != "unreachable" {
|
||||||
|
slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
|
||||||
|
"last_heartbeat", host.LastHeartbeatAt.Time)
|
||||||
|
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Active reconciliation: only for online hosts ---
|
||||||
|
|
||||||
|
if host.Status != "online" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
agent, err := m.pool.GetForHost(host)
|
||||||
|
if err != nil {
|
||||||
|
// Host has no address yet (e.g., just registered) — skip.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||||
|
if err != nil {
|
||||||
|
// RPC failure is a transient condition; the passive phase will catch it
|
||||||
|
// if heartbeats stop arriving.
|
||||||
|
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build set of sandbox IDs alive on the host.
|
||||||
|
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||||
|
for _, sb := range resp.Msg.Sandboxes {
|
||||||
|
alive[sb.SandboxId] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||||
|
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
||||||
|
autoPaused[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Restore sandboxes that are "missing" in DB but alive on host ---
|
||||||
|
// This handles the case where CP marked them missing due to a transient
|
||||||
|
// heartbeat gap, but the host was actually fine.
|
||||||
|
|
||||||
|
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||||
|
HostID: host.ID,
|
||||||
|
Column2: []string{"missing"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
|
||||||
|
} else {
|
||||||
|
var toRestore []string
|
||||||
|
var toStop []string
|
||||||
|
for _, sb := range missingSandboxes {
|
||||||
|
if _, ok := alive[sb.ID]; ok {
|
||||||
|
toRestore = append(toRestore, sb.ID)
|
||||||
|
} else {
|
||||||
|
toStop = append(toStop, sb.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(toRestore) > 0 {
|
||||||
|
slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
|
||||||
|
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(toStop) > 0 {
|
||||||
|
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
|
||||||
|
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||||
|
Column1: toStop,
|
||||||
|
Status: "stopped",
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Find running sandboxes in DB that are no longer alive on the host ---
|
||||||
|
|
||||||
|
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||||
|
HostID: host.ID,
|
||||||
|
Column2: []string{"running"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toPause, toStop []string
|
||||||
|
for _, sb := range runningSandboxes {
|
||||||
|
if _, ok := alive[sb.ID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := autoPaused[sb.ID]; ok {
|
||||||
|
toPause = append(toPause, sb.ID)
|
||||||
|
} else {
|
||||||
|
toStop = append(toStop, sb.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toPause) > 0 {
|
||||||
|
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
|
||||||
|
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||||
|
Column1: toPause,
|
||||||
|
Status: "paused",
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(toStop) > 0 {
|
||||||
|
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
|
||||||
|
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||||
|
Column1: toStop,
|
||||||
|
Status: "stopped",
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -89,6 +89,8 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
|||||||
return http.StatusConflict, "invalid_state", msg
|
return http.StatusConflict, "invalid_state", msg
|
||||||
case strings.Contains(msg, "forbidden"):
|
case strings.Contains(msg, "forbidden"):
|
||||||
return http.StatusForbidden, "forbidden", msg
|
return http.StatusForbidden, "forbidden", msg
|
||||||
|
case strings.Contains(msg, "invalid or expired"):
|
||||||
|
return http.StatusUnauthorized, "unauthorized", msg
|
||||||
case strings.Contains(msg, "invalid"):
|
case strings.Contains(msg, "invalid"):
|
||||||
return http.StatusBadRequest, "invalid_request", msg
|
return http.StatusBadRequest, "invalid_request", msg
|
||||||
default:
|
default:
|
||||||
|
|||||||
@ -1193,8 +1193,16 @@ paths:
|
|||||||
security:
|
security:
|
||||||
- bearerAuth: []
|
- bearerAuth: []
|
||||||
description: |
|
description: |
|
||||||
Admins can delete any host. Team owners can delete BYOC hosts
|
Admins can delete any host. Team owners and admins can delete BYOC hosts
|
||||||
belonging to their team.
|
belonging to their team. Without `?force=true`, returns 409 if the host
|
||||||
|
has active sandboxes. With `?force=true`, destroys all sandboxes first.
|
||||||
|
parameters:
|
||||||
|
- name: force
|
||||||
|
in: query
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
|
description: If true, destroy all sandboxes on the host before deleting.
|
||||||
responses:
|
responses:
|
||||||
"204":
|
"204":
|
||||||
description: Host deleted
|
description: Host deleted
|
||||||
@ -1204,6 +1212,12 @@ paths:
|
|||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: "#/components/schemas/Error"
|
$ref: "#/components/schemas/Error"
|
||||||
|
"409":
|
||||||
|
description: Host has active sandboxes (only when force is not set)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/HostHasSandboxesError"
|
||||||
|
|
||||||
/v1/hosts/{id}/token:
|
/v1/hosts/{id}/token:
|
||||||
parameters:
|
parameters:
|
||||||
@ -1312,6 +1326,72 @@ paths:
|
|||||||
schema:
|
schema:
|
||||||
$ref: "#/components/schemas/Error"
|
$ref: "#/components/schemas/Error"
|
||||||
|
|
||||||
|
/v1/hosts/auth/refresh:
|
||||||
|
post:
|
||||||
|
summary: Refresh host JWT
|
||||||
|
operationId: refreshHostToken
|
||||||
|
tags: [hosts]
|
||||||
|
description: |
|
||||||
|
Exchanges a refresh token for a new JWT and rotated refresh token.
|
||||||
|
The old refresh token is immediately revoked. No authentication required —
|
||||||
|
the refresh token itself is the credential.
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/RefreshHostTokenRequest"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: New JWT and rotated refresh token
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/RefreshHostTokenResponse"
|
||||||
|
"401":
|
||||||
|
description: Invalid, expired, or revoked refresh token
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Error"
|
||||||
|
|
||||||
|
/v1/hosts/{id}/delete-preview:
|
||||||
|
parameters:
|
||||||
|
- name: id
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
get:
|
||||||
|
summary: Preview host deletion
|
||||||
|
operationId: getHostDeletePreview
|
||||||
|
tags: [hosts]
|
||||||
|
security:
|
||||||
|
- bearerAuth: []
|
||||||
|
description: |
|
||||||
|
Returns the list of sandbox IDs that would be destroyed if the host
|
||||||
|
were deleted with `?force=true`. No state is modified.
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Deletion preview
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/HostDeletePreview"
|
||||||
|
"403":
|
||||||
|
description: Insufficient permissions
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Error"
|
||||||
|
"404":
|
||||||
|
description: Host not found
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Error"
|
||||||
|
|
||||||
/v1/hosts/{id}/tags:
|
/v1/hosts/{id}/tags:
|
||||||
parameters:
|
parameters:
|
||||||
- name: id
|
- name: id
|
||||||
@ -1405,7 +1485,7 @@ components:
|
|||||||
type: apiKey
|
type: apiKey
|
||||||
in: header
|
in: header
|
||||||
name: X-Host-Token
|
name: X-Host-Token
|
||||||
description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year.
|
description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days.
|
||||||
|
|
||||||
schemas:
|
schemas:
|
||||||
SignupRequest:
|
SignupRequest:
|
||||||
@ -1505,7 +1585,7 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
status:
|
status:
|
||||||
type: string
|
type: string
|
||||||
enum: [pending, running, paused, stopped, error]
|
enum: [pending, starting, running, paused, hibernated, stopped, missing, error]
|
||||||
template:
|
template:
|
||||||
type: string
|
type: string
|
||||||
vcpus:
|
vcpus:
|
||||||
@ -1661,7 +1741,10 @@ components:
|
|||||||
$ref: "#/components/schemas/Host"
|
$ref: "#/components/schemas/Host"
|
||||||
token:
|
token:
|
||||||
type: string
|
type: string
|
||||||
description: Long-lived host JWT for X-Host-Token header. Valid for 1 year.
|
description: Host JWT for X-Host-Token header. Valid for 7 days.
|
||||||
|
refresh_token:
|
||||||
|
type: string
|
||||||
|
description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.
|
||||||
|
|
||||||
Host:
|
Host:
|
||||||
type: object
|
type: object
|
||||||
@ -1697,7 +1780,7 @@ components:
|
|||||||
nullable: true
|
nullable: true
|
||||||
status:
|
status:
|
||||||
type: string
|
type: string
|
||||||
enum: [pending, online, offline, draining]
|
enum: [pending, online, offline, draining, unreachable]
|
||||||
last_heartbeat_at:
|
last_heartbeat_at:
|
||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
@ -1711,6 +1794,54 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
|
|
||||||
|
RefreshHostTokenRequest:
|
||||||
|
type: object
|
||||||
|
required: [refresh_token]
|
||||||
|
properties:
|
||||||
|
refresh_token:
|
||||||
|
type: string
|
||||||
|
description: Refresh token obtained from registration or a previous refresh.
|
||||||
|
|
||||||
|
RefreshHostTokenResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
host:
|
||||||
|
$ref: "#/components/schemas/Host"
|
||||||
|
token:
|
||||||
|
type: string
|
||||||
|
description: New host JWT. Valid for 7 days.
|
||||||
|
refresh_token:
|
||||||
|
type: string
|
||||||
|
description: New refresh token. Valid for 60 days; old token is revoked.
|
||||||
|
|
||||||
|
HostDeletePreview:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
host:
|
||||||
|
$ref: "#/components/schemas/Host"
|
||||||
|
sandbox_ids:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: IDs of sandboxes that would be destroyed on force-delete.
|
||||||
|
|
||||||
|
HostHasSandboxesError:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
error:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
example: host_has_sandboxes
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
sandbox_ids:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: IDs of active sandboxes blocking deletion.
|
||||||
|
|
||||||
AddTagRequest:
|
AddTagRequest:
|
||||||
type: object
|
type: object
|
||||||
required: [tag]
|
required: [tag]
|
||||||
|
|||||||
@ -1,126 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log/slog"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"connectrpc.com/connect"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Reconciler periodically compares the host agent's sandbox list with the DB
|
|
||||||
// and marks sandboxes that no longer exist on the host as stopped.
|
|
||||||
type Reconciler struct {
|
|
||||||
db *db.Queries
|
|
||||||
agent hostagentv1connect.HostAgentServiceClient
|
|
||||||
hostID string
|
|
||||||
interval time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewReconciler creates a new reconciler.
|
|
||||||
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
|
|
||||||
return &Reconciler{
|
|
||||||
db: db,
|
|
||||||
agent: agent,
|
|
||||||
hostID: hostID,
|
|
||||||
interval: interval,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs the reconciliation loop until the context is cancelled.
|
|
||||||
func (rc *Reconciler) Start(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(rc.interval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
rc.reconcile(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rc *Reconciler) reconcile(ctx context.Context) {
|
|
||||||
// Single RPC returns both the running sandbox list and any IDs that
|
|
||||||
// were auto-paused by the TTL reaper since the last call.
|
|
||||||
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a set of sandbox IDs that are alive on the host.
|
|
||||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
|
||||||
for _, sb := range resp.Msg.Sandboxes {
|
|
||||||
alive[sb.SandboxId] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build auto-paused set from the same response.
|
|
||||||
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
|
||||||
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
|
||||||
autoPausedSet[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all DB sandboxes for this host that are running.
|
|
||||||
// Paused sandboxes are excluded: they are expected to not exist on the
|
|
||||||
// host agent because pause = snapshot + destroy resources.
|
|
||||||
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
|
||||||
HostID: rc.hostID,
|
|
||||||
Column2: []string{"running"},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find sandboxes in DB that are no longer on the host.
|
|
||||||
var stale []string
|
|
||||||
for _, sb := range dbSandboxes {
|
|
||||||
if _, ok := alive[sb.ID]; !ok {
|
|
||||||
stale = append(stale, sb.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(stale) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split stale sandboxes into those auto-paused by the TTL reaper vs
|
|
||||||
// those that crashed/were orphaned.
|
|
||||||
var toPause, toStop []string
|
|
||||||
for _, id := range stale {
|
|
||||||
if _, ok := autoPausedSet[id]; ok {
|
|
||||||
toPause = append(toPause, id)
|
|
||||||
} else {
|
|
||||||
toStop = append(toStop, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toPause) > 0 {
|
|
||||||
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
|
|
||||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
|
||||||
Column1: toPause,
|
|
||||||
Status: "paused",
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toStop) > 0 {
|
|
||||||
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
|
|
||||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
|
||||||
Column1: toStop,
|
|
||||||
Status: "stopped",
|
|
||||||
}); err != nil {
|
|
||||||
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -11,8 +11,9 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed openapi.yaml
|
//go:embed openapi.yaml
|
||||||
@ -24,25 +25,34 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New constructs the chi router and registers all routes.
|
// New constructs the chi router and registers all routes.
|
||||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
func New(
|
||||||
|
queries *db.Queries,
|
||||||
|
pool *lifecycle.HostClientPool,
|
||||||
|
sched scheduler.HostScheduler,
|
||||||
|
pgPool *pgxpool.Pool,
|
||||||
|
rdb *redis.Client,
|
||||||
|
jwtSecret []byte,
|
||||||
|
oauthRegistry *oauth.Registry,
|
||||||
|
oauthRedirectURL string,
|
||||||
|
) *Server {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(requestLogger())
|
r.Use(requestLogger())
|
||||||
|
|
||||||
// Shared service layer.
|
// Shared service layer.
|
||||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||||
templateSvc := &service.TemplateService{DB: queries}
|
templateSvc := &service.TemplateService{DB: queries}
|
||||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
|
||||||
teamSvc := &service.TeamService{DB: queries, Pool: pool, Agent: agent}
|
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||||
|
|
||||||
sandbox := newSandboxHandler(sandboxSvc)
|
sandbox := newSandboxHandler(sandboxSvc)
|
||||||
exec := newExecHandler(queries, agent)
|
exec := newExecHandler(queries, pool)
|
||||||
execStream := newExecStreamHandler(queries, agent)
|
execStream := newExecStreamHandler(queries, pool)
|
||||||
files := newFilesHandler(queries, agent)
|
files := newFilesHandler(queries, pool)
|
||||||
filesStream := newFilesStreamHandler(queries, agent)
|
filesStream := newFilesStreamHandler(queries, pool)
|
||||||
snapshots := newSnapshotHandler(templateSvc, queries, agent)
|
snapshots := newSnapshotHandler(templateSvc, queries, pool)
|
||||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
authH := newAuthHandler(queries, pgPool, jwtSecret)
|
||||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||||
hostH := newHostHandler(hostSvc, queries)
|
hostH := newHostHandler(hostSvc, queries)
|
||||||
teamH := newTeamHandler(teamSvc)
|
teamH := newTeamHandler(teamSvc)
|
||||||
@ -123,6 +133,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
|||||||
// Unauthenticated: one-time registration token.
|
// Unauthenticated: one-time registration token.
|
||||||
r.Post("/register", hostH.Register)
|
r.Post("/register", hostH.Register)
|
||||||
|
|
||||||
|
// Unauthenticated: refresh token exchange.
|
||||||
|
r.Post("/auth/refresh", hostH.RefreshToken)
|
||||||
|
|
||||||
// Host-token-authenticated: heartbeat.
|
// Host-token-authenticated: heartbeat.
|
||||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||||
|
|
||||||
@ -134,6 +147,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
|||||||
r.Route("/{id}", func(r chi.Router) {
|
r.Route("/{id}", func(r chi.Router) {
|
||||||
r.Get("/", hostH.Get)
|
r.Get("/", hostH.Get)
|
||||||
r.Delete("/", hostH.Delete)
|
r.Delete("/", hostH.Delete)
|
||||||
|
r.Get("/delete-preview", hostH.DeletePreview)
|
||||||
r.Post("/token", hostH.RegenerateToken)
|
r.Post("/token", hostH.RegenerateToken)
|
||||||
r.Get("/tags", hostH.ListTags)
|
r.Get("/tags", hostH.ListTags)
|
||||||
r.Post("/tags", hostH.AddTag)
|
r.Post("/tags", hostH.AddTag)
|
||||||
|
|||||||
@ -8,7 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const jwtExpiry = 6 * time.Hour
|
const jwtExpiry = 6 * time.Hour
|
||||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||||
|
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||||
|
|
||||||
// Claims are the JWT payload for user tokens.
|
// Claims are the JWT payload for user tokens.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
@ -12,7 +11,6 @@ type Config struct {
|
|||||||
DatabaseURL string
|
DatabaseURL string
|
||||||
RedisURL string
|
RedisURL string
|
||||||
ListenAddr string
|
ListenAddr string
|
||||||
HostAgentAddr string
|
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
|
|
||||||
OAuthGitHubClientID string
|
OAuthGitHubClientID string
|
||||||
@ -27,11 +25,10 @@ func Load() Config {
|
|||||||
// Best-effort load — missing .env file is fine.
|
// Best-effort load — missing .env file is fine.
|
||||||
_ = godotenv.Load()
|
_ = godotenv.Load()
|
||||||
|
|
||||||
cfg := Config{
|
return Config{
|
||||||
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
||||||
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
||||||
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
||||||
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
|
|
||||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||||
|
|
||||||
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||||
@ -39,13 +36,6 @@ func Load() Config {
|
|||||||
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||||
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the host agent address has a scheme.
|
|
||||||
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
|
|
||||||
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func envOrDefault(key, def string) string {
|
func envOrDefault(key, def string) string {
|
||||||
|
|||||||
92
internal/db/host_refresh_tokens.sql.go
Normal file
92
internal/db/host_refresh_tokens.sql.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.30.0
|
||||||
|
// source: host_refresh_tokens.sql
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
)
|
||||||
|
|
||||||
|
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
|
||||||
|
DELETE FROM host_refresh_tokens
|
||||||
|
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
|
||||||
|
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
|
||||||
|
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
|
||||||
|
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
|
||||||
|
var i HostRefreshToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.HostID,
|
||||||
|
&i.TokenHash,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.RevokedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
|
||||||
|
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertHostRefreshTokenParams struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
HostID string `json:"host_id"`
|
||||||
|
TokenHash string `json:"token_hash"`
|
||||||
|
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
|
||||||
|
row := q.db.QueryRow(ctx, insertHostRefreshToken,
|
||||||
|
arg.ID,
|
||||||
|
arg.HostID,
|
||||||
|
arg.TokenHash,
|
||||||
|
arg.ExpiresAt,
|
||||||
|
)
|
||||||
|
var i HostRefreshToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.HostID,
|
||||||
|
&i.TokenHash,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.RevokedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
|
||||||
|
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
|
||||||
|
UPDATE host_refresh_tokens SET revoked_at = NOW()
|
||||||
|
WHERE host_id = $1 AND revoked_at IS NULL
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@ -234,6 +234,50 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const listActiveHosts = `-- name: ListActiveHosts :many
|
||||||
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
|
||||||
|
`
|
||||||
|
|
||||||
|
// Returns all hosts that have completed registration (not pending/offline).
|
||||||
|
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
|
||||||
|
rows, err := q.db.Query(ctx, listActiveHosts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []Host
|
||||||
|
for rows.Next() {
|
||||||
|
var i Host
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Type,
|
||||||
|
&i.TeamID,
|
||||||
|
&i.Provider,
|
||||||
|
&i.AvailabilityZone,
|
||||||
|
&i.Arch,
|
||||||
|
&i.CpuCores,
|
||||||
|
&i.MemoryMb,
|
||||||
|
&i.DiskGb,
|
||||||
|
&i.Address,
|
||||||
|
&i.Status,
|
||||||
|
&i.LastHeartbeatAt,
|
||||||
|
&i.Metadata,
|
||||||
|
&i.CreatedBy,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.CertFingerprint,
|
||||||
|
&i.MtlsEnabled,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const listHosts = `-- name: ListHosts :many
|
const listHosts = `-- name: ListHosts :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
@ -461,6 +505,15 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
|
||||||
|
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error {
|
||||||
|
_, err := q.db.Exec(ctx, markHostUnreachable, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const registerHost = `-- name: RegisterHost :execrows
|
const registerHost = `-- name: RegisterHost :execrows
|
||||||
UPDATE hosts
|
UPDATE hosts
|
||||||
SET arch = $2,
|
SET arch = $2,
|
||||||
@ -521,6 +574,20 @@ func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :exec
|
||||||
|
UPDATE hosts
|
||||||
|
SET last_heartbeat_at = NOW(),
|
||||||
|
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
|
||||||
|
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) error {
|
||||||
|
_, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
||||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
|
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|||||||
@ -36,6 +36,15 @@ type Host struct {
|
|||||||
MtlsEnabled bool `json:"mtls_enabled"`
|
MtlsEnabled bool `json:"mtls_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HostRefreshToken struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
HostID string `json:"host_id"`
|
||||||
|
TokenHash string `json:"token_hash"`
|
||||||
|
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||||
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
|
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type HostTag struct {
|
type HostTag struct {
|
||||||
HostID string `json:"host_id"`
|
HostID string `json:"host_id"`
|
||||||
Tag string `json:"tag"`
|
Tag string `json:"tag"`
|
||||||
|
|||||||
@ -11,6 +11,20 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = 'running',
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = ANY($1::text[]) AND status = 'missing'
|
||||||
|
`
|
||||||
|
|
||||||
|
// Called by the reconciler when a host comes back online and its sandboxes are
|
||||||
|
// confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||||
|
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error {
|
||||||
|
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
|
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
|
||||||
UPDATE sandboxes
|
UPDATE sandboxes
|
||||||
SET status = $2,
|
SET status = $2,
|
||||||
@ -300,6 +314,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = 'missing',
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
|
||||||
|
`
|
||||||
|
|
||||||
|
// Called when the host monitor marks a host unreachable.
|
||||||
|
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
|
||||||
|
// the sandbox is not currently reachable, without permanently losing the record.
|
||||||
|
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error {
|
||||||
|
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const updateLastActive = `-- name: UpdateLastActive :exec
|
const updateLastActive = `-- name: UpdateLastActive :exec
|
||||||
UPDATE sandboxes
|
UPDATE sandboxes
|
||||||
SET last_active_at = $2,
|
SET last_active_at = $2,
|
||||||
|
|||||||
@ -17,6 +17,13 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt.
|
||||||
|
type tokenFile struct {
|
||||||
|
HostID string `json:"host_id"`
|
||||||
|
JWT string `json:"jwt"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
// RegistrationConfig holds the configuration for host registration.
|
// RegistrationConfig holds the configuration for host registration.
|
||||||
type RegistrationConfig struct {
|
type RegistrationConfig struct {
|
||||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||||
@ -37,6 +44,17 @@ type registerRequest struct {
|
|||||||
type registerResponse struct {
|
type registerResponse struct {
|
||||||
Host json.RawMessage `json:"host"`
|
Host json.RawMessage `json:"host"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshResponse struct {
|
||||||
|
Host json.RawMessage `json:"host"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type errorResponse struct {
|
type errorResponse struct {
|
||||||
@ -46,20 +64,46 @@ type errorResponse struct {
|
|||||||
} `json:"error"`
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadTokenFile reads and parses the persisted token file.
|
||||||
|
func loadTokenFile(path string) (*tokenFile, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Support legacy format (raw JWT string) for backwards compatibility.
|
||||||
|
trimmed := strings.TrimSpace(string(data))
|
||||||
|
if !strings.HasPrefix(trimmed, "{") {
|
||||||
|
// Old format: just the JWT, no refresh token.
|
||||||
|
hostID, _ := hostIDFromJWT(trimmed)
|
||||||
|
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
|
||||||
|
}
|
||||||
|
var tf tokenFile
|
||||||
|
if err := json.Unmarshal(data, &tf); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse token file: %w", err)
|
||||||
|
}
|
||||||
|
return &tf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTokenFile writes the token file as JSON with 0600 permissions.
|
||||||
|
func saveTokenFile(path string, tf tokenFile) error {
|
||||||
|
data, err := json.MarshalIndent(tf, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal token file: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
// Register calls the control plane to register this host agent and persists
|
// Register calls the control plane to register this host agent and persists
|
||||||
// the returned JWT to disk. Returns the host JWT token string.
|
// the returned JWT and refresh token to disk. Returns the host JWT token string.
|
||||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||||
// Check if we already have a saved token.
|
// Check if we already have a saved token.
|
||||||
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
|
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||||
token := strings.TrimSpace(string(data))
|
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||||
if token != "" {
|
return tf.JWT, nil
|
||||||
slog.Info("loaded existing host token", "file", cfg.TokenFile)
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.RegistrationToken == "" {
|
if cfg.RegistrationToken == "" {
|
||||||
return "", fmt.Errorf("no saved host token and no registration token provided")
|
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
|
||||||
}
|
}
|
||||||
|
|
||||||
arch := runtime.GOARCH
|
arch := runtime.GOARCH
|
||||||
@ -117,45 +161,155 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
|||||||
return "", fmt.Errorf("registration response missing token")
|
return "", fmt.Errorf("registration response missing token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist the token to disk for subsequent startups.
|
hostID, err := hostIDFromJWT(regResp.Token)
|
||||||
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("extract host ID from JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist JWT + refresh token.
|
||||||
|
tf := tokenFile{
|
||||||
|
HostID: hostID,
|
||||||
|
JWT: regResp.Token,
|
||||||
|
RefreshToken: regResp.RefreshToken,
|
||||||
|
}
|
||||||
|
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
|
||||||
return "", fmt.Errorf("save host token: %w", err)
|
return "", fmt.Errorf("save host token: %w", err)
|
||||||
}
|
}
|
||||||
slog.Info("host registered and token saved", "file", cfg.TokenFile)
|
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
|
||||||
|
|
||||||
return regResp.Token, nil
|
return regResp.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
|
||||||
|
// It reads and updates the token file in place.
|
||||||
|
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
|
||||||
|
tf, err := loadTokenFile(tokenFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("load token file: %w", err)
|
||||||
|
}
|
||||||
|
if tf.RefreshToken == "" {
|
||||||
|
return "", fmt.Errorf("no refresh token available; host must re-register")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
|
||||||
|
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 15 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp errorResponse
|
||||||
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
|
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
var refResp refreshResponse
|
||||||
|
if err := json.Unmarshal(respBody, &refResp); err != nil {
|
||||||
|
return "", fmt.Errorf("parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tf.JWT = refResp.Token
|
||||||
|
tf.RefreshToken = refResp.RefreshToken
|
||||||
|
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
|
||||||
|
return "", fmt.Errorf("save refreshed token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("host JWT refreshed", "host_id", tf.HostID)
|
||||||
|
return refResp.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||||
// to the control plane. It runs until the context is cancelled.
|
// to the control plane. It runs until the context is cancelled.
|
||||||
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
|
//
|
||||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
|
||||||
|
// also fails (expired refresh token), it calls pauseAll and stops.
|
||||||
|
//
|
||||||
|
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
||||||
|
// retrying — the connection may recover and the host should resume heartbeating.
|
||||||
|
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func()) {
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
consecutiveFailures := 0
|
||||||
|
pausedDueToFailure := false
|
||||||
|
currentJWT := ""
|
||||||
|
|
||||||
|
// Load the current JWT from disk.
|
||||||
|
if tf, err := loadTokenFile(tokenFilePath); err == nil {
|
||||||
|
currentJWT = tf.JWT
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
req.Header.Set("X-Host-Token", hostToken)
|
req.Header.Set("X-Host-Token", currentJWT)
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("heartbeat: request failed", "error", err)
|
consecutiveFailures++
|
||||||
|
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
|
||||||
|
|
||||||
|
if consecutiveFailures >= 3 && !pausedDueToFailure {
|
||||||
|
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
|
||||||
|
if pauseAll != nil {
|
||||||
|
pauseAll()
|
||||||
|
}
|
||||||
|
pausedDueToFailure = true
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
switch resp.StatusCode {
|
||||||
|
case http.StatusNoContent:
|
||||||
|
// Success.
|
||||||
|
if consecutiveFailures > 0 || pausedDueToFailure {
|
||||||
|
slog.Info("heartbeat: CP connection restored")
|
||||||
|
}
|
||||||
|
consecutiveFailures = 0
|
||||||
|
pausedDueToFailure = false
|
||||||
|
|
||||||
|
case http.StatusUnauthorized, http.StatusForbidden:
|
||||||
|
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
|
||||||
|
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
|
||||||
|
if refreshErr != nil {
|
||||||
|
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
|
||||||
|
"error", refreshErr)
|
||||||
|
if pauseAll != nil && !pausedDueToFailure {
|
||||||
|
pauseAll()
|
||||||
|
pausedDueToFailure = true
|
||||||
|
}
|
||||||
|
// Stop the heartbeat loop — operator must re-register.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
currentJWT = newJWT
|
||||||
|
slog.Info("heartbeat: JWT refreshed successfully")
|
||||||
|
|
||||||
|
default:
|
||||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -166,6 +320,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
|
|||||||
// HostIDFromToken extracts the host_id claim from a host JWT without
|
// HostIDFromToken extracts the host_id claim from a host JWT without
|
||||||
// verifying the signature (the agent doesn't have the signing secret).
|
// verifying the signature (the agent doesn't have the signing secret).
|
||||||
func HostIDFromToken(token string) (string, error) {
|
func HostIDFromToken(token string) (string, error) {
|
||||||
|
return hostIDFromJWT(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
|
||||||
|
// the token file loader.
|
||||||
|
func hostIDFromJWT(token string) (string, error) {
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return "", fmt.Errorf("invalid JWT format")
|
return "", fmt.Errorf("invalid JWT format")
|
||||||
|
|||||||
@ -67,3 +67,17 @@ func NewRegistrationToken() string {
|
|||||||
}
|
}
|
||||||
return hex.EncodeToString(b)
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars.
|
||||||
|
func NewRefreshTokenID() string {
|
||||||
|
return "hrt-" + hex8()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token.
|
||||||
|
func NewRefreshToken() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|||||||
77
internal/lifecycle/hostpool.go
Normal file
77
internal/lifecycle/hostpool.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
|
||||||
|
// Clients are created lazily on first access and evicted when a host is removed
|
||||||
|
// or goes unreachable. The pool is safe for concurrent use.
|
||||||
|
type HostClientPool struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
clients map[string]hostagentv1connect.HostAgentServiceClient
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostClientPool creates a new pool. The underlying HTTP client uses a
|
||||||
|
// 10-minute timeout to support long-running streaming operations.
|
||||||
|
func NewHostClientPool() *HostClientPool {
|
||||||
|
return &HostClientPool{
|
||||||
|
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||||
|
httpClient: &http.Client{Timeout: 10 * time.Minute},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a Connect RPC client for the given host, creating one if necessary.
|
||||||
|
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
|
||||||
|
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
|
||||||
|
p.mu.RLock()
|
||||||
|
c, ok := p.clients[hostID]
|
||||||
|
p.mu.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
// Double-check after acquiring write lock.
|
||||||
|
if c, ok = p.clients[hostID]; ok {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address))
|
||||||
|
p.clients[hostID] = c
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetForHost is a convenience wrapper that extracts the address from a db.Host
|
||||||
|
// and returns an error if the host has no address recorded yet.
|
||||||
|
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
|
||||||
|
if !h.Address.Valid || h.Address.String == "" {
|
||||||
|
return nil, fmt.Errorf("host %s has no address", h.ID)
|
||||||
|
}
|
||||||
|
return p.Get(h.ID, h.Address.String), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict removes the cached client for the given host, forcing a new client to be
|
||||||
|
// created on the next call to Get. Call this when a host's address changes or when
|
||||||
|
// a host is deleted.
|
||||||
|
func (p *HostClientPool) Evict(hostID string) {
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.clients, hostID)
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureScheme adds "http://" if the address has no scheme.
|
||||||
|
func ensureScheme(addr string) string {
|
||||||
|
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return "http://" + addr
|
||||||
|
}
|
||||||
@ -1183,6 +1183,28 @@ func (m *Manager) Shutdown(ctx context.Context) {
|
|||||||
m.loops.ReleaseAll()
|
m.loops.ReleaseAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PauseAll pauses every running sandbox managed by this host agent.
|
||||||
|
// Called when the host loses connectivity to the control plane to avoid
|
||||||
|
// leaving running VMs unmanaged. It is best-effort: failures for individual
|
||||||
|
// sandboxes are logged but do not stop the rest.
|
||||||
|
func (m *Manager) PauseAll(ctx context.Context) {
|
||||||
|
m.mu.RLock()
|
||||||
|
ids := make([]string, 0, len(m.boxes))
|
||||||
|
for id, sb := range m.boxes {
|
||||||
|
if sb.Status == models.StatusRunning {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
slog.Info("pausing all running sandboxes due to CP connection loss", "count", len(ids))
|
||||||
|
for _, sbID := range ids {
|
||||||
|
if err := m.Pause(ctx, sbID); err != nil {
|
||||||
|
slog.Warn("PauseAll: failed to pause sandbox", "id", sbID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup
|
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup
|
||||||
// in error paths where the primary error has already been captured.
|
// in error paths where the primary error has already been captured.
|
||||||
func warnErr(msg string, id string, err error) {
|
func warnErr(msg string, id string, err error) {
|
||||||
|
|||||||
51
internal/scheduler/round_robin.go
Normal file
51
internal/scheduler/round_robin.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostScheduler selects a host for a new sandbox. Implementations may use
|
||||||
|
// different strategies (round-robin, least-loaded, tag-based, etc.).
|
||||||
|
type HostScheduler interface {
|
||||||
|
// SelectHost returns a host that can accept a new sandbox.
|
||||||
|
// Returns an error if no suitable host is available.
|
||||||
|
SelectHost(ctx context.Context) (db.Host, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundRobinScheduler cycles through online hosts in round-robin order.
|
||||||
|
// It re-fetches the host list on every call so that newly registered or
|
||||||
|
// recovered hosts are considered immediately.
|
||||||
|
type RoundRobinScheduler struct {
|
||||||
|
db *db.Queries
|
||||||
|
counter atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
|
||||||
|
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
|
||||||
|
return &RoundRobinScheduler{db: queries}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectHost returns the next online host in round-robin order.
|
||||||
|
func (s *RoundRobinScheduler) SelectHost(ctx context.Context) (db.Host, error) {
|
||||||
|
hosts, err := s.db.ListActiveHosts(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return db.Host{}, fmt.Errorf("list hosts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var online []db.Host
|
||||||
|
for _, h := range hosts {
|
||||||
|
if h.Status == "online" && h.Address.Valid && h.Address.String != "" {
|
||||||
|
online = append(online, h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(online) == 0 {
|
||||||
|
return db.Host{}, fmt.Errorf("no online hosts available")
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := s.counter.Add(1) - 1
|
||||||
|
return online[int(idx%int64(len(online)))], nil
|
||||||
|
}
|
||||||
@ -2,12 +2,14 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@ -15,6 +17,8 @@ import (
|
|||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HostService provides host management operations.
|
// HostService provides host management operations.
|
||||||
@ -22,6 +26,7 @@ type HostService struct {
|
|||||||
DB *db.Queries
|
DB *db.Queries
|
||||||
Redis *redis.Client
|
Redis *redis.Client
|
||||||
JWT []byte
|
JWT []byte
|
||||||
|
Pool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostCreateParams holds the parameters for creating a host.
|
// HostCreateParams holds the parameters for creating a host.
|
||||||
@ -50,10 +55,24 @@ type HostRegisterParams struct {
|
|||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostRegisterResult holds the registered host and its long-lived JWT.
|
// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token.
|
||||||
type HostRegisterResult struct {
|
type HostRegisterResult struct {
|
||||||
Host db.Host
|
Host db.Host
|
||||||
JWT string
|
JWT string
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh.
|
||||||
|
type HostRefreshResult struct {
|
||||||
|
Host db.Host
|
||||||
|
JWT string
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostDeletePreview describes what will be affected by deleting a host.
|
||||||
|
type HostDeletePreview struct {
|
||||||
|
Host db.Host
|
||||||
|
SandboxIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||||
@ -64,6 +83,14 @@ type regTokenPayload struct {
|
|||||||
|
|
||||||
const regTokenTTL = time.Hour
|
const regTokenTTL = time.Hour
|
||||||
|
|
||||||
|
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||||
|
func requireAdminOrOwner(role string) error {
|
||||||
|
if role == "owner" || role == "admin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
|
||||||
|
}
|
||||||
|
|
||||||
// Create creates a new host record and generates a one-time registration token.
|
// Create creates a new host record and generates a one-time registration token.
|
||||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||||
if p.Type != "regular" && p.Type != "byoc" {
|
if p.Type != "regular" && p.Type != "byoc" {
|
||||||
@ -75,7 +102,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
|||||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// BYOC: admin or team owner.
|
// BYOC: platform admin, or team owner/admin.
|
||||||
if p.TeamID == "" {
|
if p.TeamID == "" {
|
||||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||||
}
|
}
|
||||||
@ -90,8 +117,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||||
}
|
}
|
||||||
if membership.Role != "owner" {
|
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
|
return HostCreateResult{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,7 +195,6 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
|||||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same permission model as Create/Delete.
|
|
||||||
if !isAdmin {
|
if !isAdmin {
|
||||||
if host.Type != "byoc" {
|
if host.Type != "byoc" {
|
||||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||||
@ -186,8 +212,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||||
}
|
}
|
||||||
if membership.Role != "owner" {
|
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
|
return HostCreateResult{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,7 +242,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register validates a one-time registration token, updates the host with
|
// Register validates a one-time registration token, updates the host with
|
||||||
// machine specs, and returns a long-lived host JWT.
|
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
|
||||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||||
// preventing concurrent requests from consuming the same token.
|
// preventing concurrent requests from consuming the same token.
|
||||||
@ -264,18 +290,89 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
|||||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issue a long-lived refresh token.
|
||||||
|
refreshToken, err := s.issueRefreshToken(ctx, payload.HostID)
|
||||||
|
if err != nil {
|
||||||
|
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Re-fetch the host to get the updated state.
|
// Re-fetch the host to get the updated state.
|
||||||
host, err := s.DB.GetHost(ctx, payload.HostID)
|
host, err := s.DB.GetHost(ctx, payload.HostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
|
return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates the last heartbeat timestamp for a host.
|
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||||
|
// and returns a fresh JWT plus the new refresh token.
|
||||||
|
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
|
||||||
|
hash := hashToken(refreshToken)
|
||||||
|
|
||||||
|
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
host, err := s.DB.GetHost(ctx, row.HostID)
|
||||||
|
if err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign new JWT.
|
||||||
|
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
|
||||||
|
if err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||||
|
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||||
|
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||||
|
if err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke old refresh token after the new one is safely persisted.
|
||||||
|
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||||
|
// the opaque token string.
|
||||||
|
func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) {
|
||||||
|
token := id.NewRefreshToken()
|
||||||
|
hash := hashToken(token)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
|
||||||
|
ID: id.NewRefreshTokenID(),
|
||||||
|
HostID: hostID,
|
||||||
|
TokenHash: hash,
|
||||||
|
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
|
||||||
|
}); err != nil {
|
||||||
|
return "", fmt.Errorf("insert refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashToken returns the hex-encoded SHA-256 hash of the token.
|
||||||
|
func hashToken(token string) string {
|
||||||
|
h := sha256.Sum256([]byte(token))
|
||||||
|
return fmt.Sprintf("%x", h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||||
|
// any 'unreachable' host back to 'online'.
|
||||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||||
return s.DB.UpdateHostHeartbeat(ctx, hostID)
|
return s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List returns hosts visible to the caller.
|
// List returns hosts visible to the caller.
|
||||||
@ -301,37 +398,135 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo
|
|||||||
return host, nil
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a host. Admins can delete any host. Team owners can delete
|
// DeletePreview returns what would be affected by deleting the host, without
|
||||||
// BYOC hosts belonging to their team.
|
// making any changes. Use this to show the user a confirmation prompt.
|
||||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
|
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) {
|
||||||
host, err := s.DB.GetHost(ctx, hostID)
|
host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("host not found: %w", err)
|
return HostDeletePreview{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||||
|
HostID: hostID,
|
||||||
|
Column2: []string{"pending", "starting", "running", "missing"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]string, len(sandboxes))
|
||||||
|
for i, sb := range sandboxes {
|
||||||
|
ids[i] = sb.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a host. Without force it returns an error listing active
|
||||||
|
// sandboxes so the caller can present a confirmation. With force it gracefully
|
||||||
|
// destroys all running sandboxes before deleting the host record.
|
||||||
|
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error {
|
||||||
|
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||||
|
HostID: hostID,
|
||||||
|
Column2: []string{"pending", "starting", "running", "missing"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list sandboxes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sandboxes) > 0 && !force {
|
||||||
|
ids := make([]string, len(sandboxes))
|
||||||
|
for i, sb := range sandboxes {
|
||||||
|
ids[i] = sb.ID
|
||||||
|
}
|
||||||
|
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gracefully destroy running sandboxes on the host agent (best-effort).
|
||||||
|
if len(sandboxes) > 0 && host.Address.Valid && host.Address.String != "" {
|
||||||
|
agent, err := s.Pool.GetForHost(host)
|
||||||
|
if err == nil {
|
||||||
|
for _, sb := range sandboxes {
|
||||||
|
if sb.Status == "running" || sb.Status == "starting" {
|
||||||
|
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||||
|
SandboxId: sb.ID,
|
||||||
|
}))
|
||||||
|
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
|
||||||
|
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark all affected sandboxes as stopped in DB.
|
||||||
|
if len(sandboxes) > 0 {
|
||||||
|
sbIDs := make([]string, len(sandboxes))
|
||||||
|
for i, sb := range sandboxes {
|
||||||
|
sbIDs[i] = sb.ID
|
||||||
|
}
|
||||||
|
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||||
|
Column1: sbIDs,
|
||||||
|
Status: "stopped",
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke all refresh tokens for this host.
|
||||||
|
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
|
||||||
|
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict the client from the pool so no further RPCs are sent.
|
||||||
|
if s.Pool != nil {
|
||||||
|
s.Pool.Evict(hostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.DB.DeleteHost(ctx, hostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkDeletePermission verifies the caller has permission to delete the given
|
||||||
|
// host and returns the host record on success.
|
||||||
|
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) {
|
||||||
|
host, err := s.DB.GetHost(ctx, hostID)
|
||||||
|
if err != nil {
|
||||||
|
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAdmin {
|
||||||
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isAdmin {
|
|
||||||
if host.Type != "byoc" {
|
if host.Type != "byoc" {
|
||||||
return fmt.Errorf("forbidden: only admins can delete regular hosts")
|
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||||
}
|
}
|
||||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||||
return fmt.Errorf("forbidden: host does not belong to your team")
|
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userID != "" {
|
||||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
TeamID: teamID,
|
TeamID: teamID,
|
||||||
})
|
})
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return fmt.Errorf("forbidden: not a member of the specified team")
|
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("check team membership: %w", err)
|
return db.Host{}, fmt.Errorf("check team membership: %w", err)
|
||||||
}
|
}
|
||||||
if membership.Role != "owner" {
|
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||||
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
|
return db.Host{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.DB.DeleteHost(ctx, hostID)
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTag adds a tag to a host.
|
// AddTag adds a tag to a host.
|
||||||
@ -357,3 +552,14 @@ func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdm
|
|||||||
}
|
}
|
||||||
return s.DB.GetHostTags(ctx, hostID)
|
return s.DB.GetHostTags(ctx, hostID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
|
||||||
|
// and force was not set. The caller should present the list to the user and
|
||||||
|
// re-call Delete with force=true if they confirm.
|
||||||
|
type HostHasSandboxesError struct {
|
||||||
|
SandboxIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HostHasSandboxesError) Error() string {
|
||||||
|
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
|
||||||
|
}
|
||||||
|
|||||||
@ -11,16 +11,18 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SandboxService provides sandbox lifecycle operations shared between the
|
// SandboxService provides sandbox lifecycle operations shared between the
|
||||||
// REST API and the dashboard.
|
// REST API and the dashboard.
|
||||||
type SandboxService struct {
|
type SandboxService struct {
|
||||||
DB *db.Queries
|
DB *db.Queries
|
||||||
Agent hostagentv1connect.HostAgentServiceClient
|
Pool *lifecycle.HostClientPool
|
||||||
|
Scheduler scheduler.HostScheduler
|
||||||
}
|
}
|
||||||
|
|
||||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||||
@ -32,8 +34,34 @@ type SandboxCreateParams struct {
|
|||||||
TimeoutSec int32
|
TimeoutSec int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
|
// agentForSandbox looks up the host for the given sandbox and returns a client.
|
||||||
// and updates the record to running. Returns the sandbox DB row.
|
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) {
|
||||||
|
sb, err := s.DB.GetSandbox(ctx, sandboxID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||||
|
}
|
||||||
|
host, err := s.DB.GetHost(ctx, sb.HostID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
|
||||||
|
}
|
||||||
|
agent, err := s.Pool.GetForHost(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||||
|
}
|
||||||
|
return agent, sb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostagentClient is a local alias to avoid the full package path in signatures.
|
||||||
|
type hostagentClient = interface {
|
||||||
|
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||||
|
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||||
|
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
|
||||||
|
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
|
||||||
|
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
|
||||||
|
// DB record, calls the host agent, and updates the record to running.
|
||||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||||
if p.Template == "" {
|
if p.Template == "" {
|
||||||
p.Template = "minimal"
|
p.Template = "minimal"
|
||||||
@ -58,12 +86,23 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pick a host for this sandbox.
|
||||||
|
host, err := s.Scheduler.SelectHost(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agent, err := s.Pool.GetForHost(host)
|
||||||
|
if err != nil {
|
||||||
|
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
sandboxID := id.NewSandboxID()
|
sandboxID := id.NewSandboxID()
|
||||||
|
|
||||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||||
ID: sandboxID,
|
ID: sandboxID,
|
||||||
TeamID: p.TeamID,
|
TeamID: p.TeamID,
|
||||||
HostID: "default",
|
HostID: host.ID,
|
||||||
Template: p.Template,
|
Template: p.Template,
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
Vcpus: p.VCPUs,
|
Vcpus: p.VCPUs,
|
||||||
@ -73,7 +112,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
|||||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
Template: p.Template,
|
Template: p.Template,
|
||||||
Vcpus: p.VCPUs,
|
Vcpus: p.VCPUs,
|
||||||
@ -126,7 +165,12 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
|
|||||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||||
|
if err != nil {
|
||||||
|
return db.Sandbox{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
})); err != nil {
|
})); err != nil {
|
||||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||||
@ -151,7 +195,12 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
|
|||||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||||
|
if err != nil {
|
||||||
|
return db.Sandbox{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
TimeoutSec: sb.TimeoutSec,
|
TimeoutSec: sb.TimeoutSec,
|
||||||
}))
|
}))
|
||||||
@ -181,8 +230,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
|
|||||||
return fmt.Errorf("sandbox not found: %w", err)
|
return fmt.Errorf("sandbox not found: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||||
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||||
return fmt.Errorf("agent destroy: %w", err)
|
return fmt.Errorf("agent destroy: %w", err)
|
||||||
@ -206,7 +260,12 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
|
|||||||
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||||
SandboxId: sandboxID,
|
SandboxId: sandboxID,
|
||||||
})); err != nil {
|
})); err != nil {
|
||||||
return fmt.Errorf("agent ping: %w", err)
|
return fmt.Errorf("agent ping: %w", err)
|
||||||
|
|||||||
@ -14,8 +14,8 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
||||||
@ -24,7 +24,7 @@ var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
|||||||
type TeamService struct {
|
type TeamService struct {
|
||||||
DB *db.Queries
|
DB *db.Queries
|
||||||
Pool *pgxpool.Pool
|
Pool *pgxpool.Pool
|
||||||
Agent hostagentv1connect.HostAgentServiceClient
|
HostPool *lifecycle.HostClientPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TeamWithRole pairs a team with the calling user's role in it.
|
// TeamWithRole pairs a team with the calling user's role in it.
|
||||||
@ -177,11 +177,17 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin
|
|||||||
|
|
||||||
var stopIDs []string
|
var stopIDs []string
|
||||||
for _, sb := range sandboxes {
|
for _, sb := range sandboxes {
|
||||||
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
|
||||||
|
if hostErr == nil {
|
||||||
|
agent, agentErr := s.HostPool.GetForHost(host)
|
||||||
|
if agentErr == nil {
|
||||||
|
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||||
SandboxId: sb.ID,
|
SandboxId: sb.ID,
|
||||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||||
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
|
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
stopIDs = append(stopIDs, sb.ID)
|
stopIDs = append(stopIDs, sb.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user