diff --git a/.env.example b/.env.example index dee152cf..62e9b4ce 100644 --- a/.env.example +++ b/.env.example @@ -5,13 +5,13 @@ DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable REDIS_URL=redis://localhost:6379/0 # Control Plane -CP_LISTEN_ADDR=:8000 +WRENN_CP_LISTEN_ADDR=:8080 # Host Agent -AGENT_LISTEN_ADDR=:50051 -AGENT_FILES_ROOTDIR=/var/lib/wrenn -AGENT_HOST_INTERFACE=eth0 -AGENT_CP_URL=http://localhost:8000 +WRENN_HOST_LISTEN_ADDR=:50051 +WRENN_DIR=/var/lib/wrenn +WRENN_HOST_INTERFACE=eth0 +WRENN_CP_URL=http://localhost:8080 # Lago (billing — external service) LAGO_API_URL=http://localhost:3000 @@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY= # Auth JWT_SECRET= +# mTLS — CP→Agent channel +# Generate a self-signed CA with: +# openssl ecparam -genkey -name P-256 -noout -out ca.key +# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca" +# Then set these to the file contents (newlines replaced with \n or use multiline env). +WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE----- +WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY----- + # OAuth OAUTH_GITHUB_CLIENT_ID= OAUTH_GITHUB_CLIENT_SECRET= diff --git a/Makefile b/Makefile index 4a2e0b61..80fbd3a7 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ dev: dev-infra migrate-up dev-cp dev-infra: docker compose -f deploy/docker-compose.dev.yml up -d @echo "Waiting for PostgreSQL..." - @until pg_isready -h localhost -p 5432 -q; do sleep 0.5; done + @until docker compose -f deploy/docker-compose.dev.yml exec -T postgres pg_isready -q 2>/dev/null; do sleep 0.5; done @echo "Dev infrastructure ready." dev-down: @@ -53,7 +53,7 @@ dev-agent: sudo go run ./cmd/host-agent dev-frontend: - cd frontend && pnpm dev --port 5173 + cd frontend && pnpm dev --port 5173 --host 0.0.0.0 dev-envd: cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002 diff --git a/README.md b/README.md index dff19325..e2b290f0 100644 --- a/README.md +++ b/README.md @@ -51,12 +51,12 @@ Copy `.env.example` to `.env` and edit: DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable # Control plane -CP_LISTEN_ADDR=:8000 +WRENN_CP_LISTEN_ADDR=:8000 CP_HOST_AGENT_ADDR=http://localhost:50051 # Host agent -AGENT_LISTEN_ADDR=:50051 -AGENT_FILES_ROOTDIR=/var/lib/wrenn +WRENN_HOST_LISTEN_ADDR=:50051 +WRENN_DIR=/var/lib/wrenn ``` ### Run @@ -69,7 +69,7 @@ make migrate-up ./builds/wrenn-cp ``` -Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). +Control plane listens on `WRENN_CP_LISTEN_ADDR` (default `:8000`). ### Host registration @@ -87,16 +87,16 @@ Hosts must be registered with the control plane before they can serve sandboxes. 2. **Start the host agent** with the registration token and its externally-reachable address: ```bash - sudo AGENT_CP_URL=http://cp-host:8000 \ + sudo WRENN_CP_URL=http://cp-host:8000 \ ./builds/wrenn-agent \ --register \ --address 10.0.1.5:50051 ``` - On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$AGENT_FILES_ROOTDIR/host-token`. + On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$WRENN_DIR/host-token`. 3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically: ```bash - sudo AGENT_CP_URL=http://cp-host:8000 \ + sudo WRENN_CP_URL=http://cp-host:8000 \ ./builds/wrenn-agent --address 10.0.1.5:50051 ``` @@ -107,7 +107,7 @@ Hosts must be registered with the control plane before they can serve sandboxes. ``` Then restart the agent with the new token. -The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`). +The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `WRENN_HOST_LISTEN_ADDR` (default `:50051`). ### Rootfs images diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 9f84edcc..419dc875 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -15,6 +15,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/api" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -68,8 +69,52 @@ func main() { } slog.Info("connected to redis") + // mTLS: parse internal CA and build a TLS-capable host client pool. + // When CA env vars are absent the pool falls back to plain HTTP (dev mode). + var ca *auth.CA + if cfg.CACert != "" && cfg.CAKey != "" { + var err error + ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey) + if err != nil { + slog.Error("failed to parse mTLS CA from environment", "error", err) + os.Exit(1) + } + slog.Info("mTLS enabled: CA loaded") + } else { + slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted") + } + // Host client pool — manages Connect RPC clients to host agents. - hostPool := lifecycle.NewHostClientPool() + var hostPool *lifecycle.HostClientPool + if ca != nil { + cpCertStore, err := auth.NewCPCertStore(ca) + if err != nil { + slog.Error("failed to issue CP client certificate", "error", err) + os.Exit(1) + } + // Renew the CP client certificate periodically so it never expires + // while the control plane is running (TTL = 24h, renewal = every 12h). + go func() { + ticker := time.NewTicker(auth.CPCertRenewInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := cpCertStore.Refresh(); err != nil { + slog.Error("failed to renew CP client certificate", "error", err) + } else { + slog.Info("CP client certificate renewed") + } + } + } + }() + hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) + slog.Info("host client pool: mTLS enabled") + } else { + hostPool = lifecycle.NewHostClientPool() + } // Scheduler — picks a host for each new sandbox (round-robin for now). hostScheduler := scheduler.NewRoundRobinScheduler(queries) @@ -88,7 +133,11 @@ func main() { } // API server. - srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca) + + // Start template build workers (2 concurrent). + stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2) + defer stopBuildWorkers() // Start host monitor (passive + active reconciliation every 30s). monitor := api.NewHostMonitor(queries, hostPool, audit.New(queries), 30*time.Second) @@ -98,9 +147,14 @@ func main() { sampler := api.NewMetricsSampler(queries, 10*time.Second) sampler.Start(ctx) + // Wrap the API handler with the sandbox proxy so that requests with + // {port}-{sandbox_id}.{domain} Host headers are routed to the sandbox's + // host agent. All other requests pass through to the normal API router. + proxyWrapper := api.NewSandboxProxyWrapper(srv.Handler(), queries, hostPool) + httpServer := &http.Server{ Addr: cfg.ListenAddr, - Handler: srv.Handler(), + Handler: proxyWrapper, } // Graceful shutdown on signal. diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 2d34cd1d..de48a664 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -2,8 +2,10 @@ package main import ( "context" + "crypto/tls" "flag" "log/slog" + "net" "net/http" "os" "os/signal" @@ -14,8 +16,10 @@ import ( "github.com/joho/godotenv" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/hostagent" + "git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/sandbox" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -42,16 +46,17 @@ func main() { slog.Warn("failed to enable ip_forward", "error", err) } - // Clean up any stale dm-snapshot devices from a previous crash. + // Clean up stale resources from a previous crash. devicemapper.CleanupStaleDevices() + network.CleanupStaleNamespaces() - listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051") - rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn") - cpURL := os.Getenv("AGENT_CP_URL") - tokenFile := filepath.Join(rootDir, "host.jwt") + listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") + rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") + cpURL := os.Getenv("WRENN_CP_URL") + credsFile := filepath.Join(rootDir, "host-credentials.json") if cpURL == "" { - slog.Error("AGENT_CP_URL environment variable is required") + slog.Error("WRENN_CP_URL environment variable is required") os.Exit(1) } if *advertiseAddr == "" { @@ -59,11 +64,15 @@ func main() { os.Exit(1) } + // Expand base images to the standard disk size (sparse, no extra physical + // disk). This ensures dm-snapshot sandboxes see the full size from boot. + if err := sandbox.EnsureImageSizes(rootDir, sandbox.DefaultDiskSizeMB); err != nil { + slog.Error("failed to expand base images", "error", err) + os.Exit(1) + } + cfg := sandbox.Config{ - KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"), - ImagesDir: filepath.Join(rootDir, "images"), - SandboxesDir: filepath.Join(rootDir, "sandboxes"), - SnapshotsDir: filepath.Join(rootDir, "snapshots"), + WrennDir: rootDir, } mgr := sandbox.New(cfg) @@ -74,10 +83,10 @@ func main() { mgr.StartTTLReaper(ctx) // Register with the control plane and start heartbeating. - hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ CPURL: cpURL, RegistrationToken: *registrationToken, - TokenFile: tokenFile, + TokenFile: credsFile, Address: *advertiseAddr, }) if err != nil { @@ -85,17 +94,29 @@ func main() { os.Exit(1) } - hostID, err := hostagent.HostIDFromToken(hostToken) - if err != nil { - slog.Error("failed to extract host ID from token", "error", err) - os.Exit(1) - } - - slog.Info("host registered", "host_id", hostID) + slog.Info("host registered", "host_id", creds.HostID) // httpServer is declared here so the shutdown func can reference it. httpServer := &http.Server{Addr: listenAddr} + // Set up mTLS if the CP issued a certificate during registration. + var certStore hostagent.CertStore + if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" { + if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { + slog.Error("failed to load host TLS certificate", "error", err) + os.Exit(1) + } + tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert) + if tlsCfg == nil { + slog.Error("failed to build agent TLS config: invalid CA certificate PEM") + os.Exit(1) + } + httpServer.TLSConfig = tlsCfg + slog.Info("mTLS enabled on agent server") + } else { + slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP") + } + // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // and httpServer.Shutdown are each called exactly once regardless of // whether shutdown is triggered by a signal, a heartbeat 404, or the @@ -119,13 +140,16 @@ func main() { }) path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) + proxyHandler := hostagent.NewProxyHandler(mgr) + mux := http.NewServeMux() mux.Handle(path, handler) + mux.Handle("/proxy/", proxyHandler) httpServer.Handler = mux // Start heartbeat loop. Handler must be set before this because the // immediate beat can trigger doShutdown → httpServer.Shutdown synchronously. - hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, + hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second, // pauseAll: called on 3 consecutive network failures. func() { pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute) @@ -136,6 +160,17 @@ func main() { func() { doShutdown("host deleted from CP") }, + // onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh. + func(tf *hostagent.TokenFile) { + if tf.CertPEM == "" || tf.KeyPEM == "" { + return + } + if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil { + slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err) + } else { + slog.Info("TLS cert hot-swapped after credentials refresh") + } + }, ) // Graceful shutdown on SIGINT/SIGTERM. @@ -146,10 +181,30 @@ func main() { doShutdown("signal: " + sig.String()) }() - slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID) - if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - slog.Error("http server error", "error", err) - os.Exit(1) + slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID) + if httpServer.TLSConfig != nil { + // When TLSConfig is pre-populated (cert via GetCertificate callback), + // ListenAndServeTLS does not work because it requires on-disk cert/key paths. + // Instead, create the TLS listener manually and call Serve. + ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) + if err != nil { + slog.Error("failed to start TLS listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("https server error", "error", err) + os.Exit(1) + } + } else { + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + slog.Error("failed to start listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("http server error", "error", err) + os.Exit(1) + } } slog.Info("host agent stopped") diff --git a/db/migrations/20260310094104_initial.sql b/db/migrations/20260310094104_initial.sql index c291815a..6c8afc47 100644 --- a/db/migrations/20260310094104_initial.sql +++ b/db/migrations/20260310094104_initial.sql @@ -1,25 +1,237 @@ -- +goose Up -CREATE TABLE sandboxes ( - id TEXT PRIMARY KEY, - owner_id TEXT NOT NULL DEFAULT '', - host_id TEXT NOT NULL DEFAULT 'default', - template TEXT NOT NULL DEFAULT 'minimal', - status TEXT NOT NULL DEFAULT 'pending', - vcpus INTEGER NOT NULL DEFAULT 1, - memory_mb INTEGER NOT NULL DEFAULT 512, - timeout_sec INTEGER NOT NULL DEFAULT 0, - guest_ip TEXT NOT NULL DEFAULT '', - host_ip TEXT NOT NULL DEFAULT '', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - started_at TIMESTAMPTZ, - last_active_at TIMESTAMPTZ, - last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW() +-- teams +CREATE TABLE teams ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL UNIQUE, + is_byoc BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); +CREATE INDEX idx_teams_slug ON teams(slug); + +-- users +CREATE TABLE users ( + id UUID PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + password_hash TEXT, + name TEXT NOT NULL DEFAULT '', + is_admin BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); +-- users_teams (junction) +CREATE TABLE users_teams ( + user_id UUID NOT NULL REFERENCES users(id), + team_id UUID NOT NULL REFERENCES teams(id), + is_default BOOLEAN NOT NULL DEFAULT FALSE, + role TEXT NOT NULL DEFAULT 'member', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (team_id, user_id) +); +CREATE INDEX idx_users_teams_user ON users_teams(user_id); + +-- team_api_keys +CREATE TABLE team_api_keys ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL REFERENCES teams(id), + name TEXT NOT NULL, + key_hash TEXT NOT NULL UNIQUE, + key_prefix TEXT NOT NULL, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_used TIMESTAMPTZ +); +CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id); + +-- oauth_providers +CREATE TABLE oauth_providers ( + provider TEXT NOT NULL, + provider_id TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES users(id), + email TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (provider, provider_id) +); +CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id); + +-- admin_permissions +CREATE TABLE admin_permissions ( + id UUID PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id), + permission TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, permission) +); +CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id); + +-- hosts +CREATE TABLE hosts ( + id UUID PRIMARY KEY, + type TEXT NOT NULL DEFAULT 'regular', + team_id UUID REFERENCES teams(id), + provider TEXT NOT NULL DEFAULT '', + availability_zone TEXT NOT NULL DEFAULT '', + arch TEXT NOT NULL DEFAULT '', + cpu_cores INTEGER NOT NULL DEFAULT 0, + memory_mb INTEGER NOT NULL DEFAULT 0, + disk_gb INTEGER NOT NULL DEFAULT 0, + address TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'pending', + last_heartbeat_at TIMESTAMPTZ, + metadata JSONB NOT NULL DEFAULT '{}', + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + cert_fingerprint TEXT NOT NULL DEFAULT '', + mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX idx_hosts_type ON hosts(type); +CREATE INDEX idx_hosts_team ON hosts(team_id); +CREATE INDEX idx_hosts_status ON hosts(status); + +-- host_tokens +CREATE TABLE host_tokens ( + id UUID PRIMARY KEY, + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL, + used_at TIMESTAMPTZ +); +CREATE INDEX idx_host_tokens_host ON host_tokens(host_id); + +-- host_tags +CREATE TABLE host_tags ( + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + tag TEXT NOT NULL, + PRIMARY KEY (host_id, tag) +); +CREATE INDEX idx_host_tags_tag ON host_tags(tag); + +-- host_refresh_tokens +CREATE TABLE host_refresh_tokens ( + id UUID PRIMARY KEY, + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + token_hash TEXT NOT NULL UNIQUE, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMPTZ +); +CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id); + +-- templates (TEXT primary key — not UUID) +CREATE TABLE templates ( + name TEXT PRIMARY KEY, + type TEXT NOT NULL DEFAULT 'base', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + size_bytes BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + team_id UUID NOT NULL +); +CREATE INDEX idx_templates_team ON templates(team_id); + +-- sandboxes +CREATE TABLE sandboxes ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL REFERENCES teams(id), + host_id UUID NOT NULL, + template TEXT NOT NULL DEFAULT 'minimal', + status TEXT NOT NULL DEFAULT 'pending', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + timeout_sec INTEGER NOT NULL DEFAULT 300, + disk_size_mb INTEGER NOT NULL DEFAULT 5120, + guest_ip TEXT NOT NULL DEFAULT '', + host_ip TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + started_at TIMESTAMPTZ, + last_active_at TIMESTAMPTZ, + last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW() +); CREATE INDEX idx_sandboxes_status ON sandboxes(status); CREATE INDEX idx_sandboxes_host_status ON sandboxes(host_id, status); +CREATE INDEX idx_sandboxes_team ON sandboxes(team_id); + +-- audit_logs (id and team_id are UUID; actor_id and resource_id are TEXT for polymorphism) +CREATE TABLE audit_logs ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL, + actor_type TEXT NOT NULL, + actor_id TEXT, + actor_name TEXT NOT NULL DEFAULT '', + resource_type TEXT NOT NULL, + resource_id TEXT, + action TEXT NOT NULL, + scope TEXT NOT NULL DEFAULT 'team', + status TEXT NOT NULL DEFAULT 'success', + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX idx_audit_logs_team_time ON audit_logs(team_id, created_at DESC); +CREATE INDEX idx_audit_logs_team_resource ON audit_logs(team_id, resource_type, created_at DESC); + +-- sandbox_metrics_snapshots +CREATE TABLE sandbox_metrics_snapshots ( + id BIGSERIAL PRIMARY KEY, + team_id UUID NOT NULL, + sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + running_count INTEGER NOT NULL DEFAULT 0, + vcpus_reserved INTEGER NOT NULL DEFAULT 0, + memory_mb_reserved INTEGER NOT NULL DEFAULT 0 +); +CREATE INDEX idx_metrics_snapshots_team_time ON sandbox_metrics_snapshots(team_id, sampled_at DESC); + +-- sandbox_metric_points +CREATE TABLE sandbox_metric_points ( + sandbox_id UUID NOT NULL, + tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')), + ts BIGINT NOT NULL, + cpu_pct FLOAT8 NOT NULL DEFAULT 0, + mem_bytes BIGINT NOT NULL DEFAULT 0, + disk_bytes BIGINT NOT NULL DEFAULT 0, + PRIMARY KEY (sandbox_id, tier, ts) +); +CREATE INDEX idx_sandbox_metric_points_sandbox_tier ON sandbox_metric_points(sandbox_id, tier); + +-- template_builds +CREATE TABLE template_builds ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + base_template TEXT NOT NULL, + recipe JSONB NOT NULL DEFAULT '[]', + healthcheck TEXT NOT NULL DEFAULT '', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + status TEXT NOT NULL DEFAULT 'pending', + current_step INTEGER NOT NULL DEFAULT 0, + total_steps INTEGER NOT NULL DEFAULT 0, + logs JSONB NOT NULL DEFAULT '[]', + error TEXT NOT NULL DEFAULT '', + sandbox_id UUID, + host_id UUID, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ +); -- +goose Down - -DROP TABLE sandboxes; +DROP TABLE IF EXISTS template_builds; +DROP TABLE IF EXISTS sandbox_metric_points; +DROP TABLE IF EXISTS sandbox_metrics_snapshots; +DROP TABLE IF EXISTS audit_logs; +DROP TABLE IF EXISTS sandboxes; +DROP TABLE IF EXISTS templates; +DROP TABLE IF EXISTS host_refresh_tokens; +DROP TABLE IF EXISTS host_tags; +DROP TABLE IF EXISTS host_tokens; +DROP TABLE IF EXISTS hosts; +DROP TABLE IF EXISTS admin_permissions; +DROP TABLE IF EXISTS oauth_providers; +DROP TABLE IF EXISTS team_api_keys; +DROP TABLE IF EXISTS users_teams; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS teams; diff --git a/db/migrations/20260311224925_snapshots.sql b/db/migrations/20260311224925_snapshots.sql deleted file mode 100644 index 8a0427c2..00000000 --- a/db/migrations/20260311224925_snapshots.sql +++ /dev/null @@ -1,14 +0,0 @@ --- +goose Up - -CREATE TABLE templates ( - name TEXT PRIMARY KEY, - type TEXT NOT NULL DEFAULT 'base', -- 'base' or 'snapshot' - vcpus INTEGER, - memory_mb INTEGER, - size_bytes BIGINT NOT NULL DEFAULT 0, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- +goose Down - -DROP TABLE templates; diff --git a/db/migrations/20260313210608_auth.sql b/db/migrations/20260313210608_auth.sql deleted file mode 100644 index 03970a8f..00000000 --- a/db/migrations/20260313210608_auth.sql +++ /dev/null @@ -1,46 +0,0 @@ --- +goose Up - -CREATE TABLE users ( - id TEXT PRIMARY KEY, - email TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE teams ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE users_teams ( - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, - is_default BOOLEAN NOT NULL DEFAULT TRUE, - role TEXT NOT NULL DEFAULT 'owner', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - PRIMARY KEY (team_id, user_id) -); - -CREATE INDEX idx_users_teams_user ON users_teams(user_id); - -CREATE TABLE team_api_keys ( - id TEXT PRIMARY KEY, - team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, - name TEXT NOT NULL DEFAULT '', - key_hash TEXT NOT NULL UNIQUE, - key_prefix TEXT NOT NULL DEFAULT '', - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - last_used TIMESTAMPTZ -); - -CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id); - --- +goose Down - -DROP TABLE team_api_keys; -DROP TABLE users_teams; -DROP TABLE teams; -DROP TABLE users; diff --git a/db/migrations/20260313210611_team_ownership.sql b/db/migrations/20260313210611_team_ownership.sql deleted file mode 100644 index 849e781e..00000000 --- a/db/migrations/20260313210611_team_ownership.sql +++ /dev/null @@ -1,31 +0,0 @@ --- +goose Up - -ALTER TABLE sandboxes - ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; - -UPDATE sandboxes SET team_id = owner_id WHERE owner_id != ''; - -ALTER TABLE sandboxes - DROP COLUMN owner_id; - -ALTER TABLE templates - ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; - -CREATE INDEX idx_sandboxes_team ON sandboxes(team_id); -CREATE INDEX idx_templates_team ON templates(team_id); - --- +goose Down - -ALTER TABLE sandboxes - ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''; - -UPDATE sandboxes SET owner_id = team_id WHERE team_id != ''; - -ALTER TABLE sandboxes - DROP COLUMN team_id; - -ALTER TABLE templates - DROP COLUMN team_id; - -DROP INDEX IF EXISTS idx_sandboxes_team; -DROP INDEX IF EXISTS idx_templates_team; diff --git a/db/migrations/20260315001514_oauth.sql b/db/migrations/20260315001514_oauth.sql deleted file mode 100644 index c3c33e9b..00000000 --- a/db/migrations/20260315001514_oauth.sql +++ /dev/null @@ -1,22 +0,0 @@ --- +goose Up - -ALTER TABLE users - ALTER COLUMN password_hash DROP NOT NULL; - -CREATE TABLE oauth_providers ( - provider TEXT NOT NULL, - provider_id TEXT NOT NULL, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - email TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - PRIMARY KEY (provider, provider_id) -); - -CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id); - --- +goose Down - -DROP TABLE oauth_providers; - -UPDATE users SET password_hash = '' WHERE password_hash IS NULL; -ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL; diff --git a/db/migrations/20260316203135_admin_users.sql b/db/migrations/20260316203135_admin_users.sql deleted file mode 100644 index eff669b5..00000000 --- a/db/migrations/20260316203135_admin_users.sql +++ /dev/null @@ -1,21 +0,0 @@ --- +goose Up - -ALTER TABLE users - ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE; - -CREATE TABLE admin_permissions ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - permission TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - UNIQUE (user_id, permission) -); - -CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id); - --- +goose Down - -DROP TABLE admin_permissions; - -ALTER TABLE users - DROP COLUMN is_admin; diff --git a/db/migrations/20260316203138_byoc_teams.sql b/db/migrations/20260316203138_byoc_teams.sql deleted file mode 100644 index bb2c8ec2..00000000 --- a/db/migrations/20260316203138_byoc_teams.sql +++ /dev/null @@ -1,9 +0,0 @@ --- +goose Up - -ALTER TABLE teams - ADD COLUMN is_byoc BOOLEAN NOT NULL DEFAULT FALSE; - --- +goose Down - -ALTER TABLE teams - DROP COLUMN is_byoc; diff --git a/db/migrations/20260316203142_hosts.sql b/db/migrations/20260316203142_hosts.sql deleted file mode 100644 index 372b3802..00000000 --- a/db/migrations/20260316203142_hosts.sql +++ /dev/null @@ -1,47 +0,0 @@ --- +goose Up - -CREATE TABLE hosts ( - id TEXT PRIMARY KEY, - type TEXT NOT NULL DEFAULT 'regular', -- 'regular' or 'byoc' - team_id TEXT REFERENCES teams(id) ON DELETE SET NULL, - provider TEXT, - availability_zone TEXT, - arch TEXT, - cpu_cores INTEGER, - memory_mb INTEGER, - disk_gb INTEGER, - address TEXT, -- ip:port of host agent - status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'online', 'offline', 'draining' - last_heartbeat_at TIMESTAMPTZ, - metadata JSONB NOT NULL DEFAULT '{}', - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE host_tokens ( - id TEXT PRIMARY KEY, - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - expires_at TIMESTAMPTZ NOT NULL, - used_at TIMESTAMPTZ -); - -CREATE TABLE host_tags ( - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - tag TEXT NOT NULL, - PRIMARY KEY (host_id, tag) -); - -CREATE INDEX idx_hosts_type ON hosts(type); -CREATE INDEX idx_hosts_team ON hosts(team_id); -CREATE INDEX idx_hosts_status ON hosts(status); -CREATE INDEX idx_host_tokens_host ON host_tokens(host_id); -CREATE INDEX idx_host_tags_tag ON host_tags(tag); - --- +goose Down - -DROP TABLE host_tags; -DROP TABLE host_tokens; -DROP TABLE hosts; diff --git a/db/migrations/20260316223629_host_mtls.sql b/db/migrations/20260316223629_host_mtls.sql deleted file mode 100644 index f56b923e..00000000 --- a/db/migrations/20260316223629_host_mtls.sql +++ /dev/null @@ -1,11 +0,0 @@ --- +goose Up - -ALTER TABLE hosts - ADD COLUMN cert_fingerprint TEXT, - ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; - --- +goose Down - -ALTER TABLE hosts - DROP COLUMN cert_fingerprint, - DROP COLUMN mtls_enabled; diff --git a/db/migrations/20260324071453_team_management.sql b/db/migrations/20260324071453_team_management.sql deleted file mode 100644 index 1495d6dd..00000000 --- a/db/migrations/20260324071453_team_management.sql +++ /dev/null @@ -1,17 +0,0 @@ --- +goose Up - -ALTER TABLE teams ADD COLUMN slug TEXT; -ALTER TABLE teams ADD COLUMN deleted_at TIMESTAMPTZ; - --- Backfill slugs for existing teams using MD5 of their ID. --- MD5 returns 32 hex chars; take chars 1-6 and 7-12 to form a 6-6 slug. -UPDATE teams SET slug = LEFT(MD5(id), 6) || '-' || SUBSTRING(MD5(id), 7, 6); - -ALTER TABLE teams ALTER COLUMN slug SET NOT NULL; -CREATE UNIQUE INDEX idx_teams_slug ON teams(slug); - --- +goose Down - -DROP INDEX idx_teams_slug; -ALTER TABLE teams DROP COLUMN deleted_at; -ALTER TABLE teams DROP COLUMN slug; diff --git a/db/migrations/20260324100234_user_names.sql b/db/migrations/20260324100234_user_names.sql deleted file mode 100644 index 2775d12d..00000000 --- a/db/migrations/20260324100234_user_names.sql +++ /dev/null @@ -1,5 +0,0 @@ --- +goose Up -ALTER TABLE users ADD COLUMN name TEXT NOT NULL DEFAULT ''; - --- +goose Down -ALTER TABLE users DROP COLUMN name; diff --git a/db/migrations/20260324120214_host_refresh_tokens.sql b/db/migrations/20260324120214_host_refresh_tokens.sql deleted file mode 100644 index 02a13f74..00000000 --- a/db/migrations/20260324120214_host_refresh_tokens.sql +++ /dev/null @@ -1,19 +0,0 @@ --- +goose Up - --- Refresh tokens for host agent JWT rotation. --- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation). --- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that. -CREATE TABLE host_refresh_tokens ( - id TEXT PRIMARY KEY, - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token - expires_at TIMESTAMPTZ NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete -); - -CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id); - --- +goose Down - -DROP TABLE host_refresh_tokens; diff --git a/db/migrations/20260324220743_audit_logs.sql b/db/migrations/20260324220743_audit_logs.sql deleted file mode 100644 index 91b73754..00000000 --- a/db/migrations/20260324220743_audit_logs.sql +++ /dev/null @@ -1,28 +0,0 @@ --- +goose Up - -CREATE TABLE audit_logs ( - id TEXT PRIMARY KEY, - team_id TEXT NOT NULL, - actor_type TEXT NOT NULL, -- 'user', 'api_key', 'system' - actor_id TEXT, -- user_id or api_key_id; NULL for system - actor_name TEXT, -- display name snapshotted at write time; NULL for system - resource_type TEXT NOT NULL, -- 'sandbox', 'snapshot', 'team', 'api_key', 'member', 'host' - resource_id TEXT, -- primary ID of the affected resource; NULL when not applicable - action TEXT NOT NULL, -- 'create', 'pause', 'resume', 'destroy', 'delete', 'rename', - -- 'revoke', 'add', 'remove', 'leave', 'role_update', - -- 'marked_down', 'marked_up' - scope TEXT NOT NULL, -- 'team' or 'admin' - status TEXT NOT NULL, -- 'success', 'info', 'warning', 'error' - metadata JSONB NOT NULL DEFAULT '{}', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Primary access pattern: team feed sorted newest-first with cursor pagination. -CREATE INDEX idx_audit_logs_team_time ON audit_logs (team_id, created_at DESC); - --- Secondary index: filtered by resource_type and action within a team. -CREATE INDEX idx_audit_logs_team_resource ON audit_logs (team_id, resource_type, action, created_at DESC); - --- +goose Down - -DROP TABLE audit_logs; diff --git a/db/migrations/20260325074949_metrics_snapshots.sql b/db/migrations/20260325074949_metrics_snapshots.sql deleted file mode 100644 index 7d373e86..00000000 --- a/db/migrations/20260325074949_metrics_snapshots.sql +++ /dev/null @@ -1,18 +0,0 @@ --- +goose Up - -CREATE TABLE sandbox_metrics_snapshots ( - id BIGSERIAL PRIMARY KEY, - team_id TEXT NOT NULL, - sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - running_count INTEGER NOT NULL, - vcpus_reserved INTEGER NOT NULL, - memory_mb_reserved INTEGER NOT NULL -); - --- All queries filter on team_id first then range-scan sampled_at. -CREATE INDEX idx_metrics_snapshots_team_time - ON sandbox_metrics_snapshots (team_id, sampled_at DESC); - --- +goose Down - -DROP TABLE sandbox_metrics_snapshots; diff --git a/db/migrations/20260325135035_add_sandbox_metric_points.sql b/db/migrations/20260325135035_add_sandbox_metric_points.sql deleted file mode 100644 index 08e86835..00000000 --- a/db/migrations/20260325135035_add_sandbox_metric_points.sql +++ /dev/null @@ -1,16 +0,0 @@ --- +goose Up -CREATE TABLE sandbox_metric_points ( - sandbox_id TEXT NOT NULL, - tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')), - ts BIGINT NOT NULL, - cpu_pct FLOAT8 NOT NULL DEFAULT 0, - mem_bytes BIGINT NOT NULL DEFAULT 0, - disk_bytes BIGINT NOT NULL DEFAULT 0, - PRIMARY KEY (sandbox_id, tier, ts) -); - -CREATE INDEX idx_sandbox_metric_points_sandbox_tier - ON sandbox_metric_points (sandbox_id, tier); - --- +goose Down -DROP TABLE IF EXISTS sandbox_metric_points; diff --git a/db/migrations/20260328162803_template_uuid_pk.sql b/db/migrations/20260328162803_template_uuid_pk.sql new file mode 100644 index 00000000..0bb65668 --- /dev/null +++ b/db/migrations/20260328162803_template_uuid_pk.sql @@ -0,0 +1,82 @@ +-- +goose Up + +-- 1. Add UUID id column to templates and make it the primary key. +ALTER TABLE templates ADD COLUMN id UUID DEFAULT gen_random_uuid(); +UPDATE templates SET id = gen_random_uuid() WHERE id IS NULL; +ALTER TABLE templates ALTER COLUMN id SET NOT NULL; +ALTER TABLE templates DROP CONSTRAINT templates_pkey; +ALTER TABLE templates ADD PRIMARY KEY (id); + +-- 2. Name becomes a display field with team-scoped uniqueness. +ALTER TABLE templates ADD CONSTRAINT uq_templates_team_name UNIQUE (team_id, name); + +-- 3. Prevent team templates from using names that belong to global (platform) templates. +-- A team template insert/update with a name matching any platform template is rejected. +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION check_global_template_name_collision() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.team_id != '00000000-0000-0000-0000-000000000000' THEN + IF EXISTS ( + SELECT 1 FROM templates + WHERE name = NEW.name + AND team_id = '00000000-0000-0000-0000-000000000000' + ) THEN + RAISE EXCEPTION 'template name "%" is reserved by a global template', NEW.name + USING ERRCODE = 'unique_violation'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; +-- +goose StatementEnd + +CREATE TRIGGER trg_check_global_template_name + BEFORE INSERT OR UPDATE ON templates + FOR EACH ROW + EXECUTE FUNCTION check_global_template_name_collision(); + +-- 4. Seed the built-in "minimal" template so it appears in all listings. +-- Both id and team_id are the all-zeros UUID (platform sentinel). +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ( + '00000000-0000-0000-0000-000000000000', + 'minimal', + 'base', + 1, + 512, + 0, + '00000000-0000-0000-0000-000000000000' +) ON CONFLICT DO NOTHING; + +-- 5. Add template UUID references to template_builds. +ALTER TABLE template_builds + ADD COLUMN template_id UUID, + ADD COLUMN team_id UUID; + +-- 5. Add template UUID references to sandboxes. +ALTER TABLE sandboxes + ADD COLUMN template_id UUID, + ADD COLUMN template_team_id UUID; + +-- +goose Down + +ALTER TABLE sandboxes + DROP COLUMN IF EXISTS template_team_id, + DROP COLUMN IF EXISTS template_id; + +ALTER TABLE template_builds + DROP COLUMN IF EXISTS team_id, + DROP COLUMN IF EXISTS template_id; + +-- Remove the seeded minimal template. +DELETE FROM templates WHERE id = '00000000-0000-0000-0000-000000000000'; + +DROP TRIGGER IF EXISTS trg_check_global_template_name ON templates; +DROP FUNCTION IF EXISTS check_global_template_name_collision(); + +ALTER TABLE templates DROP CONSTRAINT IF EXISTS uq_templates_team_name; + +ALTER TABLE templates DROP CONSTRAINT IF EXISTS templates_pkey; +ALTER TABLE templates ADD PRIMARY KEY (name); +ALTER TABLE templates DROP COLUMN IF EXISTS id; diff --git a/db/migrations/20260330112050_mtls_cert_expiry.sql b/db/migrations/20260330112050_mtls_cert_expiry.sql new file mode 100644 index 00000000..e7245d27 --- /dev/null +++ b/db/migrations/20260330112050_mtls_cert_expiry.sql @@ -0,0 +1,7 @@ +-- +goose Up +ALTER TABLE hosts DROP COLUMN mtls_enabled; +ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ; + +-- +goose Down +ALTER TABLE hosts DROP COLUMN cert_expires_at; +ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/db/migrations/20260330150223_build_options.sql b/db/migrations/20260330150223_build_options.sql new file mode 100644 index 00000000..981ad065 --- /dev/null +++ b/db/migrations/20260330150223_build_options.sql @@ -0,0 +1,11 @@ +-- +goose Up + +-- Allow completed_at to be set when a build is cancelled. +-- (The UpdateBuildStatus query is updated in sqlc; no schema change needed for that.) + +-- Add skip_pre_post flag: when true, the pre-build and post-build command phases +-- are skipped for this build. +ALTER TABLE template_builds ADD COLUMN skip_pre_post BOOLEAN NOT NULL DEFAULT FALSE; + +-- +goose Down +ALTER TABLE template_builds DROP COLUMN skip_pre_post; diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql index 27ece000..0a5a150b 100644 --- a/db/queries/hosts.sql +++ b/db/queries/hosts.sql @@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC; -- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending'; +-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1; + -- name: UpdateHostStatus :exec UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 131fe1ed..2b195744 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -1,6 +1,6 @@ -- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *; -- name: GetSandbox :one @@ -9,6 +9,14 @@ SELECT * FROM sandboxes WHERE id = $1; -- name: GetSandboxByTeam :one SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2; +-- name: GetSandboxProxyTarget :one +-- Returns the sandbox status and its host's address in one query. +-- Used by SandboxProxyWrapper to avoid two round-trips. +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2; + -- name: ListSandboxes :many SELECT * FROM sandboxes ORDER BY created_at DESC; @@ -50,7 +58,7 @@ WHERE id = $1; UPDATE sandboxes SET status = $2, last_updated = NOW() -WHERE id = ANY($1::text[]); +WHERE id = ANY($1::uuid[]); -- name: ListActiveSandboxesByTeam :many SELECT * FROM sandboxes @@ -72,4 +80,4 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending'); UPDATE sandboxes SET status = 'running', last_updated = NOW() -WHERE id = ANY($1::text[]) AND status = 'missing'; +WHERE id = ANY($1::uuid[]) AND status = 'missing'; diff --git a/db/queries/template_builds.sql b/db/queries/template_builds.sql new file mode 100644 index 00000000..1fb07be3 --- /dev/null +++ b/db/queries/template_builds.sql @@ -0,0 +1,33 @@ +-- name: InsertTemplateBuild :one +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11) +RETURNING *; + +-- name: GetTemplateBuild :one +SELECT * FROM template_builds WHERE id = $1; + +-- name: ListTemplateBuilds :many +SELECT * FROM template_builds ORDER BY created_at DESC; + +-- name: UpdateBuildStatus :one +UPDATE template_builds +SET status = $2, + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END +WHERE id = $1 +RETURNING *; + +-- name: UpdateBuildProgress :exec +UPDATE template_builds +SET current_step = $2, logs = $3 +WHERE id = $1; + +-- name: UpdateBuildSandbox :exec +UPDATE template_builds +SET sandbox_id = $2, host_id = $3 +WHERE id = $1; + +-- name: UpdateBuildError :exec +UPDATE template_builds +SET error = $2, status = 'failed', completed_at = NOW() +WHERE id = $1; diff --git a/db/queries/templates.sql b/db/queries/templates.sql index b17abc3a..de4d6f2a 100644 --- a/db/queries/templates.sql +++ b/db/queries/templates.sql @@ -1,13 +1,22 @@ -- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) -VALUES ($1, $2, $3, $4, $5, $6) +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *; -- name: GetTemplate :one -SELECT * FROM templates WHERE name = $1; +SELECT * FROM templates WHERE id = $1; -- name: GetTemplateByTeam :one -SELECT * FROM templates WHERE name = $1 AND team_id = $2; +-- Platform templates (team_id = 00000000-...) are visible to all teams. +SELECT * FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000'); + +-- name: GetTemplateByName :one +-- Look up a template by team_id and name (exact team match, no global fallback). +SELECT * FROM templates WHERE team_id = $1 AND name = $2; + +-- name: GetPlatformTemplateByName :one +-- Check if a global (platform) template exists with the given name. +SELECT * FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1; -- name: ListTemplates :many SELECT * FROM templates ORDER BY created_at DESC; @@ -16,13 +25,23 @@ SELECT * FROM templates ORDER BY created_at DESC; SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC; -- name: ListTemplatesByTeam :many -SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC; +-- Platform templates are visible to all teams. +SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC; -- name: ListTemplatesByTeamAndType :many -SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC; +-- Platform templates are visible to all teams. +SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC; -- name: DeleteTemplate :exec -DELETE FROM templates WHERE name = $1; +DELETE FROM templates WHERE id = $1; -- name: DeleteTemplateByTeam :exec DELETE FROM templates WHERE name = $1 AND team_id = $2; + +-- name: DeleteTemplatesByTeam :exec +-- Bulk delete all templates owned by a team (for team soft-delete cleanup). +DELETE FROM templates WHERE team_id = $1; + +-- name: ListTemplatesByTeamOnly :many +-- List templates owned by a specific team (NOT including platform templates). +SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC; diff --git a/deploy/Caddyfile.dev b/deploy/Caddyfile.dev new file mode 100644 index 00000000..789f8dfd --- /dev/null +++ b/deploy/Caddyfile.dev @@ -0,0 +1,41 @@ +# Sandbox port forwarding: {port}-{sandbox_id}.localhost +# Matches subdomains like 49999-sb-abcd1234.localhost and proxies them +# to the control plane, which inspects the Host header and routes to +# the correct host agent. +# +# NOTE: Wildcard *.localhost DNS resolution requires local setup. +# Option 1: Add entries to /etc/hosts for each sandbox +# Option 2: Use dnsmasq: address=/.localhost/127.0.0.1 +# Option 3: Use systemd-resolved (Ubuntu default — *.localhost resolves to 127.0.0.1) +http://*.localhost { + reverse_proxy host.docker.internal:8080 +} + +# Main entry point: API + frontend +http://localhost { + # API routes — strip /api prefix and proxy to the control plane. + # The frontend calls /api/v1/... which becomes /v1/... at the CP. + handle_path /api/* { + reverse_proxy host.docker.internal:8080 + } + + # Backend routes served directly (SDK clients, OAuth initiation) + handle /v1/* { + reverse_proxy host.docker.internal:8080 + } + handle /openapi.yaml { + reverse_proxy host.docker.internal:8080 + } + handle /docs { + reverse_proxy host.docker.internal:8080 + } + handle /auth/oauth/* { + reverse_proxy host.docker.internal:8080 + } + + # Everything else — proxy to the frontend dev server + # This includes: /login, /dashboard/*, /admin/*, /auth/github/callback + handle { + reverse_proxy host.docker.internal:5173 + } +} diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index ebcd3087..28f401f6 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -15,19 +15,14 @@ services: ports: - "6379:6379" - prometheus: - image: prom/prometheus:latest + caddy: + image: caddy:2-alpine ports: - - "9090:9090" + - "8000:80" volumes: - - ./deploy/prometheus.yml:/etc/prometheus/prometheus.yml - - grafana: - image: grafana/grafana:latest - ports: - - "3001:3000" - environment: - GF_SECURITY_ADMIN_PASSWORD: admin + - ./Caddyfile.dev:/etc/caddy/Caddyfile:ro + extra_hosts: + - "host.docker.internal:host-gateway" volumes: pgdata: diff --git a/envd/internal/api/init.go b/envd/internal/api/init.go index a4894599..301400cb 100644 --- a/envd/internal/api/init.go +++ b/envd/internal/api/init.go @@ -14,14 +14,12 @@ import ( "os/exec" "time" - "github.com/awnumar/memguard" - "github.com/rs/zerolog" - "github.com/txn2/txeh" - "golang.org/x/sys/unix" - "git.omukk.dev/wrenn/sandbox/envd/internal/host" "git.omukk.dev/wrenn/sandbox/envd/internal/logs" "git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys" + "github.com/awnumar/memguard" + "github.com/rs/zerolog" + "github.com/txn2/txeh" ) var ( @@ -29,11 +27,6 @@ var ( ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized") ) -const ( - maxTimeInPast = 50 * time.Millisecond - maxTimeInFuture = 5 * time.Second -) - // validateInitAccessToken validates the access token for /init requests. // Token is valid if it matches the existing token OR the MMDS hash. // If neither exists, first-time setup is allowed. @@ -172,20 +165,6 @@ func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJ return err } - if data.Timestamp != nil { - // Check if current time differs significantly from the received timestamp - if shouldSetSystemTime(time.Now(), *data.Timestamp) { - logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp) - ts := unix.NsecToTimespec(data.Timestamp.UnixNano()) - err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts) - if err != nil { - logger.Error().Msgf("Failed to set system time: %v", err) - } - } else { - logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp) - } - } - if data.EnvVars != nil { logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars))) @@ -308,10 +287,3 @@ func getIPFamily(address string) (txeh.IPFamily, error) { return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address) } } - -// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp, -// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than -// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime. -func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool { - return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture)) -} diff --git a/envd/internal/api/init_test.go b/envd/internal/api/init_test.go index c4b6f4b5..f3db3615 100644 --- a/envd/internal/api/init_test.go +++ b/envd/internal/api/init_test.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -59,71 +58,6 @@ func TestSimpleCases(t *testing.T) { } } -func TestShouldSetSystemTime(t *testing.T) { - t.Parallel() - sandboxTime := time.Now() - - tests := []struct { - name string - hostTime time.Time - want bool - }{ - { - name: "sandbox time far ahead of host time (should set)", - hostTime: sandboxTime.Add(-10 * time.Second), - want: true, - }, - { - name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)", - hostTime: sandboxTime.Add(-50 * time.Millisecond), - want: false, - }, - { - name: "sandbox time just within maxTimeInPast ahead of host time (should not set)", - hostTime: sandboxTime.Add(-40 * time.Millisecond), - want: false, - }, - { - name: "sandbox time slightly ahead of host time (should not set)", - hostTime: sandboxTime.Add(-10 * time.Millisecond), - want: false, - }, - { - name: "sandbox time equals host time (should not set)", - hostTime: sandboxTime, - want: false, - }, - { - name: "sandbox time slightly behind host time (should not set)", - hostTime: sandboxTime.Add(1 * time.Second), - want: false, - }, - { - name: "sandbox time just within maxTimeInFuture behind host time (should not set)", - hostTime: sandboxTime.Add(4 * time.Second), - want: false, - }, - { - name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)", - hostTime: sandboxTime.Add(5 * time.Second), - want: false, - }, - { - name: "sandbox time far behind host time (should set)", - hostTime: sandboxTime.Add(1 * time.Minute), - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := shouldSetSystemTime(tt.hostTime, sandboxTime) - assert.Equal(t, tt.want, got) - }) - } -} - func secureTokenPtr(s string) *SecureToken { token := &SecureToken{} _ = token.Set([]byte(s)) diff --git a/envd/internal/port/conn.go b/envd/internal/port/conn.go new file mode 100644 index 00000000..8a8c032c --- /dev/null +++ b/envd/internal/port/conn.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 + +package port + +import ( + "bufio" + "encoding/hex" + "fmt" + "net" + "os" + "strconv" + "strings" + "syscall" +) + +// ConnStat represents a single TCP connection read from /proc/net/tcp(6). +// It contains only the fields needed by the port scanner and forwarder. +type ConnStat struct { + LocalIP string + LocalPort uint32 + Status string + Family uint32 // syscall.AF_INET or syscall.AF_INET6 + Inode uint64 // socket inode, unique per connection +} + +// tcpStates maps the hex state values from /proc/net/tcp to string names +// matching the gopsutil convention used by ScannerFilter. +var tcpStates = map[string]string{ + "01": "ESTABLISHED", + "02": "SYN_SENT", + "03": "SYN_RECV", + "04": "FIN_WAIT1", + "05": "FIN_WAIT2", + "06": "TIME_WAIT", + "07": "CLOSE", + "08": "CLOSE_WAIT", + "09": "LAST_ACK", + "0A": "LISTEN", + "0B": "CLOSING", +} + +// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns +// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil +// performs, which is unsafe across Firecracker snapshot/restore boundaries. +func ReadTCPConnections() ([]ConnStat, error) { + var conns []ConnStat + + tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp: %w", err) + } + conns = append(conns, tcp4...) + + tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err) + } + conns = append(conns, tcp6...) + + return conns, nil +} + +// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file. +// +// Format (fields are whitespace-separated): +// +// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode +// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345 +func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var conns []ConnStat + scanner := bufio.NewScanner(f) + + // Skip header line. + scanner.Scan() + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 10 { + continue + } + + // fields[1] = local_address (hex_ip:hex_port) + ip, port, err := parseHexAddr(fields[1], family) + if err != nil { + continue + } + + // fields[3] = state (hex) + state, ok := tcpStates[fields[3]] + if !ok { + state = "UNKNOWN" + } + + // fields[9] = inode + inode, err := strconv.ParseUint(fields[9], 10, 64) + if err != nil { + continue + } + + conns = append(conns, ConnStat{ + LocalIP: ip, + LocalPort: port, + Status: state, + Family: family, + Inode: inode, + }) + } + + return conns, scanner.Err() +} + +// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp. +// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word). +// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word). +func parseHexAddr(s string, family uint32) (string, uint32, error) { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid address: %s", s) + } + + port64, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return "", 0, err + } + + ipHex := parts[0] + ipBytes, err := hex.DecodeString(ipHex) + if err != nil { + return "", 0, err + } + + var ip net.IP + if family == syscall.AF_INET { + if len(ipBytes) != 4 { + return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes)) + } + // /proc/net/tcp stores IPv4 as a single little-endian 32-bit word. + ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0]) + } else { + if len(ipBytes) != 16 { + return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes)) + } + // /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words. + ip = make(net.IP, 16) + for i := 0; i < 4; i++ { + ip[i*4+0] = ipBytes[i*4+3] + ip[i*4+1] = ipBytes[i*4+2] + ip[i*4+2] = ipBytes[i*4+1] + ip[i*4+3] = ipBytes[i*4+0] + } + } + + return ip.String(), uint32(port64), nil +} diff --git a/envd/internal/port/forward.go b/envd/internal/port/forward.go index e8365196..bf516ff9 100644 --- a/envd/internal/port/forward.go +++ b/envd/internal/port/forward.go @@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21) type PortToForward struct { socat *exec.Cmd - // Process ID of the process that's listening on port. - pid int32 + // Socket inode of the listening socket (unique per connection). + inode uint64 // family version of the ip. family uint32 state PortState @@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { // Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state. // This will make sure we won't delete them later. for _, p := range procs { - key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port) + key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort) // We check if the opened port is in our map of forwarded ports. val, portOk := f.ports[key] @@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { val.state = PortStateForward } else { f.logger.Debug(). - Str("ip", p.Laddr.IP). - Uint32("port", p.Laddr.Port). + Str("ip", p.LocalIP). + Uint32("port", p.LocalPort). Uint32("family", familyToIPVersion(p.Family)). Str("state", p.Status). Msg("Detected new opened port on localhost that is not forwarded") // The opened port wasn't in the map so we create a new PortToForward and start forwarding. ptf := &PortToForward{ - pid: p.Pid, - port: p.Laddr.Port, + inode: p.Inode, + port: p.LocalPort, state: PortStateForward, family: familyToIPVersion(p.Family), } @@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) { f.logger.Debug(). Str("socatCmd", cmd.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). @@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) { logger := f.logger.With(). Str("socatCmd", p.socat.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). diff --git a/envd/internal/port/scan.go b/envd/internal/port/scan.go index 766202a6..2b155233 100644 --- a/envd/internal/port/scan.go +++ b/envd/internal/port/scan.go @@ -3,19 +3,21 @@ package port import ( + "sync" "time" "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" - - "git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap" ) type Scanner struct { - Processes chan net.ConnectionStat - scanExit chan struct{} - subs *smap.Map[*ScannerSubscriber] - period time.Duration + scanExit chan struct{} + period time.Duration + + // Plain mutex-protected map instead of concurrent-map. The concurrent-map + // library's Items() spawns goroutines and uses a WaitGroup internally, + // which corrupts Go runtime semaphore state across Firecracker snapshot/restore. + mu sync.RWMutex + subs map[string]*ScannerSubscriber } func (s *Scanner) Destroy() { @@ -24,33 +26,44 @@ func (s *Scanner) Destroy() { func NewScanner(period time.Duration) *Scanner { return &Scanner{ - period: period, - subs: smap.New[*ScannerSubscriber](), - scanExit: make(chan struct{}), - Processes: make(chan net.ConnectionStat), + period: period, + subs: make(map[string]*ScannerSubscriber), + scanExit: make(chan struct{}), } } func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber { subscriber := NewScannerSubscriber(logger, id, filter) - s.subs.Insert(id, subscriber) + + s.mu.Lock() + s.subs[id] = subscriber + s.mu.Unlock() return subscriber } func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) { - s.subs.Remove(sub.ID()) + s.mu.Lock() + delete(s.subs, sub.ID()) + s.mu.Unlock() + sub.Destroy() } // ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers. func (s *Scanner) ScanAndBroadcast() { for { - // tcp monitors both ipv4 and ipv6 connections. - processes, _ := net.Connections("tcp") - for _, sub := range s.subs.Items() { - sub.Signal(processes) + // Read directly from /proc/net/tcp and /proc/net/tcp6 instead of + // using gopsutil's net.Connections(), which walks /proc/{pid}/fd + // and causes Go runtime corruption after Firecracker snapshot/restore. + conns, _ := ReadTCPConnections() + + s.mu.RLock() + for _, sub := range s.subs { + sub.Signal(conns) } + s.mu.RUnlock() + select { case <-s.scanExit: return diff --git a/envd/internal/port/scanSubscriber.go b/envd/internal/port/scanSubscriber.go index 6a4f5b01..bad99083 100644 --- a/envd/internal/port/scanSubscriber.go +++ b/envd/internal/port/scanSubscriber.go @@ -4,7 +4,6 @@ package port import ( "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" ) // If we want to create a listener/subscriber pattern somewhere else we should move @@ -13,7 +12,7 @@ import ( type ScannerSubscriber struct { logger *zerolog.Logger filter *ScannerFilter - Messages chan ([]net.ConnectionStat) + Messages chan ([]ConnStat) id string } @@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt logger: logger, id: id, filter: filter, - Messages: make(chan []net.ConnectionStat), + Messages: make(chan []ConnStat), } } @@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() { close(ss.Messages) } -func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) { +func (ss *ScannerSubscriber) Signal(conns []ConnStat) { // Filter isn't specified. Accept everything. if ss.filter == nil { - ss.Messages <- proc + ss.Messages <- conns } else { - filtered := []net.ConnectionStat{} - for i := range proc { + filtered := []ConnStat{} + for i := range conns { // We need to access the list directly otherwise there will be implicit memory aliasing - // If the filter matched a process, we will send it to a channel. - if ss.filter.Match(&proc[i]) { - filtered = append(filtered, proc[i]) + // If the filter matched a connection, we will send it to a channel. + if ss.filter.Match(&conns[i]) { + filtered = append(filtered, conns[i]) } } ss.Messages <- filtered diff --git a/envd/internal/port/scanfilter.go b/envd/internal/port/scanfilter.go index 941023d9..f87667f2 100644 --- a/envd/internal/port/scanfilter.go +++ b/envd/internal/port/scanfilter.go @@ -4,8 +4,6 @@ package port import ( "slices" - - "github.com/shirou/gopsutil/v4/net" ) type ScannerFilter struct { @@ -13,15 +11,15 @@ type ScannerFilter struct { IPs []string } -func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool { +func (sf *ScannerFilter) Match(conn *ConnStat) bool { // Filter is an empty struct. if sf.State == "" && len(sf.IPs) == 0 { return false } - ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP) + ipMatch := slices.Contains(sf.IPs, conn.LocalIP) - if ipMatch && sf.State == proc.Status { + if ipMatch && sf.State == conn.Status { return true } diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts new file mode 100644 index 00000000..1de23b8d --- /dev/null +++ b/frontend/src/lib/api/builds.ts @@ -0,0 +1,76 @@ +import { apiFetch, type ApiResult } from '$lib/api/client'; + +export type BuildLogEntry = { + step: number; + phase: string; // "pre-build", "recipe", or "post-build" + cmd: string; + stdout: string; + stderr: string; + exit: number; + ok: boolean; + elapsed_ms: number; +}; + +export type Build = { + id: string; + name: string; + base_template: string; + recipe: string[]; + healthcheck?: string; + vcpus: number; + memory_mb: number; + status: string; + current_step: number; + total_steps: number; + logs: BuildLogEntry[]; + error?: string; + sandbox_id?: string; + host_id?: string; + created_at: string; + started_at?: string; + completed_at?: string; +}; + +export type CreateBuildParams = { + name: string; + base_template?: string; + recipe: string[]; + healthcheck?: string; + vcpus?: number; + memory_mb?: number; + skip_pre_post?: boolean; +}; + +export async function createBuild(params: CreateBuildParams): Promise> { + return apiFetch('POST', '/api/v1/admin/builds', params); +} + +export async function listBuilds(): Promise> { + return apiFetch('GET', '/api/v1/admin/builds'); +} + +export async function getBuild(id: string): Promise> { + return apiFetch('GET', `/api/v1/admin/builds/${id}`); +} + +export type AdminTemplate = { + name: string; + type: string; + vcpus: number; + memory_mb: number; + size_bytes: number; + team_id: string; + created_at: string; +}; + +export async function listAdminTemplates(): Promise> { + return apiFetch('GET', '/api/v1/admin/templates'); +} + +export async function deleteAdminTemplate(name: string): Promise> { + return apiFetch('DELETE', `/api/v1/admin/templates/${name}`); +} + +export async function cancelBuild(id: string): Promise> { + return apiFetch('POST', `/api/v1/admin/builds/${id}/cancel`); +} diff --git a/frontend/src/lib/api/capsules.ts b/frontend/src/lib/api/capsules.ts index cc4ad79b..565f14f2 100644 --- a/frontend/src/lib/api/capsules.ts +++ b/frontend/src/lib/api/capsules.ts @@ -54,6 +54,7 @@ export type Snapshot = { memory_mb?: number; size_bytes: number; created_at: string; + platform: boolean; }; export async function createSnapshot(sandboxId: string, name?: string): Promise> { diff --git a/frontend/src/lib/components/AdminSidebar.svelte b/frontend/src/lib/components/AdminSidebar.svelte index 4bed5cc9..ebf4b646 100644 --- a/frontend/src/lib/components/AdminSidebar.svelte +++ b/frontend/src/lib/components/AdminSidebar.svelte @@ -3,6 +3,7 @@ import { auth } from '$lib/auth.svelte'; import { IconServer, + IconTemplate, IconSettings, IconLogout, IconSidebar, @@ -21,7 +22,8 @@ }; const managementItems: NavItem[] = [ - { label: 'Hosts', icon: IconServer, href: '/admin/hosts' } + { label: 'Hosts', icon: IconServer, href: '/admin/hosts' }, + { label: 'Templates', icon: IconTemplate, href: '/admin/templates' } ]; function isActive(href: string): boolean { diff --git a/frontend/src/lib/components/CreateCapsuleDialog.svelte b/frontend/src/lib/components/CreateCapsuleDialog.svelte index b570f2b4..2bd027d6 100644 --- a/frontend/src/lib/components/CreateCapsuleDialog.svelte +++ b/frontend/src/lib/components/CreateCapsuleDialog.svelte @@ -56,6 +56,7 @@ class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]" placeholder="minimal" /> +

Name of a snapshot or base image to boot from.

@@ -85,14 +86,16 @@
- + +

Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.

diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte new file mode 100644 index 00000000..4619e7b9 --- /dev/null +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -0,0 +1,902 @@ + + +
+ + +
+ +
+
+
+

+ Templates +

+

+ Build and manage global templates available to all teams. +

+
+ +
+ + + {#if !templatesLoading && !templatesError} +
+
+ {templateCount} + templates +
+
+ {baseCount} + base +
+
+ {snapshotCount} + snapshots +
+ {#if runningBuilds > 0} +
+ + + + + {runningBuilds} + building +
+ {/if} +
+ {/if} +
+ + +
+ {#each [['templates', 'Templates', templateCount], ['builds', 'Builds', builds.length]] as [id, label, count] (id)} + + {/each} +
+ + +
+ {#if activeTab === 'templates'} + {#if templatesLoading} + {@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])} + {:else if templatesError} +
+ {templatesError} +
+ {:else if templates.length === 0} + {@render emptyState('templates')} + {:else} + {@render templatesTable()} + {/if} + {:else} + {#if buildsLoading} + {@render skeletonRows(4, ['Build', 'Name', 'Status', 'Progress', 'Started', 'Duration'])} + {:else if buildsError} +
+ {buildsError} +
+ {:else if builds.length === 0} + {@render emptyState('builds')} + {:else} + {@render buildsTable()} + {/if} + {/if} +
+
+
+ + + +{#snippet skeletonRows(count: number, headers: string[])} +
+ + + + {#each headers as h} + + {/each} + + + + {#each Array(count) as _, i} + + {#each headers as _h, j} + + {/each} + + {/each} + +
{h}
+
+
+
+{/snippet} + +{#snippet emptyState(type: 'templates' | 'builds')} +
+
+ {#if type === 'templates'} + + {:else} + + {/if} +
+

+ {type === 'templates' ? 'No templates yet.' : 'No builds yet.'} +

+

+ {type === 'templates' + ? 'Create a template to provide pre-configured environments for all teams.' + : 'Start a template build to see progress and logs here.'} +

+
+{/snippet} + +{#snippet templatesTable()} +
+ + + + + + + + + + + + + {#each templates as tmpl (tmpl.name)} + + + + + + + + + {/each} + +
NameType
+ {tmpl.name} + + {#if tmpl.type === 'snapshot'} + + snapshot + + {:else} + + base + + {/if} + + +
+
+{/snippet} + +{#snippet buildsTable()} +
+ + + + + + + + + + + + + + {#each builds as build (build.id)} + toggleBuildExpand(build.id)} + > + + + + + + + + + + {#if expandedBuildId === build.id} + + + + {/if} + {/each} + +
BuildNameStatus
+
+ + + + {build.id} +
+
+ {build.name} + + + {#if build.status === 'running'} + + + + + {:else if build.status === 'success'} + + {:else if build.status === 'failed'} + + {:else} + + {/if} + {build.status} + +
+
+ {#if build.status === 'pending' || build.status === 'running'} +
+ +
+ {/if} + {#if build.error} +
+ {build.error} +
+ {/if} + + {#if build.logs && build.logs.length > 0} +
+ {#each build.logs as log, i (i)} + {@const isInternal = log.phase === 'pre-build' || log.phase === 'post-build'} + {@const recipeIdx = log.phase === 'recipe' ? build.logs.filter(l => l.phase === 'recipe' && l.step <= log.step).length : 0} + {@const phaseLabel = isInternal ? (log.phase === 'pre-build' ? 'Pre-build' : 'Post-build') : `Step ${recipeIdx}`} + {@const [kw, kwRest] = splitInstruction(log.cmd)} +
+ + + + + {#if expandedSteps.has(log.step)} +
+ {#if log.stdout} +
+ stdout +
{log.stdout}
+
+ {/if} + {#if log.stderr} +
+ stderr +
{log.stderr}
+
+ {/if} + {#if !log.stdout && !log.stderr} + No output + {/if} +
+ {/if} +
+ {/each} +
+ {:else} +
+ {#if build.status === 'pending' || build.status === 'running'} + + {build.status === 'pending' ? 'Waiting for worker…' : 'Running…'} + {:else} + No build logs recorded. + {/if} +
+ {/if} + + + {#if build.recipe && build.recipe.length > 0} +
+ Recipe +
+ {#each build.recipe as cmd, i} + {@const [kw, kwRest] = splitInstruction(cmd)} +
+ {i + 1}. + {kw}{#if kwRest} {kwRest}{/if} +
+ {/each} +
+
+ {/if} + + {#if build.healthcheck} +
+ Healthcheck + {build.healthcheck} +
+ {/if} +
+
+
+{/snippet} + + +{#if showCreate} +
+
{ if (!creating) showCreate = false; }} + onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }} + >
+
+

+ Create Template +

+

+ Build a new global template by running commands on a base image. +

+ + {#if createError} +
+ {createError} +
+ {/if} + +
+
+ + +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+ + +

+ Supports RUN, START, WORKDIR, ENV key=value. RUN steps have a 30s timeout; override with RUN --timeout=5m. +

+
+ +
+ + +

+ If set, the build will poll this command every 1s (up to 60s) after the recipe completes. On success, a full snapshot (with memory state) is created. Without a healthcheck, only the rootfs is saved. +

+
+ + +
+ +
+ + +
+
+
+{/if} + + +{#if deleteTarget} +
+
{ if (!deleting) deleteTarget = null; }} + onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }} + >
+
+

+ Delete Template +

+

+ Permanently remove {deleteTarget.name} from all hosts. +

+ + {#if deleteError} +
+ {deleteError} +
+ {/if} + +
+ + +
+
+
+{/if} + + diff --git a/frontend/src/routes/dashboard/capsules/+layout.svelte b/frontend/src/routes/dashboard/capsules/+layout.svelte index d9f93f82..5c400b9a 100644 --- a/frontend/src/routes/dashboard/capsules/+layout.svelte +++ b/frontend/src/routes/dashboard/capsules/+layout.svelte @@ -47,7 +47,7 @@ Capsules

- Isolated VMs. Start cold in under a second — pause, snapshot, or destroy at will. + All active and recent capsules across your team.

diff --git a/frontend/src/routes/dashboard/capsules/+page.svelte b/frontend/src/routes/dashboard/capsules/+page.svelte index 4f270035..afb9de0c 100644 --- a/frontend/src/routes/dashboard/capsules/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/+page.svelte @@ -247,6 +247,13 @@ return `${Math.floor(seconds / 86400)}d ago`; } + function fmtTimeout(sec: number): string { + if (!sec) return 'None'; + if (sec < 60) return `${sec}s`; + if (sec < 3600) return `${Math.round(sec / 60)}m`; + return `${Math.round(sec / 3600)}h`; + } + function handleClickOutside(event: MouseEvent) { if (openMenuId && !(event.target as Element)?.closest('.status-menu-container')) { openMenuId = null; @@ -300,7 +307,7 @@ class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-2)] py-2 pl-9 pr-3 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]" /> - {filteredCapsules.length} total + {filteredCapsules.length} capsule{filteredCapsules.length !== 1 ? 's' : ''}
@@ -363,8 +370,11 @@ {#if error} -
- {error} +
+ + + + {error}. Try refreshing the page.
{/if} @@ -466,14 +476,14 @@
- {capsule.timeout_sec ? `${capsule.timeout_sec}s` : '—'} + {fmtTimeout(capsule.timeout_sec)}
{formatTime(capsule.started_at)} {#if capsule.last_active_at} - {timeAgo(capsule.last_active_at)} + active {timeAgo(capsule.last_active_at)} {/if}
@@ -612,7 +622,7 @@ -

This capsule will be paused first — memory state is captured at rest.

+

This capsule will be paused first, then its full state (memory + disk) will be captured.

{:else}

The capsule's current memory state will be captured and stored as a reusable snapshot.

diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.js b/frontend/src/routes/dashboard/capsules/[id]/+page.js new file mode 100644 index 00000000..d43d0cd2 --- /dev/null +++ b/frontend/src/routes/dashboard/capsules/[id]/+page.js @@ -0,0 +1 @@ +export const prerender = false; diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte index e932209a..ed26426d 100644 --- a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte @@ -512,7 +512,7 @@ - Failed to load metrics: {metricsError} + Could not load metrics: {metricsError}. Will retry automatically. {/if} diff --git a/frontend/src/routes/dashboard/snapshots/+page.svelte b/frontend/src/routes/dashboard/snapshots/+page.svelte index e39bf3b6..2ae201ab 100644 --- a/frontend/src/routes/dashboard/snapshots/+page.svelte +++ b/frontend/src/routes/dashboard/snapshots/+page.svelte @@ -114,15 +114,15 @@ } function emptyHeading(f: TypeFilter): string { - if (f === 'snapshot') return 'No snapshots'; - if (f === 'base') return 'No images'; - return 'No templates yet'; + if (f === 'snapshot') return 'No snapshots yet'; + if (f === 'base') return 'No base images'; + return 'No snapshots yet'; } function emptyDescription(f: TypeFilter): string { - if (f === 'snapshot') return 'Pause a capsule from the Capsules page, then snapshot it to capture its state.'; - if (f === 'base') return 'Base images are added by the Wrenn team. Contact support to request a custom image.'; - return 'To create a snapshot, go to Capsules, pause a running capsule, then choose Snapshot.'; + if (f === 'snapshot') return 'Pause a running capsule, then choose Snapshot to save its state.'; + if (f === 'base') return 'Base images are provided by the Wrenn team. Contact support to request a custom one.'; + return 'Pause a running capsule, then choose Snapshot to save its state. You can launch new capsules from any snapshot.'; } onMount(fetchSnapshots); @@ -162,7 +162,7 @@ Templates

- Snapshots capture a live capsule state. Base images are the rootfs every capsule starts from. Launch a full VM from any template. + Snapshots capture a running capsule's state. Base images are the starting point for every new capsule. Launch from either.

@@ -206,8 +206,11 @@ {#if pageTab === 'snapshots'}
{#if error} -
- {error} +
+ + + + {error}. Try refreshing the page.
{/if} @@ -274,11 +277,11 @@
{filteredSnapshots.length} - {typeFilter === 'all' - ? filteredSnapshots.length === 1 ? 'template' : 'templates' - : typeFilter === 'snapshot' - ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' - : filteredSnapshots.length === 1 ? 'image' : 'images'} + {typeFilter === 'snapshot' + ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' + : typeFilter === 'base' + ? filteredSnapshots.length === 1 ? 'image' : 'images' + : filteredSnapshots.length === 1 ? 'item' : 'total'}
@@ -420,6 +423,7 @@
- - - -
-
-
- Team ID -
- {team.id} + Team ID
-
+ diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 89afaba3..350070b5 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -7,7 +7,7 @@ export default defineConfig({ server: { proxy: { '/api': { - target: 'http://localhost:8000', + target: 'http://localhost:8080', rewrite: (path) => path.replace(/^\/api/, '') } } diff --git a/images/wrenn-init.sh b/images/wrenn-init.sh index 32285ea7..266e516a 100644 --- a/images/wrenn-init.sh +++ b/images/wrenn-init.sh @@ -1,6 +1,6 @@ #!/bin/sh # wrenn-init: minimal PID 1 init for Firecracker microVMs. -# Mounts virtual filesystems then execs envd. +# Mounts virtual filesystems, starts chronyd for time sync, then execs tini + envd. set -e @@ -25,7 +25,19 @@ echo "nameserver 8.8.8.8" > /etc/resolv.conf echo "nameserver 8.8.4.4" >> /etc/resolv.conf # Set a standard PATH so envd and all child processes can find common binaries. -export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games + +# Write chrony config to sync time from the KVM PTP hardware clock. +# /dev/ptp0 is a paravirtual clock exposed by KVM — no network required. +mkdir -p /etc/chrony /run/chrony +cat > /etc/chrony/chrony.conf </dev/null || true # Exec tini as PID 1 — it reaps zombie processes and forwards signals to envd. exec /sbin/tini -- /usr/local/bin/envd diff --git a/internal/api/agent_helper.go b/internal/api/agent_helper.go index ac5b38e0..98a881d6 100644 --- a/internal/api/agent_helper.go +++ b/internal/api/agent_helper.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -11,7 +13,7 @@ import ( // agentForHost looks up the host record and returns a Connect RPC client for it. // Returns an error if the host is not found or has no address. -func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) { +func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) { host, err := queries.GetHost(ctx, hostID) if err != nil { return nil, fmt.Errorf("host not found: %w", err) diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go new file mode 100644 index 00000000..963dff69 --- /dev/null +++ b/internal/api/handler_sandbox_proxy.go @@ -0,0 +1,229 @@ +package api + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" +) + +// Sentinel errors returned by proxyTarget, used to map to HTTP status codes +// without relying on error message text. +var ( + errProxySandboxNotFound = errors.New("sandbox not found") + errProxyNoHostAddress = errors.New("host agent has no address") +) + +const proxyCacheTTL = 120 * time.Second + +// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or +// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID. +var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`) + +// errProxySandboxNotRunning carries the sandbox status so callers can include +// it in the HTTP response without parsing error strings. +type errProxySandboxNotRunning struct{ status string } + +func (e errProxySandboxNotRunning) Error() string { + return fmt.Sprintf("sandbox is not running (status: %s)", e.status) +} + +// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair. +// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure +// can capture the correct port without the cache key needing to include it. +type proxyCacheEntry struct { + agentURL *url.URL + expiresAt time.Time +} + +// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation. +type proxyCacheKey [32]byte + +func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey { + var k proxyCacheKey + copy(k[:16], sandboxID.Bytes[:]) + copy(k[16:], teamID.Bytes[:]) + return k +} + +// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests +// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching +// requests are reverse-proxied through the host agent that owns the sandbox. +// All other requests are passed through to the inner handler. +// +// Authentication is via X-API-Key header only (no JWT). The API key's team +// must own the sandbox. +type SandboxProxyWrapper struct { + inner http.Handler + db *db.Queries + pool *lifecycle.HostClientPool + transport http.RoundTripper + + cacheMu sync.Mutex + cache map[proxyCacheKey]proxyCacheEntry +} + +// NewSandboxProxyWrapper creates a new proxy wrapper. +func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper { + return &SandboxProxyWrapper{ + inner: inner, + db: queries, + pool: pool, + transport: pool.Transport(), + cache: make(map[proxyCacheKey]proxyCacheEntry), + } +} + +// proxyTarget looks up the cached agent URL for (sandboxID, teamID). +// On a miss it queries the DB, resolves the address, and populates the cache. +// The *httputil.ReverseProxy is built by the caller so the Director closure +// captures the correct port without the cache key needing to include it. +func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) { + cacheKey := makeProxyCacheKey(sandboxID, teamID) + + h.cacheMu.Lock() + entry, ok := h.cache[cacheKey] + h.cacheMu.Unlock() + + if ok && time.Now().Before(entry.expiresAt) { + return entry.agentURL, nil + } + + // Cache miss or expired — query DB. + target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{ + ID: sandboxID, + TeamID: teamID, + }) + if err != nil { + return nil, errProxySandboxNotFound + } + if target.Status != "running" { + return nil, errProxySandboxNotRunning{status: target.Status} + } + if target.HostAddress == "" { + return nil, errProxyNoHostAddress + } + + agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress)) + if err != nil { + return nil, fmt.Errorf("invalid host agent address: %w", err) + } + + h.cacheMu.Lock() + h.cache[cacheKey] = proxyCacheEntry{ + agentURL: agentURL, + expiresAt: time.Now().Add(proxyCacheTTL), + } + h.cacheMu.Unlock() + + return agentURL, nil +} + +// evictProxyCache removes the cached entry for a (sandbox, team) pair. +// Called on 502 so a stopped/moved sandbox is re-resolved on the next request. +func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) { + h.cacheMu.Lock() + delete(h.cache, makeProxyCacheKey(sandboxID, teamID)) + h.cacheMu.Unlock() +} + +func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host := r.Host + // Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost") + if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 { + host = host[:colonIdx] + } + + matches := sandboxHostPattern.FindStringSubmatch(host) + if matches == nil { + h.inner.ServeHTTP(w, r) + return + } + + port := matches[1] + sandboxIDStr := matches[2] + + // Validate port. + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + http.Error(w, "invalid port", http.StatusBadRequest) + return + } + + // Authenticate: require API key or JWT, extract team ID. + teamID, err := h.authenticateRequest(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", err.Error()) + return + } + + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + http.Error(w, "invalid sandbox ID", http.StatusBadRequest) + return + } + + agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID) + if err != nil { + switch { + case errors.Is(err, errProxySandboxNotFound): + http.Error(w, err.Error(), http.StatusNotFound) + case errors.As(err, new(errProxySandboxNotRunning)): + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } + return + } + + proxy := &httputil.ReverseProxy{ + Transport: h.transport, + Director: func(req *http.Request) { + req.URL.Scheme = agentURL.Scheme + req.URL.Host = agentURL.Host + req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path + req.Host = agentURL.Host + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + slog.Debug("sandbox proxy error", + "sandbox_id", sandboxIDStr, + "port", port, + "error", err, + ) + h.evictProxyCache(sandboxID, teamID) + http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) + }, + } + proxy.ServeHTTP(w, r) +} + +// authenticateRequest validates the request's API key and returns the team ID. +// Only API key authentication is supported for sandbox proxy requests (not JWT). +func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) { + key := r.Header.Get("X-API-Key") + if key == "" { + return pgtype.UUID{}, fmt.Errorf("X-API-Key header required") + } + + hash := auth.HashAPIKey(key) + row, err := h.db.GetAPIKeyByHash(r.Context(), hash) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid API key") + } + return row.TeamID, nil +} diff --git a/internal/api/handlers_apikeys.go b/internal/api/handlers_apikeys.go index 2637181d..700ddc52 100644 --- a/internal/api/handlers_apikeys.go +++ b/internal/api/handlers_apikeys.go @@ -9,6 +9,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -39,11 +40,11 @@ type apiKeyResponse struct { func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse { resp := apiKeyResponse{ - ID: k.ID, - TeamID: k.TeamID, + ID: id.FormatAPIKeyID(k.ID), + TeamID: id.FormatTeamID(k.TeamID), Name: k.Name, KeyPrefix: k.KeyPrefix, - CreatedBy: k.CreatedBy, + CreatedBy: id.FormatUserID(k.CreatedBy), } if k.CreatedAt.Valid { resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339) @@ -57,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse { func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse { resp := apiKeyResponse{ - ID: k.ID, - TeamID: k.TeamID, + ID: id.FormatAPIKeyID(k.ID), + TeamID: id.FormatTeamID(k.TeamID), Name: k.Name, KeyPrefix: k.KeyPrefix, - CreatedBy: k.CreatedBy, + CreatedBy: id.FormatUserID(k.CreatedBy), CreatorEmail: k.CreatorEmail, } if k.CreatedAt.Valid { @@ -118,7 +119,13 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) { // Delete handles DELETE /v1/api-keys/{id}. func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(r.Context()) - keyID := chi.URLParam(r, "id") + keyIDStr := chi.URLParam(r, "id") + + keyID, err := id.ParseAPIKeyID(keyIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID") + return + } if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key") diff --git a/internal/api/handlers_audit.go b/internal/api/handlers_audit.go index 7812309d..a19ab1dc 100644 --- a/internal/api/handlers_audit.go +++ b/internal/api/handlers_audit.go @@ -6,7 +6,10 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -65,13 +68,24 @@ func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) { limit = n } + // Parse ?before_id cursor (UUID). + var beforeID pgtype.UUID + if s := r.URL.Query().Get("before_id"); s != "" { + parsed, err := id.ParseAuditLogID(s) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID") + return + } + beforeID = parsed + } + entries, err := h.svc.List(r.Context(), service.AuditListParams{ TeamID: ac.TeamID, AdminScoped: ac.Role == "owner" || ac.Role == "admin", ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]), Actions: parseMultiParam(r.URL.Query()["action"]), Before: before, - BeforeID: r.URL.Query().Get("before_id"), + BeforeID: beforeID, Limit: limit, }) if err != nil { diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go index ba60d8e1..b1d4915f 100644 --- a/internal/api/handlers_auth.go +++ b/internal/api/handlers_auth.go @@ -20,7 +20,7 @@ import ( // It prefers the user's default team; if none is flagged as default it falls // back to the earliest-joined team. Returns pgx.ErrNoRows when the user has // no team memberships at all. -func loginTeam(ctx context.Context, q *db.Queries, userID string) (db.Team, string, error) { +func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) { team, err := q.GetDefaultTeamForUser(ctx, userID) if err == nil { membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID}) @@ -176,8 +176,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusCreated, authResponse{ Token: token, - UserID: userID, - TeamID: teamID, + UserID: id.FormatUserID(userID), + TeamID: id.FormatTeamID(teamID), Email: req.Email, Name: req.Name, }) @@ -236,8 +236,8 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, authResponse{ Token: token, - UserID: user.ID, - TeamID: team.ID, + UserID: id.FormatUserID(user.ID), + TeamID: id.FormatTeamID(team.ID), Email: user.Email, Name: user.Name, }) @@ -260,10 +260,16 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { return } + teamID, err := id.ParseTeamID(req.TeamID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id") + return + } + ctx := r.Context() // Verify team exists and is not deleted. - team, err := h.db.GetTeam(ctx, req.TeamID) + team, err := h.db.GetTeam(ctx, teamID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { writeError(w, http.StatusNotFound, "not_found", "team not found") @@ -280,7 +286,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { // Verify membership from DB — JWT role is not trusted here. membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: ac.UserID, - TeamID: req.TeamID, + TeamID: teamID, }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -298,7 +304,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin) + token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") return @@ -306,8 +312,8 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, authResponse{ Token: token, - UserID: ac.UserID, - TeamID: req.TeamID, + UserID: id.FormatUserID(ac.UserID), + TeamID: id.FormatTeamID(teamID), Email: ac.Email, Name: user.Name, }) diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go new file mode 100644 index 00000000..282c3f48 --- /dev/null +++ b/internal/api/handlers_builds.go @@ -0,0 +1,276 @@ +package api + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "connectrpc.com/connect" + "github.com/go-chi/chi/v5" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/service" + "git.omukk.dev/wrenn/sandbox/internal/validate" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" +) + +type buildHandler struct { + svc *service.BuildService + db *db.Queries + pool *lifecycle.HostClientPool +} + +func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler { + return &buildHandler{svc: svc, db: db, pool: pool} +} + +type createBuildRequest struct { + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []string `json:"recipe"` + Healthcheck string `json:"healthcheck"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + SkipPrePost bool `json:"skip_pre_post"` +} + +type buildResponse struct { + ID string `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe json.RawMessage `json:"recipe"` + Healthcheck *string `json:"healthcheck,omitempty"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + Status string `json:"status"` + CurrentStep int32 `json:"current_step"` + TotalSteps int32 `json:"total_steps"` + Logs json.RawMessage `json:"logs"` + Error *string `json:"error,omitempty"` + SandboxID *string `json:"sandbox_id,omitempty"` + HostID *string `json:"host_id,omitempty"` + CreatedAt string `json:"created_at"` + StartedAt *string `json:"started_at,omitempty"` + CompletedAt *string `json:"completed_at,omitempty"` +} + +func buildToResponse(b db.TemplateBuild) buildResponse { + resp := buildResponse{ + ID: id.FormatBuildID(b.ID), + Name: b.Name, + BaseTemplate: b.BaseTemplate, + Recipe: b.Recipe, + VCPUs: b.Vcpus, + MemoryMB: b.MemoryMb, + Status: b.Status, + CurrentStep: b.CurrentStep, + TotalSteps: b.TotalSteps, + Logs: b.Logs, + } + if b.Healthcheck != "" { + resp.Healthcheck = &b.Healthcheck + } + if b.Error != "" { + resp.Error = &b.Error + } + if b.SandboxID.Valid { + s := id.FormatSandboxID(b.SandboxID) + resp.SandboxID = &s + } + if b.HostID.Valid { + s := id.FormatHostID(b.HostID) + resp.HostID = &s + } + if b.CreatedAt.Valid { + resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339) + } + if b.StartedAt.Valid { + s := b.StartedAt.Time.Format(time.RFC3339) + resp.StartedAt = &s + } + if b.CompletedAt.Valid { + s := b.CompletedAt.Time.Format(time.RFC3339) + resp.CompletedAt = &s + } + return resp +} + +// Create handles POST /v1/admin/builds. +func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) { + var req createBuildRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + if req.Name == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "name is required") + return + } + if err := validate.SafeName(req.Name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err)) + return + } + if len(req.Recipe) == 0 { + writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command") + return + } + + build, err := h.svc.Create(r.Context(), service.BuildCreateParams{ + Name: req.Name, + BaseTemplate: req.BaseTemplate, + Recipe: req.Recipe, + Healthcheck: req.Healthcheck, + VCPUs: req.VCPUs, + MemoryMB: req.MemoryMB, + SkipPrePost: req.SkipPrePost, + }) + if err != nil { + slog.Error("failed to create build", "error", err) + writeError(w, http.StatusInternalServerError, "build_error", "failed to create build") + return + } + + writeJSON(w, http.StatusCreated, buildToResponse(build)) +} + +// List handles GET /v1/admin/builds. +func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) { + builds, err := h.svc.List(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds") + return + } + + resp := make([]buildResponse, len(builds)) + for i, b := range builds { + resp[i] = buildToResponse(b) + } + + writeJSON(w, http.StatusOK, resp) +} + +// Get handles GET /v1/admin/builds/{id}. +func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) { + buildIDStr := chi.URLParam(r, "id") + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID") + return + } + + build, err := h.svc.Get(r.Context(), buildID) + if err != nil { + writeError(w, http.StatusNotFound, "not_found", "build not found") + return + } + + writeJSON(w, http.StatusOK, buildToResponse(build)) +} + +// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams. +func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { + templates, err := h.db.ListTemplates(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates") + return + } + + type templateResponse struct { + Name string `json:"name"` + Type string `json:"type"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + SizeBytes int64 `json:"size_bytes"` + TeamID string `json:"team_id"` + CreatedAt string `json:"created_at"` + } + + resp := make([]templateResponse, len(templates)) + for i, t := range templates { + resp[i] = templateResponse{ + Name: t.Name, + Type: t.Type, + VCPUs: t.Vcpus, + MemoryMB: t.MemoryMb, + SizeBytes: t.SizeBytes, + TeamID: id.FormatTeamID(t.TeamID), + } + if t.CreatedAt.Valid { + resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339) + } + } + + writeJSON(w, http.StatusOK, resp) +} + +// DeleteTemplate handles DELETE /v1/admin/templates/{name}. +func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "name") + if err := validate.SafeName(name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err)) + return + } + ctx := r.Context() + + tmpl, err := h.db.GetPlatformTemplateByName(ctx, name) + if err != nil { + writeError(w, http.StatusNotFound, "not_found", "template not found") + return + } + if layout.IsMinimal(tmpl.TeamID, tmpl.ID) { + writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted") + return + } + + // Broadcast delete to all online hosts. + hosts, _ := h.db.ListActiveHosts(ctx) + for _, host := range hosts { + if host.Status != "online" { + continue + } + agent, err := h.pool.GetForHost(host) + if err != nil { + continue + } + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: formatUUIDForRPC(tmpl.TeamID), + TemplateId: formatUUIDForRPC(tmpl.ID), + })); err != nil { + if connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err) + } + } + } + + if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// Cancel handles POST /v1/admin/builds/{id}/cancel. +func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) { + buildIDStr := chi.URLParam(r, "id") + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID") + return + } + + if err := h.svc.Cancel(r.Context(), buildID); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go index 84b38330..596457b2 100644 --- a/internal/api/handlers_exec.go +++ b/internal/api/handlers_exec.go @@ -14,6 +14,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -46,10 +47,16 @@ type execResponse struct { // Exec handles POST /v1/sandboxes/{id}/exec. func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -80,7 +87,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { } resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: req.Cmd, Args: req.Args, TimeoutSec: req.TimeoutSec, @@ -101,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { Valid: true, }, }); err != nil { - slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err) + slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err) } // Use base64 encoding if output contains non-UTF-8 bytes. @@ -112,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { if !utf8.Valid(stdout) || !utf8.Valid(stderr) { encoding = "base64" writeJSON(w, http.StatusOK, execResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Cmd: req.Cmd, Stdout: base64.StdEncoding.EncodeToString(stdout), Stderr: base64.StdEncoding.EncodeToString(stderr), @@ -124,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, execResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Cmd: req.Cmd, Stdout: string(stdout), Stderr: string(stderr), diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go index 3ecfdfe7..52dfd17e 100644 --- a/internal/api/handlers_exec_stream.go +++ b/internal/api/handlers_exec_stream.go @@ -14,6 +14,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -48,10 +49,16 @@ type wsOutMsg struct { // ExecStream handles WS /v1/sandboxes/{id}/exec/stream. func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -91,7 +98,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { defer cancel() stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: startMsg.Cmd, Args: startMsg.Args, })) @@ -157,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { Valid: true, }, }); err != nil { - slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err) + slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err) } } diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go index c5fff70d..a2e9936c 100644 --- a/internal/api/handlers_files.go +++ b/internal/api/handlers_files.go @@ -11,6 +11,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl // - "path" text field: absolute destination path inside the sandbox // - "file" file field: binary content to write func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -82,7 +89,7 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { } if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: filePath, Content: content, })); err != nil { @@ -101,10 +108,16 @@ type readFileRequest struct { // Download handles POST /v1/sandboxes/{id}/files/read. // Accepts JSON body with path, returns raw file content with Content-Disposition. func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -133,7 +146,7 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { } resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: req.Path, })) if err != nil { diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go index 66e89c73..e6c040f2 100644 --- a/internal/api/handlers_files_stream.go +++ b/internal/api/handlers_files_stream.go @@ -12,6 +12,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -29,10 +30,16 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file // Expects multipart/form-data with "path" text field and "file" file field. // Streams file content directly from the request body to the host agent without buffering. func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -101,7 +108,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request if err := stream.Send(&pb.WriteFileStreamRequest{ Content: &pb.WriteFileStreamRequest_Meta{ Meta: &pb.WriteFileStreamMeta{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: filePath, }, }, @@ -146,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request // StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read. // Accepts JSON body with path, streams file content back without buffering. func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -178,7 +191,7 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque // Open server-streaming RPC to host agent. stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: req.Path, })) if err != nil { diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go index f4f79173..50652a00 100644 --- a/internal/api/handlers_hosts.go +++ b/internal/api/handlers_hosts.go @@ -8,9 +8,12 @@ import ( "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -46,6 +49,9 @@ type refreshTokenResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type deletePreviewResponse struct { @@ -66,6 +72,9 @@ type registerHostResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type addTagRequest struct { @@ -93,34 +102,35 @@ type hostResponse struct { func hostToResponse(h db.Host) hostResponse { resp := hostResponse{ - ID: h.ID, + ID: id.FormatHostID(h.ID), Type: h.Type, Status: h.Status, - CreatedBy: h.CreatedBy, + CreatedBy: id.FormatUserID(h.CreatedBy), } if h.TeamID.Valid { - resp.TeamID = &h.TeamID.String + s := id.FormatTeamID(h.TeamID) + resp.TeamID = &s } - if h.Provider.Valid { - resp.Provider = &h.Provider.String + if h.Provider != "" { + resp.Provider = &h.Provider } - if h.AvailabilityZone.Valid { - resp.AvailabilityZone = &h.AvailabilityZone.String + if h.AvailabilityZone != "" { + resp.AvailabilityZone = &h.AvailabilityZone } - if h.Arch.Valid { - resp.Arch = &h.Arch.String + if h.Arch != "" { + resp.Arch = &h.Arch } - if h.CpuCores.Valid { - resp.CPUCores = &h.CpuCores.Int32 + if h.CpuCores != 0 { + resp.CPUCores = &h.CpuCores } - if h.MemoryMb.Valid { - resp.MemoryMB = &h.MemoryMb.Int32 + if h.MemoryMb != 0 { + resp.MemoryMB = &h.MemoryMb } - if h.DiskGb.Valid { - resp.DiskGB = &h.DiskGb.Int32 + if h.DiskGb != 0 { + resp.DiskGB = &h.DiskGb } - if h.Address.Valid { - resp.Address = &h.Address.String + if h.Address != "" { + resp.Address = &h.Address } if h.LastHeartbeatAt.Valid { s := h.LastHeartbeatAt.Time.Format(time.RFC3339) @@ -133,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse { } // isAdmin fetches the user record and returns whether they are an admin. -func (h *hostHandler) isAdmin(r *http.Request, userID string) bool { +func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool { user, err := h.queries.GetUserByID(r.Context(), userID) if err != nil { return false @@ -151,14 +161,23 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(r.Context()) - result, err := h.svc.Create(r.Context(), service.HostCreateParams{ - Type: req.Type, - TeamID: req.TeamID, - Provider: req.Provider, - AvailabilityZone: req.AvailabilityZone, - RequestingUserID: ac.UserID, - IsRequestorAdmin: h.isAdmin(r, ac.UserID), - }) + // Parse optional team ID from request body. + var params service.HostCreateParams + params.Type = req.Type + params.Provider = req.Provider + params.AvailabilityZone = req.AvailabilityZone + params.RequestingUserID = ac.UserID + params.IsRequestorAdmin = h.isAdmin(r, ac.UserID) + if req.TeamID != "" { + teamID, err := id.ParseTeamID(req.TeamID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id") + return + } + params.TeamID = teamID + } + + result, err := h.svc.Create(r.Context(), params) if err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -166,8 +185,7 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { } // Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team). - hostTeamID := result.Host.TeamID.String - h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, hostTeamID) + h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID) writeJSON(w, http.StatusCreated, createHostResponse{ Host: hostToResponse(result.Host), @@ -192,14 +210,22 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { seen := make(map[string]struct{}) for _, host := range hosts { if host.TeamID.Valid { - seen[host.TeamID.String] = struct{}{} + key := id.FormatTeamID(host.TeamID) + seen[key] = struct{}{} } } if len(seen) > 0 { teamNames = make(map[string]string, len(seen)) - for id := range seen { - if team, err := h.queries.GetTeam(r.Context(), id); err == nil { - teamNames[id] = team.Name + for _, host := range hosts { + if !host.TeamID.Valid { + continue + } + key := id.FormatTeamID(host.TeamID) + if _, ok := teamNames[key]; ok { + continue + } + if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil { + teamNames[key] = team.Name } } } @@ -209,7 +235,8 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { for i, host := range hosts { resp[i] = hostToResponse(host) if host.TeamID.Valid { - if name, ok := teamNames[host.TeamID.String]; ok { + key := id.FormatTeamID(host.TeamID) + if name, ok := teamNames[key]; ok { resp[i].TeamName = &name } } @@ -220,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/hosts/{id}. func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -236,9 +269,15 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { // DeletePreview handles GET /v1/hosts/{id}/delete-preview. // Returns what would be affected without making changes, for confirmation UI. func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -256,19 +295,25 @@ func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) { // Without ?force=true: returns 409 with affected sandbox IDs if any are active. // With ?force=true: gracefully stops all sandboxes then deletes the host. func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) force := r.URL.Query().Get("force") == "true" + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + // Fetch host before deletion to capture team_id for audit. deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID) if hostErr != nil { - slog.Warn("audit: could not fetch host before delete", "host_id", hostID, "error", hostErr) + slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr) } - err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force) + err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force) if err == nil { - h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID.String) + h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID) w.WriteHeader(http.StatusNoContent) return } @@ -292,9 +337,15 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { // RegenerateToken handles POST /v1/hosts/{id}/token. func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -343,14 +394,23 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } // Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated). func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") hc := auth.MustHostFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + // Prevent a host from heartbeating for a different host. if hostID != hc.HostID { writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch") @@ -368,7 +428,7 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { // Log marked_up if the host just recovered from unreachable. if prevHost.Status == "unreachable" { - h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID.String, hc.HostID) + h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID) } w.WriteHeader(http.StatusNoContent) @@ -376,10 +436,16 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { // AddTag handles POST /v1/hosts/{id}/tags. func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) admin := h.isAdmin(r, ac.UserID) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + var req addTagRequest if err := decodeJSON(r, &req); err != nil { writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") @@ -401,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) { // RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}. func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") tag := chi.URLParam(r, "tag") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -438,14 +510,23 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } // ListTags handles GET /v1/hosts/{id}/tags. func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) diff --git a/internal/api/handlers_metrics.go b/internal/api/handlers_metrics.go index 793349e5..25f485c6 100644 --- a/internal/api/handlers_metrics.go +++ b/internal/api/handlers_metrics.go @@ -7,9 +7,11 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -38,10 +40,16 @@ type metricsResponse struct { // GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h. func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + rangeTier := r.URL.Query().Get("range") if rangeTier == "" { rangeTier = "10m" @@ -60,15 +68,15 @@ func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Reques switch sb.Status { case "running": - h.getFromAgent(w, r, sandboxID, rangeTier, sb.HostID) + h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID) case "paused": - h.getFromDB(ctx, w, sandboxID, rangeTier) + h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier) default: writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status) } } -func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxID, rangeTier, hostID string) { +func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) { ctx := r.Context() agent, err := agentForHost(ctx, h.db, h.pool, hostID) @@ -78,7 +86,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ } resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Range: rangeTier, })) if err != nil { @@ -98,7 +106,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ } writeJSON(w, http.StatusOK, metricsResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Range: rangeTier, Points: points, }) @@ -118,7 +126,7 @@ var rangeToDB = map[string]struct { "24h": {"24h", 24 * time.Hour}, } -func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxID, rangeTier string) { +func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) { mapping := rangeToDB[rangeTier] rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{ SandboxID: sandboxID, @@ -141,7 +149,7 @@ func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWr } writeJSON(w, http.StatusOK, metricsResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Range: rangeTier, Points: points, }) diff --git a/internal/api/handlers_oauth.go b/internal/api/handlers_oauth.go index 348dd859..a9c448ac 100644 --- a/internal/api/handlers_oauth.go +++ b/internal/api/handlers_oauth.go @@ -162,7 +162,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { redirectWithError(w, r, redirectBase, "internal_error") return } - redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name) return } if !errors.Is(err, pgx.ErrNoRows) { @@ -262,7 +262,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { return } - redirectWithToken(w, r, redirectBase, token, userID, teamID, email, profile.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name) } // retryAsLogin handles the race where a concurrent request already created the user. @@ -296,7 +296,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov redirectWithError(w, r, redirectBase, "internal_error") return } - redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name) } func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) { diff --git a/internal/api/handlers_sandbox.go b/internal/api/handlers_sandbox.go index b2709a5d..a19a7cc3 100644 --- a/internal/api/handlers_sandbox.go +++ b/internal/api/handlers_sandbox.go @@ -10,6 +10,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -46,7 +47,7 @@ type sandboxResponse struct { func sandboxToResponse(sb db.Sandbox) sandboxResponse { resp := sandboxResponse{ - ID: sb.ID, + ID: id.FormatSandboxID(sb.ID), Status: sb.Status, Template: sb.Template, VCPUs: sb.Vcpus, @@ -81,7 +82,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { } ac := auth.MustFromContext(r.Context()) - if ac.TeamID == "" { + if !ac.TeamID.Valid { writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate") return } @@ -122,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/sandboxes/{id}. func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -136,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { // Pause handles POST /v1/sandboxes/{id}/pause. func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -152,9 +165,15 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { // Resume handles POST /v1/sandboxes/{id}/resume. func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -168,9 +187,15 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { // Ping handles POST /v1/sandboxes/{id}/ping. func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -182,9 +207,15 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) { // Destroy handles DELETE /v1/sandboxes/{id}. func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index fbfcdc18..f7d05f2f 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -10,12 +10,14 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" @@ -35,8 +37,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life // deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts. // Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts -// and ignore NotFound errors. TODO: add host_id to templates table. -func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error { +// and ignore NotFound errors. +func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error { hosts, err := h.db.ListActiveHosts(ctx) if err != nil { return fmt.Errorf("list hosts: %w", err) @@ -49,9 +51,12 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri if err != nil { continue } - if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: formatUUIDForRPC(teamID), + TemplateId: formatUUIDForRPC(templateID), + })); err != nil { if connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err) + slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err) } } } @@ -70,6 +75,7 @@ type snapshotResponse struct { MemoryMB *int32 `json:"memory_mb,omitempty"` SizeBytes int64 `json:"size_bytes"` CreatedAt string `json:"created_at"` + Platform bool `json:"platform"` } func templateToResponse(t db.Template) snapshotResponse { @@ -77,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse { Name: t.Name, Type: t.Type, SizeBytes: t.SizeBytes, + Platform: t.TeamID == id.PlatformTeamID, } - if t.Vcpus.Valid { - resp.VCPUs = &t.Vcpus.Int32 + if t.Vcpus != 0 { + resp.VCPUs = &t.Vcpus } - if t.MemoryMb.Valid { - resp.MemoryMB = &t.MemoryMb.Int32 + if t.MemoryMb != 0 { + resp.MemoryMB = &t.MemoryMb } if t.CreatedAt.Valid { resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339) @@ -103,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } + sandboxID, err := id.ParseSandboxID(req.SandboxID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id") + return + } + if req.Name == "" { req.Name = id.NewSnapshotName() } @@ -115,14 +128,20 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(ctx) overwrite := r.URL.Query().Get("overwrite") == "true" + // Check for global name collision. + if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil { + writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template") + return + } + // Check if name already exists for this team. - if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil { + if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil { if !overwrite { writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace") return } // Delete old snapshot files from all hosts before removing the DB record. - if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil { + if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files") return } @@ -133,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { } // Verify sandbox exists, belongs to team, and is running or paused. - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID}) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -149,30 +168,53 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } - resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ - SandboxId: req.SandboxID, - Name: req.Name, + // Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC. + // The host agent's CreateSnapshot removes the sandbox from its in-memory + // map immediately; if the reconciler fires during the flatten window and + // the DB still says "running", it will mark the sandbox "stopped". + if sb.Status == "running" { + if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "paused", + }); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status") + return + } + } + + // Use a detached context with a generous timeout so the snapshot completes + // even if the client disconnects (the flatten step can take 10-20s). + snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer snapCancel() + + // Generate the new template ID upfront so the host agent knows where to store files. + newTemplateID := id.NewTemplateID() + + resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{ + SandboxId: req.SandboxID, + Name: req.Name, + TeamId: formatUUIDForRPC(ac.TeamID), + TemplateId: formatUUIDForRPC(newTemplateID), })) if err != nil { + // Snapshot failed — revert status back to what it was. + if sb.Status == "running" { + if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "running", + }); dbErr != nil { + slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr) + } + } status, code, msg := agentErrToHTTP(err) writeError(w, status, code, msg) return } - // Mark sandbox as paused (if it was running, it got paused by the snapshot). - if sb.Status != "paused" { - if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: req.SandboxID, Status: "paused", - }); err != nil { - slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err) - } - } - - tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{ + tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{ + ID: newTemplateID, Name: req.Name, Type: "snapshot", - Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true}, - MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true}, + Vcpus: sb.Vcpus, + MemoryMb: sb.MemoryMb, SizeBytes: resp.Msg.SizeBytes, TeamID: ac.TeamID, }) @@ -182,7 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } - h.audit.LogSnapshotCreate(r.Context(), ac, req.Name) + h.audit.LogSnapshotCreate(snapCtx, ac, req.Name) + + if ctx.Err() != nil { + slog.Info("snapshot created but client disconnected before response", "name", req.Name) + return + } writeJSON(w, http.StatusCreated, templateToResponse(tmpl)) } @@ -215,12 +262,22 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ac := auth.MustFromContext(ctx) - if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil { + tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}) + if err != nil { writeError(w, http.StatusNotFound, "not_found", "template not found") return } + // Platform templates can only be deleted by admins via /v1/admin/templates. + if tmpl.TeamID == id.PlatformTeamID { + writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here") + return + } + if layout.IsMinimal(tmpl.TeamID, tmpl.ID) { + writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted") + return + } - if err := h.deleteSnapshotBroadcast(ctx, name); err != nil { + if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files") return } diff --git a/internal/api/handlers_team.go b/internal/api/handlers_team.go index 3950ab73..2bf99f9a 100644 --- a/internal/api/handlers_team.go +++ b/internal/api/handlers_team.go @@ -7,10 +7,12 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -48,7 +50,7 @@ type memberResponse struct { func teamToResponse(t db.Team) teamResponse { resp := teamResponse{ - ID: t.ID, + ID: id.FormatTeamID(t.ID), Name: t.Name, Slug: t.Slug, IsByoc: t.IsByoc, @@ -72,11 +74,16 @@ func memberInfoToResponse(m service.MemberInfo) memberResponse { // requireTeamAccess is an inline check used by every team-scoped handler: // the JWT team_id must match the URL {id} before any DB call is made. // Returns false and writes 403 if they don't match. -func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (string, bool) { - teamID := chi.URLParam(r, "id") +func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) { + teamIDStr := chi.URLParam(r, "id") + teamID, err := id.ParseTeamID(teamIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID") + return pgtype.UUID{}, false + } if ac.TeamID != teamID { writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first") - return "", false + return pgtype.UUID{}, false } return teamID, true } @@ -185,7 +192,7 @@ func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) { // Fetch old name for audit log before renaming. oldTeam, err := h.svc.GetTeam(r.Context(), teamID) if err != nil { - slog.Warn("audit: could not fetch old team name for rename log", "team_id", teamID, "error", err) + slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err) } if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil { @@ -267,7 +274,11 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) { return } - h.audit.LogMemberAdd(r.Context(), ac, member.UserID, member.Email, member.Role) + // member.UserID is already formatted with prefix; parse it back for the audit logger. + targetUserID, parseErr := id.ParseUserID(member.UserID) + if parseErr == nil { + h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role) + } writeJSON(w, http.StatusCreated, memberInfoToResponse(member)) } @@ -279,7 +290,13 @@ func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { if !ok { return } - targetUserID := chi.URLParam(r, "uid") + targetUserIDStr := chi.URLParam(r, "uid") + + targetUserID, err := id.ParseUserID(targetUserIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID") + return + } if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil { status, code, msg := serviceErrToHTTP(err) @@ -299,7 +316,13 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) { if !ok { return } - targetUserID := chi.URLParam(r, "uid") + targetUserIDStr := chi.URLParam(r, "uid") + + targetUserID, err := id.ParseUserID(targetUserIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID") + return + } var req struct { Role string `json:"role"` @@ -341,7 +364,13 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) { // SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only). // Enables or disables the BYOC feature flag for a team. func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) { - teamID := chi.URLParam(r, "id") + teamIDStr := chi.URLParam(r, "id") + + teamID, err := id.ParseTeamID(teamIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID") + return + } var req struct { Enabled bool `json:"enabled"` diff --git a/internal/api/handlers_users.go b/internal/api/handlers_users.go index 8269d3c0..17050641 100644 --- a/internal/api/handlers_users.go +++ b/internal/api/handlers_users.go @@ -8,6 +8,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) type usersHandler struct { @@ -45,7 +46,7 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) { } resp := make([]userResult, len(results)) for i, u := range results { - resp[i] = userResult{UserID: u.ID, Email: u.Email} + resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email} } writeJSON(w, http.StatusOK, resp) } diff --git a/internal/api/host_monitor.go b/internal/api/host_monitor.go index 4bf19d87..95fde106 100644 --- a/internal/api/host_monitor.go +++ b/internal/api/host_monitor.go @@ -6,9 +6,11 @@ import ( "time" "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -82,15 +84,15 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold if stale && host.Status != "unreachable" { - slog.Info("host monitor: marking host unreachable", "host_id", host.ID, + slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID), "last_heartbeat", host.LastHeartbeatAt.Time) if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil { - slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err) } if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil { - slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err) } - m.audit.LogHostMarkedDown(ctx, host.TeamID.String, host.ID) + m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID) return } @@ -110,19 +112,20 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { if err != nil { // RPC failure is a transient condition; the passive phase will catch it // if heartbeats stop arriving. - slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err) + slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err) return } // Build set of sandbox IDs alive on the host. + // The host agent returns sandbox IDs as strings (formatted with prefix). alive := make(map[string]struct{}, len(resp.Msg.Sandboxes)) for _, sb := range resp.Msg.Sandboxes { alive[sb.SandboxId] = struct{}{} } autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds)) - for _, id := range resp.Msg.AutoPausedSandboxIds { - autoPaused[id] = struct{}{} + for _, apID := range resp.Msg.AutoPausedSandboxIds { + autoPaused[apID] = struct{}{} } // --- Restore sandboxes that are "missing" in DB but alive on host --- @@ -134,30 +137,31 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { Column2: []string{"missing"}, }) if err != nil { - slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } else { - var toRestore []string - var toStop []string + var toRestore []pgtype.UUID + var toStop []pgtype.UUID for _, sb := range missingSandboxes { - if _, ok := alive[sb.ID]; ok { + sbIDStr := id.FormatSandboxID(sb.ID) + if _, ok := alive[sbIDStr]; ok { toRestore = append(toRestore, sb.ID) } else { toStop = append(toStop, sb.ID) } } if len(toRestore) > 0 { - slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore)) + slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore)) if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil { - slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } } if len(toStop) > 0 { - slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop)) + slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toStop, Status: "stopped", }); err != nil { - slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } } } @@ -169,18 +173,19 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { Column2: []string{"running"}, }) if err != nil { - slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) return } - var toPause, toStop []string - sbTeamID := make(map[string]string, len(runningSandboxes)) + var toPause, toStop []pgtype.UUID + sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes)) for _, sb := range runningSandboxes { + sbIDStr := id.FormatSandboxID(sb.ID) sbTeamID[sb.ID] = sb.TeamID - if _, ok := alive[sb.ID]; ok { + if _, ok := alive[sbIDStr]; ok { continue } - if _, ok := autoPaused[sb.ID]; ok { + if _, ok := autoPaused[sbIDStr]; ok { toPause = append(toPause, sb.ID) } else { toStop = append(toStop, sb.ID) @@ -188,24 +193,24 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { } if len(toPause) > 0 { - slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause)) + slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toPause, Status: "paused", }); err != nil { - slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err) } for _, sbID := range toPause { m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID) } } if len(toStop) > 0 { - slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop)) + slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toStop, Status: "stopped", }); err != nil { - slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err) } } } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 6a56293c..5c9d8cdf 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -12,6 +12,9 @@ import ( "time" "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" ) type errorResponse struct { @@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) { }) } +// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages. +func formatUUIDForRPC(u pgtype.UUID) string { + return id.UUIDString(u) +} + // agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message. func agentErrToHTTP(err error) (int, string, string) { switch connect.CodeOf(err) { diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go index dee4240f..985b2899 100644 --- a/internal/api/middleware_auth.go +++ b/internal/api/middleware_auth.go @@ -7,6 +7,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT. @@ -24,7 +25,7 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler } if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { - slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err) + slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err) } ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ @@ -45,9 +46,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler return } + teamID, err := id.ParseTeamID(claims.TeamID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token") + return + } + userID, err := id.ParseUserID(claims.Subject) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token") + return + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ - TeamID: claims.TeamID, - UserID: claims.Subject, + TeamID: teamID, + UserID: userID, Email: claims.Email, Name: claims.Name, Role: claims.Role, diff --git a/internal/api/middleware_hosttoken.go b/internal/api/middleware_hosttoken.go index a5c5e6f1..b926e41f 100644 --- a/internal/api/middleware_hosttoken.go +++ b/internal/api/middleware_hosttoken.go @@ -4,6 +4,7 @@ import ( "net/http" "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireHostToken validates the X-Host-Token header containing a host JWT, @@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler { return } - ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID}) + hostID, err := id.ParseHostID(claims.HostID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token") + return + } + + ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID}) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index 96b1c68a..37215388 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -5,6 +5,7 @@ import ( "strings" "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireJWT validates the Authorization: Bearer header, verifies the JWT @@ -25,9 +26,20 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler { return } + teamID, err := id.ParseTeamID(claims.TeamID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token") + return + } + userID, err := id.ParseUserID(claims.Subject) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token") + return + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ - TeamID: claims.TeamID, - UserID: claims.Subject, + TeamID: teamID, + UserID: userID, Email: claims.Email, Name: claims.Name, Role: claims.Role, diff --git a/internal/api/server.go b/internal/api/server.go index 918476bf..5d854b90 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" @@ -22,7 +23,8 @@ var openapiYAML []byte // Server is the control plane HTTP server. type Server struct { - router chi.Router + router chi.Router + BuildSvc *service.BuildService } // New constructs the chi router and registers all routes. @@ -35,6 +37,7 @@ func New( jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string, + ca *auth.CA, ) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -43,10 +46,11 @@ func New( sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched} apiKeySvc := &service.APIKeyService{DB: queries} templateSvc := &service.TemplateService{DB: queries} - hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool} + hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca} teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} auditSvc := &service.AuditService{DB: queries} statsSvc := &service.StatsService{DB: queries, Pool: pgPool} + buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched} al := audit.New(queries) @@ -65,6 +69,7 @@ func New( auditH := newAuditHandler(auditSvc) statsH := newStatsHandler(statsSvc) metricsH := newSandboxMetricsHandler(queries, pool) + buildH := newBuildHandler(buildSvc, queries, pool) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -174,9 +179,15 @@ func New( r.Use(requireJWT(jwtSecret)) r.Use(requireAdmin(queries)) r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Get("/templates", buildH.ListTemplates) + r.Delete("/templates/{name}", buildH.DeleteTemplate) + r.Post("/builds", buildH.Create) + r.Get("/builds", buildH.List) + r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) }) - return &Server{router: r} + return &Server{router: r, BuildSvc: buildSvc} } // Handler returns the HTTP handler. diff --git a/internal/audit/logger.go b/internal/audit/logger.go index 8f44059f..e60d8a6c 100644 --- a/internal/audit/logger.go +++ b/internal/audit/logger.go @@ -25,18 +25,15 @@ func New(queries *db.Queries) *AuditLogger { } // actorFields extracts actor_type, actor_id, and actor_name from an AuthContext. -func actorFields(ac auth.AuthContext) (actorType string, actorID pgtype.Text, actorName pgtype.Text) { - if ac.UserID != "" { - return "user", - pgtype.Text{String: ac.UserID, Valid: true}, - pgtype.Text{String: ac.Name, Valid: ac.Name != ""} +// actor_id is stored as a prefixed string in the TEXT column. +func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) { + if ac.UserID.Valid { + return "user", id.FormatUserID(ac.UserID), ac.Name } - if ac.APIKeyID != "" { - return "api_key", - pgtype.Text{String: ac.APIKeyID, Valid: true}, - pgtype.Text{String: ac.APIKeyName, Valid: true} + if ac.APIKeyID.Valid { + return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName } - return "system", pgtype.Text{}, pgtype.Text{} + return "system", "", "" } func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) { @@ -44,7 +41,6 @@ func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) { slog.Warn("audit: failed to write log entry", "action", p.Action, "resource_type", p.ResourceType, - "team_id", p.TeamID, "error", err, ) } @@ -61,18 +57,26 @@ func marshalMeta(meta map[string]any) []byte { return b } +// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one. +func optText(s string) pgtype.Text { + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + // --- Sandbox events (scope: team) --- -func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID, template string) { +func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "create", Scope: "team", Status: "success", @@ -80,16 +84,16 @@ func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "pause", Scope: "team", Status: "success", @@ -98,15 +102,15 @@ func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, } // LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler). -func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID string) { +func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) { l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "pause", Scope: "team", Status: "info", @@ -114,16 +118,16 @@ func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID }) } -func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "resume", Scope: "team", Status: "success", @@ -131,16 +135,16 @@ func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "destroy", Scope: "team", Status: "warning", @@ -156,10 +160,10 @@ func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "snapshot", - ResourceID: pgtype.Text{String: name, Valid: true}, + ResourceID: optText(name), Action: "create", Scope: "team", Status: "success", @@ -173,10 +177,10 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "snapshot", - ResourceID: pgtype.Text{String: name, Valid: true}, + ResourceID: optText(name), Action: "delete", Scope: "team", Status: "warning", @@ -186,16 +190,16 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext // --- Team events (scope: team) --- -func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID, oldName, newName string) { +func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "team", - ResourceID: pgtype.Text{String: teamID, Valid: true}, + ResourceID: optText(id.FormatTeamID(teamID)), Action: "rename", Scope: "team", Status: "info", @@ -205,16 +209,16 @@ func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, te // --- API key events (scope: team) --- -func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID, keyName string) { +func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "api_key", - ResourceID: pgtype.Text{String: keyID, Valid: true}, + ResourceID: optText(id.FormatAPIKeyID(keyID)), Action: "create", Scope: "team", Status: "success", @@ -222,16 +226,16 @@ func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID string) { +func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "api_key", - ResourceID: pgtype.Text{String: keyID, Valid: true}, + ResourceID: optText(id.FormatAPIKeyID(keyID)), Action: "revoke", Scope: "team", Status: "warning", @@ -241,16 +245,16 @@ func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, // --- Member events (scope: admin) --- -func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID, targetEmail, role string) { +func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "add", Scope: "admin", Status: "success", @@ -258,16 +262,16 @@ func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, tar }) } -func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID string) { +func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "remove", Scope: "admin", Status: "warning", @@ -277,14 +281,18 @@ func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) { actorType, actorID, actorName := actorFields(ac) + resourceID := "" + if ac.UserID.Valid { + resourceID = id.FormatUserID(ac.UserID) + } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: ac.UserID, Valid: ac.UserID != ""}, + ResourceID: optText(resourceID), Action: "leave", Scope: "admin", Status: "info", @@ -292,16 +300,16 @@ func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) { }) } -func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID, newRole string) { +func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "role_update", Scope: "admin", Status: "info", @@ -311,24 +319,24 @@ func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthConte // --- Host events (scope: admin) --- -func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID string) { +func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) // For shared hosts with no owning team, use the caller's team. logTeamID := teamID - if logTeamID == "" { + if !logTeamID.Valid { logTeamID = ac.TeamID } - if logTeamID == "" { + if !logTeamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: logTeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "create", Scope: "admin", Status: "success", @@ -336,23 +344,23 @@ func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, ho }) } -func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID string) { +func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) logTeamID := teamID - if logTeamID == "" { + if !logTeamID.Valid { logTeamID = ac.TeamID } - if logTeamID == "" { + if !logTeamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: logTeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "delete", Scope: "admin", Status: "warning", @@ -361,9 +369,8 @@ func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, ho } // LogHostMarkedDown records a system-initiated host status transition to unreachable. -// teamID must be non-empty (BYOC hosts only); shared hosts are not logged. -func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID string) { - if teamID == "" { +func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) { + if !teamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ @@ -371,9 +378,9 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "marked_down", Scope: "admin", Status: "error", @@ -382,9 +389,8 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri } // LogHostMarkedUp records a system-initiated host status transition back to online. -// teamID must be non-empty (BYOC hosts only); shared hosts are not logged. -func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string) { - if teamID == "" { +func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) { + if !teamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ @@ -392,9 +398,9 @@ func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "marked_up", Scope: "admin", Status: "success", diff --git a/internal/auth/cert.go b/internal/auth/cert.go new file mode 100644 index 00000000..d76f1de3 --- /dev/null +++ b/internal/auth/cert.go @@ -0,0 +1,251 @@ +package auth + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "sync/atomic" + "time" +) + +// CPCertRenewInterval is how often the control plane should renew its client +// certificate. It is set to half the cert TTL so there is always a wide safety +// margin before expiry. +const CPCertRenewInterval = cpCertTTL / 2 + +const ( + hostCertTTL = 7 * 24 * time.Hour + cpCertTTL = 24 * time.Hour +) + +// CA holds a parsed certificate authority ready to issue leaf certificates. +type CA struct { + Cert *x509.Certificate + Key *ecdsa.PrivateKey + PEM string // PEM-encoded certificate for embedding in register/refresh responses +} + +// ParseCA parses PEM-encoded CA certificate and private key strings. +// The cert and key are expected to be ECDSA P-256. +func ParseCA(certPEM, keyPEM string) (*CA, error) { + certBlock, _ := pem.Decode([]byte(certPEM)) + if certBlock == nil { + return nil, fmt.Errorf("failed to decode CA certificate PEM") + } + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA certificate: %w", err) + } + + keyBlock, _ := pem.Decode([]byte(keyPEM)) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode CA key PEM") + } + keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA private key: %w", err) + } + + return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil +} + +// HostCert holds all material returned when issuing a leaf cert for a host agent. +type HostCert struct { + CertPEM string + KeyPEM string + Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint + ExpiresAt time.Time // stored in hosts.cert_expires_at + TLSCert tls.Certificate +} + +// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server +// certificate for the host agent. hostID becomes the common name; the host's +// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS +// stack can verify the connection without disabling hostname checking. +func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return HostCert{}, fmt.Errorf("generate host key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return HostCert{}, err + } + + now := time.Now() + expires := now.Add(hostCertTTL) + + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: hostID}, + NotBefore: now.Add(-time.Minute), // small clock-skew tolerance + NotAfter: expires, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + // Extract IP from "ip:port" address; fall back to DNS SAN if not parseable. + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + if ip := net.ParseIP(host); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{host} + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return HostCert{}, fmt.Errorf("create host certificate: %w", err) + } + + certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return HostCert{}, fmt.Errorf("marshal host key: %w", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + + tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return HostCert{}, fmt.Errorf("build TLS certificate: %w", err) + } + + fp := fmt.Sprintf("%x", sha256.Sum256(derBytes)) + + return HostCert{ + CertPEM: certPEM, + KeyPEM: keyPEM, + Fingerprint: fp, + ExpiresAt: expires, + TLSCert: tlsCert, + }, nil +} + +// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for +// the control plane to present during mTLS handshakes with host agents. +// Called once at CP startup; the result is embedded into the shared HTTP client. +func IssueCPClientCert(ca *CA) (tls.Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return tls.Certificate{}, err + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "wrenn-cp"}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(cpCertTTL), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the +// PEM-encoded CA certificate. This is used on the agent side where only the +// CA certificate (not the private key) is available. +func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM([]byte(caCertPEM)) { + return nil + } + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: pool, + GetCertificate: getCert, + MinVersion: tls.VersionTLS13, + } +} + +// CPCertStore provides lock-free read/write access to the control plane's +// current client TLS certificate. It is used with tls.Config.GetClientCertificate +// to enable hot-swap without restarting the HTTP client. +// +// The zero value is not usable; use NewCPCertStore to create one. +type CPCertStore struct { + ptr atomic.Pointer[tls.Certificate] + ca *CA +} + +// NewCPCertStore issues an initial CP client certificate from ca and returns a +// store that can renew it in place. Returns an error if the initial issuance fails. +func NewCPCertStore(ca *CA) (*CPCertStore, error) { + s := &CPCertStore{ca: ca} + if err := s.Refresh(); err != nil { + return nil, err + } + return s, nil +} + +// Refresh issues a fresh CP client certificate and atomically stores it. +// If issuance fails the existing cert is unchanged. +func (s *CPCertStore) Refresh() error { + cert, err := IssueCPClientCert(s.ca) + if err != nil { + return fmt.Errorf("renew CP client certificate: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called +// per-handshake and always returns the most recently stored certificate. +func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no CP client certificate available") + } + return cert, nil +} + +// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client. +// It uses certStore.GetClientCertificate so the certificate can be renewed +// without replacing the config or transport. +func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config { + pool := x509.NewCertPool() + pool.AddCert(ca.Cert) + return &tls.Config{ + RootCAs: pool, + GetClientCertificate: certStore.GetClientCertificate, + MinVersion: tls.VersionTLS13, + } +} + +// randomSerial returns a random 128-bit certificate serial number. +func randomSerial() (*big.Int, error) { + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("generate serial number: %w", err) + } + return serial, nil +} diff --git a/internal/auth/context.go b/internal/auth/context.go index 22bf7957..762227cf 100644 --- a/internal/auth/context.go +++ b/internal/auth/context.go @@ -1,6 +1,10 @@ package auth -import "context" +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) type contextKey int @@ -8,14 +12,14 @@ const authCtxKey contextKey = 0 // AuthContext is stamped into request context by auth middleware. type AuthContext struct { - TeamID string - UserID string // empty when authenticated via API key - Email string // empty when authenticated via API key - Name string // empty when authenticated via API key - Role string // owner, admin, or member; empty when authenticated via API key - IsAdmin bool // platform-level admin; always false when authenticated via API key - APIKeyID string // populated when authenticated via API key; empty for JWT auth - APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth + TeamID pgtype.UUID + UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key + Email string // empty when authenticated via API key + Name string // empty when authenticated via API key + Role string // owner, admin, or member; empty when authenticated via API key + IsAdmin bool // platform-level admin; always false when authenticated via API key + APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth + APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth } // WithAuthContext returns a new context with the given AuthContext. @@ -43,7 +47,7 @@ const hostCtxKey contextKey = 1 // HostContext is stamped into request context by host token middleware. type HostContext struct { - HostID string + HostID pgtype.UUID } // WithHostContext returns a new context with the given HostContext. diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index fd1bc02d..840cd3bb 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -5,6 +5,9 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" ) const jwtExpiry = 6 * time.Hour @@ -23,16 +26,16 @@ type Claims struct { } // SignJWT signs a new 6-hour JWT for the given user. -func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) { +func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) { now := time.Now() claims := Claims{ - TeamID: teamID, + TeamID: id.FormatTeamID(teamID), Role: role, Email: email, Name: name, IsAdmin: isAdmin, RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, + Subject: id.FormatUserID(userID), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)), }, @@ -70,14 +73,15 @@ type HostClaims struct { jwt.RegisteredClaims } -// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent. -func SignHostJWT(secret []byte, hostID string) (string, error) { +// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent. +func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) { + formatted := id.FormatHostID(hostID) now := time.Now() claims := HostClaims{ Type: "host", - HostID: hostID, + HostID: formatted, RegisteredClaims: jwt.RegisteredClaims{ - Subject: hostID, + Subject: formatted, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)), }, diff --git a/internal/config/config.go b/internal/config/config.go index 7ef0aa69..e4e67403 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,11 @@ type Config struct { ListenAddr string JWTSecret string + // mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either + // disables cert issuance and leaves agent connections on plain HTTP (dev mode). + CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate + CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key + OAuthGitHubClientID string OAuthGitHubClientSecret string OAuthRedirectURL string @@ -28,9 +33,12 @@ func Load() Config { return Config{ DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"), - ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), + ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"), JWTSecret: os.Getenv("JWT_SECRET"), + CACert: os.Getenv("WRENN_CA_CERT"), + CAKey: os.Getenv("WRENN_CA_KEY"), + OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), diff --git a/internal/db/api_keys.sql.go b/internal/db/api_keys.sql.go index b4f0ffcc..4b8d3699 100644 --- a/internal/db/api_keys.sql.go +++ b/internal/db/api_keys.sql.go @@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2 ` type DeleteAPIKeyParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error { @@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_ ` type InsertAPIKeyParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` - Name string `json:"name"` - KeyHash string `json:"key_hash"` - KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + Name string `json:"name"` + KeyHash string `json:"key_hash"` + KeyPrefix string `json:"key_prefix"` + CreatedBy pgtype.UUID `json:"created_by"` } func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) { @@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC ` -func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) { +func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) { rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID) if err != nil { return nil, err @@ -126,18 +126,18 @@ ORDER BY k.created_at DESC ` type ListAPIKeysByTeamWithCreatorRow struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` Name string `json:"name"` KeyHash string `json:"key_hash"` KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` LastUsed pgtype.Timestamptz `json:"last_used"` CreatorEmail string `json:"creator_email"` } -func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) { +func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) { rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID) if err != nil { return nil, err @@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec UPDATE team_api_keys SET last_used = NOW() WHERE id = $1 ` -func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error { +func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id) return err } diff --git a/internal/db/audit.sql.go b/internal/db/audit.sql.go index 9370eca9..69b2b8ce 100644 --- a/internal/db/audit.sql.go +++ b/internal/db/audit.sql.go @@ -17,11 +17,11 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ` type InsertAuditLogParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` ActorType string `json:"actor_type"` ActorID pgtype.Text `json:"actor_id"` - ActorName pgtype.Text `json:"actor_name"` + ActorName string `json:"actor_name"` ResourceType string `json:"resource_type"` ResourceID pgtype.Text `json:"resource_id"` Action string `json:"action"` @@ -60,12 +60,12 @@ LIMIT $7 ` type ListAuditLogsParams struct { - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` Column2 []string `json:"column_2"` Column3 []string `json:"column_3"` Column4 []string `json:"column_4"` Column5 pgtype.Timestamptz `json:"column_5"` - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Limit int32 `json:"limit"` } diff --git a/internal/db/host_refresh_tokens.sql.go b/internal/db/host_refresh_tokens.sql.go index d02a0e7a..0ec162d9 100644 --- a/internal/db/host_refresh_tokens.sql.go +++ b/internal/db/host_refresh_tokens.sql.go @@ -47,8 +47,8 @@ RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at ` type InsertHostRefreshTokenParams struct { - ID string `json:"id"` - HostID string `json:"host_id"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` TokenHash string `json:"token_hash"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` } @@ -76,7 +76,7 @@ const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1 ` -func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error { +func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, revokeHostRefreshToken, id) return err } @@ -86,7 +86,7 @@ UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE host_id = $1 AND revoked_at IS NULL ` -func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error { +func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error { _, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID) return err } diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index 2d7b8e0c..2e3962b5 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING ` type AddHostTagParams struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error { @@ -29,16 +29,16 @@ const deleteHost = `-- name: DeleteHost :exec DELETE FROM hosts WHERE id = $1 ` -func (q *Queries) DeleteHost(ctx context.Context, id string) error { +func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, deleteHost, id) return err } const getHost = `-- name: GetHost :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 ` -func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { +func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { row := q.db.QueryRow(ctx, getHost, id) var i Host err := row.Scan( @@ -59,18 +59,18 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } const getHostByTeam = `-- name: GetHostByTeam :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2 ` type GetHostByTeamParams struct { - ID string `json:"id"` - TeamID pgtype.Text `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) { @@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag ` -func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) { +func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) { rows, err := q.db.Query(ctx, getHostTags, hostID) if err != nil { return nil, err @@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC ` -func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) { +func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) { rows, err := q.db.Query(ctx, getHostTokensByHost, hostID) if err != nil { return nil, err @@ -157,16 +157,16 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos const insertHost = `-- name: InsertHost :one INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled +RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at ` type InsertHostParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Type string `json:"type"` - TeamID pgtype.Text `json:"team_id"` - Provider pgtype.Text `json:"provider"` - AvailabilityZone pgtype.Text `json:"availability_zone"` - CreatedBy string `json:"created_by"` + TeamID pgtype.UUID `json:"team_id"` + Provider string `json:"provider"` + AvailabilityZone string `json:"availability_zone"` + CreatedBy pgtype.UUID `json:"created_by"` } func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) { @@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at ` type InsertHostTokenParams struct { - ID string `json:"id"` - HostID string `json:"host_id"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` + CreatedBy pgtype.UUID `json:"created_by"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` } @@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams } const listActiveHosts = `-- name: ListActiveHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at ` // Returns all hosts that have completed registration (not pending/offline). @@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { } const listHosts = `-- name: ListHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC ` func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { @@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { } const listHostsByStatus = `-- name: ListHostsByStatus :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) { @@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, } const listHostsByTag = `-- name: ListHostsByTag :many -SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h +SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h JOIN host_tags ht ON ht.host_id = h.id WHERE ht.tag = $1 ORDER BY h.created_at DESC @@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -411,10 +411,10 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error } const listHostsByTeam = `-- name: ListHostsByTeam :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC ` -func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) { +func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) { rows, err := q.db.Query(ctx, listHostsByTeam, teamID) if err != nil { return nil, err @@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho } const listHostsByType = `-- name: ListHostsByType :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) { @@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -500,7 +500,7 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec UPDATE host_tokens SET used_at = NOW() WHERE id = $1 ` -func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error { +func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, markHostTokenUsed, id) return err } @@ -509,31 +509,35 @@ const markHostUnreachable = `-- name: MarkHostUnreachable :exec UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1 ` -func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error { +func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, markHostUnreachable, id) return err } const registerHost = `-- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending' ` type RegisterHostParams struct { - ID string `json:"id"` - Arch pgtype.Text `json:"arch"` - CpuCores pgtype.Int4 `json:"cpu_cores"` - MemoryMb pgtype.Int4 `json:"memory_mb"` - DiskGb pgtype.Int4 `json:"disk_gb"` - Address pgtype.Text `json:"address"` + ID pgtype.UUID `json:"id"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { @@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int arg.MemoryMb, arg.DiskGb, arg.Address, + arg.CertFingerprint, + arg.CertExpiresAt, ) if err != nil { return 0, err @@ -556,8 +562,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2 ` type RemoveHostTagParams struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error { @@ -565,11 +571,30 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er return err } +const updateHostCert = `-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1 +` + +type UpdateHostCertParams struct { + ID pgtype.UUID `json:"id"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` +} + +func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error { + _, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt) + return err +} + const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 ` -func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error { +func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, updateHostHeartbeat, id) return err } @@ -584,7 +609,7 @@ WHERE id = $1 // Updates last_heartbeat_at and transitions unreachable hosts back to online. // Returns 0 if no host was found (deleted), which the caller treats as 404. -func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) { +func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) { result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id) if err != nil { return 0, err @@ -597,8 +622,8 @@ UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1 ` type UpdateHostStatusParams struct { - ID string `json:"id"` - Status string `json:"status"` + ID pgtype.UUID `json:"id"` + Status string `json:"status"` } func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error { diff --git a/internal/db/metrics.sql.go b/internal/db/metrics.sql.go index 80501559..f522dc2d 100644 --- a/internal/db/metrics.sql.go +++ b/internal/db/metrics.sql.go @@ -7,6 +7,8 @@ package db import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec @@ -14,7 +16,7 @@ DELETE FROM sandbox_metric_points WHERE sandbox_id = $1 ` -func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID string) error { +func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error { _, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID) return err } @@ -25,8 +27,8 @@ WHERE sandbox_id = $1 AND tier = $2 ` type DeleteSandboxMetricPointsByTierParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` } func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error { @@ -53,7 +55,7 @@ type GetLiveMetricsRow struct { // Reads directly from sandboxes for accurate real-time current values. // CPU reserved = running + starting only (paused VMs release CPU). // RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling). -func (q *Queries) GetLiveMetrics(ctx context.Context, teamID string) (GetLiveMetricsRow, error) { +func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) { row := q.db.QueryRow(ctx, getLiveMetrics, teamID) var i GetLiveMetricsRow err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved) @@ -76,7 +78,7 @@ type GetPeakMetricsRow struct { PeakMemoryMb int32 `json:"peak_memory_mb"` } -func (q *Queries) GetPeakMetrics(ctx context.Context, teamID string) (GetPeakMetricsRow, error) { +func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) { row := q.db.QueryRow(ctx, getPeakMetrics, teamID) var i GetPeakMetricsRow err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb) @@ -91,9 +93,9 @@ ORDER BY ts ASC ` type GetSandboxMetricPointsParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` } type GetSandboxMetricPointsRow struct { @@ -134,10 +136,10 @@ VALUES ($1, $2, $3, $4) ` type InsertMetricsSnapshotParams struct { - TeamID string `json:"team_id"` - RunningCount int32 `json:"running_count"` - VcpusReserved int32 `json:"vcpus_reserved"` - MemoryMbReserved int32 `json:"memory_mb_reserved"` + TeamID pgtype.UUID `json:"team_id"` + RunningCount int32 `json:"running_count"` + VcpusReserved int32 `json:"vcpus_reserved"` + MemoryMbReserved int32 `json:"memory_mb_reserved"` } func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error { @@ -157,12 +159,12 @@ ON CONFLICT (sandbox_id, tier, ts) DO NOTHING ` type InsertSandboxMetricPointParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` - CpuPct float64 `json:"cpu_pct"` - MemBytes int64 `json:"mem_bytes"` - DiskBytes int64 `json:"disk_bytes"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` + CpuPct float64 `json:"cpu_pct"` + MemBytes int64 `json:"mem_bytes"` + DiskBytes int64 `json:"disk_bytes"` } func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error { @@ -210,10 +212,10 @@ GROUP BY team_id ` type SampleSandboxMetricsRow struct { - TeamID string `json:"team_id"` - RunningCount int32 `json:"running_count"` - VcpusReserved int32 `json:"vcpus_reserved"` - MemoryMbReserved int32 `json:"memory_mb_reserved"` + TeamID pgtype.UUID `json:"team_id"` + RunningCount int32 `json:"running_count"` + VcpusReserved int32 `json:"vcpus_reserved"` + MemoryMbReserved int32 `json:"memory_mb_reserved"` } // Aggregates per-team resource usage from the live sandboxes table. diff --git a/internal/db/models.go b/internal/db/models.go index 0128f4a8..1e9a5d00 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -9,18 +9,18 @@ import ( ) type AdminPermission struct { - ID string `json:"id"` - UserID string `json:"user_id"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` Permission string `json:"permission"` CreatedAt pgtype.Timestamptz `json:"created_at"` } type AuditLog struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` ActorType string `json:"actor_type"` ActorID pgtype.Text `json:"actor_id"` - ActorName pgtype.Text `json:"actor_name"` + ActorName string `json:"actor_name"` ResourceType string `json:"resource_type"` ResourceID pgtype.Text `json:"resource_id"` Action string `json:"action"` @@ -31,29 +31,29 @@ type AuditLog struct { } type Host struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Type string `json:"type"` - TeamID pgtype.Text `json:"team_id"` - Provider pgtype.Text `json:"provider"` - AvailabilityZone pgtype.Text `json:"availability_zone"` - Arch pgtype.Text `json:"arch"` - CpuCores pgtype.Int4 `json:"cpu_cores"` - MemoryMb pgtype.Int4 `json:"memory_mb"` - DiskGb pgtype.Int4 `json:"disk_gb"` - Address pgtype.Text `json:"address"` + TeamID pgtype.UUID `json:"team_id"` + Provider string `json:"provider"` + AvailabilityZone string `json:"availability_zone"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` Status string `json:"status"` LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"` Metadata []byte `json:"metadata"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` - CertFingerprint pgtype.Text `json:"cert_fingerprint"` - MtlsEnabled bool `json:"mtls_enabled"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } type HostRefreshToken struct { - ID string `json:"id"` - HostID string `json:"host_id"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` TokenHash string `json:"token_hash"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` CreatedAt pgtype.Timestamptz `json:"created_at"` @@ -61,14 +61,14 @@ type HostRefreshToken struct { } type HostTag struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } type HostToken struct { - ID string `json:"id"` - HostID string `json:"host_id"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` UsedAt pgtype.Timestamptz `json:"used_at"` @@ -77,40 +77,43 @@ type HostToken struct { type OauthProvider struct { Provider string `json:"provider"` ProviderID string `json:"provider_id"` - UserID string `json:"user_id"` + UserID pgtype.UUID `json:"user_id"` Email string `json:"email"` CreatedAt pgtype.Timestamptz `json:"created_at"` } type Sandbox struct { - ID string `json:"id"` - HostID string `json:"host_id"` - Template string `json:"template"` - Status string `json:"status"` - Vcpus int32 `json:"vcpus"` - MemoryMb int32 `json:"memory_mb"` - TimeoutSec int32 `json:"timeout_sec"` - GuestIp string `json:"guest_ip"` - HostIp string `json:"host_ip"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - StartedAt pgtype.Timestamptz `json:"started_at"` - LastActiveAt pgtype.Timestamptz `json:"last_active_at"` - LastUpdated pgtype.Timestamptz `json:"last_updated"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` + Template string `json:"template"` + Status string `json:"status"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` + GuestIp string `json:"guest_ip"` + HostIp string `json:"host_ip"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + StartedAt pgtype.Timestamptz `json:"started_at"` + LastActiveAt pgtype.Timestamptz `json:"last_active_at"` + LastUpdated pgtype.Timestamptz `json:"last_updated"` + TemplateID pgtype.UUID `json:"template_id"` + TemplateTeamID pgtype.UUID `json:"template_team_id"` } type SandboxMetricPoint struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` - CpuPct float64 `json:"cpu_pct"` - MemBytes int64 `json:"mem_bytes"` - DiskBytes int64 `json:"disk_bytes"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` + CpuPct float64 `json:"cpu_pct"` + MemBytes int64 `json:"mem_bytes"` + DiskBytes int64 `json:"disk_bytes"` } type SandboxMetricsSnapshot struct { ID int64 `json:"id"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` SampledAt pgtype.Timestamptz `json:"sampled_at"` RunningCount int32 `json:"running_count"` VcpusReserved int32 `json:"vcpus_reserved"` @@ -118,21 +121,21 @@ type SandboxMetricsSnapshot struct { } type Team struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - IsByoc bool `json:"is_byoc"` Slug string `json:"slug"` + IsByoc bool `json:"is_byoc"` + CreatedAt pgtype.Timestamptz `json:"created_at"` DeletedAt pgtype.Timestamptz `json:"deleted_at"` } type TeamApiKey struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` Name string `json:"name"` KeyHash string `json:"key_hash"` KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` LastUsed pgtype.Timestamptz `json:"last_used"` } @@ -140,26 +143,50 @@ type TeamApiKey struct { type Template struct { Name string `json:"name"` Type string `json:"type"` - Vcpus pgtype.Int4 `json:"vcpus"` - MemoryMb pgtype.Int4 `json:"memory_mb"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` CreatedAt pgtype.Timestamptz `json:"created_at"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` + ID pgtype.UUID `json:"id"` +} + +type TemplateBuild struct { + ID pgtype.UUID `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []byte `json:"recipe"` + Healthcheck string `json:"healthcheck"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + Status string `json:"status"` + CurrentStep int32 `json:"current_step"` + TotalSteps int32 `json:"total_steps"` + Logs []byte `json:"logs"` + Error string `json:"error"` + SandboxID pgtype.UUID `json:"sandbox_id"` + HostID pgtype.UUID `json:"host_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + StartedAt pgtype.Timestamptz `json:"started_at"` + CompletedAt pgtype.Timestamptz `json:"completed_at"` + TemplateID pgtype.UUID `json:"template_id"` + TeamID pgtype.UUID `json:"team_id"` + SkipPrePost bool `json:"skip_pre_post"` } type User struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Email string `json:"email"` PasswordHash pgtype.Text `json:"password_hash"` + Name string `json:"name"` + IsAdmin bool `json:"is_admin"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` - IsAdmin bool `json:"is_admin"` - Name string `json:"name"` } type UsersTeam struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` IsDefault bool `json:"is_default"` Role string `json:"role"` CreatedAt pgtype.Timestamptz `json:"created_at"` diff --git a/internal/db/oauth.sql.go b/internal/db/oauth.sql.go index ab79eec5..0270defa 100644 --- a/internal/db/oauth.sql.go +++ b/internal/db/oauth.sql.go @@ -7,6 +7,8 @@ package db import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const getOAuthProvider = `-- name: GetOAuthProvider :one @@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4) ` type InsertOAuthProviderParams struct { - Provider string `json:"provider"` - ProviderID string `json:"provider_id"` - UserID string `json:"user_id"` - Email string `json:"email"` + Provider string `json:"provider"` + ProviderID string `json:"provider_id"` + UserID pgtype.UUID `json:"user_id"` + Email string `json:"email"` } func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error { diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 620f77e2..3ce16443 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -15,12 +15,12 @@ const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec UPDATE sandboxes SET status = 'running', last_updated = NOW() -WHERE id = ANY($1::text[]) AND status = 'missing' +WHERE id = ANY($1::uuid[]) AND status = 'missing' ` // Called by the reconciler when a host comes back online and its sandboxes are // confirmed alive. Restores only sandboxes that are in 'missing' state. -func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error { +func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error { _, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1) return err } @@ -29,12 +29,12 @@ const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec UPDATE sandboxes SET status = $2, last_updated = NOW() -WHERE id = ANY($1::text[]) +WHERE id = ANY($1::uuid[]) ` type BulkUpdateStatusByIDsParams struct { - Column1 []string `json:"column_1"` - Status string `json:"status"` + Column1 []pgtype.UUID `json:"column_1"` + Status string `json:"status"` } func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error { @@ -43,38 +43,41 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu } const getSandbox = `-- name: GetSandbox :one -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 ` -func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { +func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) { row := q.db.QueryRow(ctx, getSandbox, id) var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } const getSandboxByTeam = `-- name: GetSandboxByTeam :one -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2 ` type GetSandboxByTeamParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) { @@ -82,38 +85,70 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } +const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2 +` + +type GetSandboxProxyTargetParams struct { + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` +} + +type GetSandboxProxyTargetRow struct { + Status string `json:"status"` + HostAddress string `json:"host_address"` +} + +// Returns the sandbox status and its host's address in one query. +// Used by SandboxProxyWrapper to avoid two round-trips. +func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) { + row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID) + var i GetSandboxProxyTargetRow + err := row.Scan(&i.Status, &i.HostAddress) + return i, err +} + const insertSandbox = `-- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type InsertSandboxParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` - HostID string `json:"host_id"` - Template string `json:"template"` - Status string `json:"status"` - Vcpus int32 `json:"vcpus"` - MemoryMb int32 `json:"memory_mb"` - TimeoutSec int32 `json:"timeout_sec"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` + Template string `json:"template"` + Status string `json:"status"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` + TemplateID pgtype.UUID `json:"template_id"` + TemplateTeamID pgtype.UUID `json:"template_team_id"` } func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) { @@ -126,34 +161,40 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S arg.Vcpus, arg.MemoryMb, arg.TimeoutSec, + arg.DiskSizeMb, + arg.TemplateID, + arg.TemplateTeamID, ) var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') ORDER BY created_at DESC ` -func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) { +func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) { rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID) if err != nil { return nil, err @@ -164,19 +205,22 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -189,7 +233,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) } const listSandboxes = `-- name: ListSandboxes :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC ` func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { @@ -203,19 +247,22 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -228,14 +275,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { } const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) ORDER BY created_at DESC ` type ListSandboxesByHostAndStatusParams struct { - HostID string `json:"host_id"` - Column2 []string `json:"column_2"` + HostID pgtype.UUID `json:"host_id"` + Column2 []string `json:"column_2"` } func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) { @@ -249,19 +296,22 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -274,12 +324,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand } const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE team_id = $1 AND status NOT IN ('stopped', 'error') ORDER BY created_at DESC ` -func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) { +func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) { rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID) if err != nil { return nil, err @@ -290,19 +340,22 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -324,7 +377,7 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending') // Called when the host monitor marks a host unreachable. // Marks running/starting/pending sandboxes on that host as 'missing' so users see // the sandbox is not currently reachable, without permanently losing the record. -func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error { +func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error { _, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID) return err } @@ -337,7 +390,7 @@ WHERE id = $1 ` type UpdateLastActiveParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` LastActiveAt pgtype.Timestamptz `json:"last_active_at"` } @@ -355,11 +408,11 @@ SET status = 'running', last_active_at = $4, last_updated = NOW() WHERE id = $1 -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type UpdateSandboxRunningParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` HostIp string `json:"host_ip"` GuestIp string `json:"guest_ip"` StartedAt pgtype.Timestamptz `json:"started_at"` @@ -375,19 +428,22 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } @@ -397,12 +453,12 @@ UPDATE sandboxes SET status = $2, last_updated = NOW() WHERE id = $1 -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type UpdateSandboxStatusParams struct { - ID string `json:"id"` - Status string `json:"status"` + ID pgtype.UUID `json:"id"` + Status string `json:"status"` } func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) { @@ -410,19 +466,22 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } diff --git a/internal/db/teams.sql.go b/internal/db/teams.sql.go index a00f5efc..334141ff 100644 --- a/internal/db/teams.sql.go +++ b/internal/db/teams.sql.go @@ -16,8 +16,8 @@ DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2 ` type DeleteTeamMemberParams struct { - TeamID string `json:"team_id"` - UserID string `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` } func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error { @@ -26,7 +26,7 @@ func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberPara } const getBYOCTeams = `-- name: GetBYOCTeams :many -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at ` func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { @@ -41,9 +41,9 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { if err := rows.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ); err != nil { return nil, err @@ -57,46 +57,46 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { } const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one -SELECT t.id, t.name, t.created_at, t.is_byoc, t.slug, t.deleted_at FROM teams t +SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t JOIN users_teams ut ON ut.team_id = t.id WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL LIMIT 1 ` -func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) { +func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) { row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID) var i Team err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err } const getTeam = `-- name: GetTeam :one -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE id = $1 +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1 ` -func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) { +func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) { row := q.db.QueryRow(ctx, getTeam, id) var i Team err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err } const getTeamBySlug = `-- name: GetTeamBySlug :one -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL ` func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) { @@ -105,9 +105,9 @@ func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err @@ -122,14 +122,14 @@ ORDER BY ut.created_at ` type GetTeamMembersRow struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` Email string `json:"email"` Role string `json:"role"` JoinedAt pgtype.Timestamptz `json:"joined_at"` } -func (q *Queries) GetTeamMembers(ctx context.Context, teamID string) ([]GetTeamMembersRow, error) { +func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) { rows, err := q.db.Query(ctx, getTeamMembers, teamID) if err != nil { return nil, err @@ -160,8 +160,8 @@ SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE use ` type GetTeamMembershipParams struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) { @@ -186,7 +186,7 @@ ORDER BY ut.created_at ` type GetTeamsForUserRow struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` Slug string `json:"slug"` IsByoc bool `json:"is_byoc"` @@ -195,7 +195,7 @@ type GetTeamsForUserRow struct { Role string `json:"role"` } -func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeamsForUserRow, error) { +func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) { rows, err := q.db.Query(ctx, getTeamsForUser, userID) if err != nil { return nil, err @@ -226,13 +226,13 @@ func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeam const insertTeam = `-- name: InsertTeam :one INSERT INTO teams (id, name, slug) VALUES ($1, $2, $3) -RETURNING id, name, created_at, is_byoc, slug, deleted_at +RETURNING id, name, slug, is_byoc, created_at, deleted_at ` type InsertTeamParams struct { - ID string `json:"id"` - Name string `json:"name"` - Slug string `json:"slug"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` + Slug string `json:"slug"` } func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) { @@ -241,9 +241,9 @@ func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, e err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err @@ -255,10 +255,10 @@ VALUES ($1, $2, $3, $4) ` type InsertTeamMemberParams struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` - IsDefault bool `json:"is_default"` - Role string `json:"role"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` + IsDefault bool `json:"is_default"` + Role string `json:"role"` } func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error { @@ -276,8 +276,8 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1 ` type SetTeamBYOCParams struct { - ID string `json:"id"` - IsByoc bool `json:"is_byoc"` + ID pgtype.UUID `json:"id"` + IsByoc bool `json:"is_byoc"` } func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error { @@ -289,7 +289,7 @@ const softDeleteTeam = `-- name: SoftDeleteTeam :exec UPDATE teams SET deleted_at = NOW() WHERE id = $1 ` -func (q *Queries) SoftDeleteTeam(ctx context.Context, id string) error { +func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, softDeleteTeam, id) return err } @@ -299,9 +299,9 @@ UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2 ` type UpdateMemberRoleParams struct { - TeamID string `json:"team_id"` - UserID string `json:"user_id"` - Role string `json:"role"` + TeamID pgtype.UUID `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` } func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error { @@ -314,8 +314,8 @@ UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL ` type UpdateTeamNameParams struct { - ID string `json:"id"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` } func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error { diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go new file mode 100644 index 00000000..facfb199 --- /dev/null +++ b/internal/db/template_builds.sql.go @@ -0,0 +1,241 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: template_builds.sql + +package db + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const getTemplateBuild = `-- name: GetTemplateBuild :one +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1 +` + +func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, getTemplateBuild, id) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + &i.TemplateID, + &i.TeamID, + &i.SkipPrePost, + ) + return i, err +} + +const insertTemplateBuild = `-- name: InsertTemplateBuild :one +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11) +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post +` + +type InsertTemplateBuildParams struct { + ID pgtype.UUID `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []byte `json:"recipe"` + Healthcheck string `json:"healthcheck"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TotalSteps int32 `json:"total_steps"` + TemplateID pgtype.UUID `json:"template_id"` + TeamID pgtype.UUID `json:"team_id"` + SkipPrePost bool `json:"skip_pre_post"` +} + +func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, insertTemplateBuild, + arg.ID, + arg.Name, + arg.BaseTemplate, + arg.Recipe, + arg.Healthcheck, + arg.Vcpus, + arg.MemoryMb, + arg.TotalSteps, + arg.TemplateID, + arg.TeamID, + arg.SkipPrePost, + ) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + &i.TemplateID, + &i.TeamID, + &i.SkipPrePost, + ) + return i, err +} + +const listTemplateBuilds = `-- name: ListTemplateBuilds :many +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC +` + +func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) { + rows, err := q.db.Query(ctx, listTemplateBuilds) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TemplateBuild + for rows.Next() { + var i TemplateBuild + if err := rows.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + &i.TemplateID, + &i.TeamID, + &i.SkipPrePost, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateBuildError = `-- name: UpdateBuildError :exec +UPDATE template_builds +SET error = $2, status = 'failed', completed_at = NOW() +WHERE id = $1 +` + +type UpdateBuildErrorParams struct { + ID pgtype.UUID `json:"id"` + Error string `json:"error"` +} + +func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error { + _, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error) + return err +} + +const updateBuildProgress = `-- name: UpdateBuildProgress :exec +UPDATE template_builds +SET current_step = $2, logs = $3 +WHERE id = $1 +` + +type UpdateBuildProgressParams struct { + ID pgtype.UUID `json:"id"` + CurrentStep int32 `json:"current_step"` + Logs []byte `json:"logs"` +} + +func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error { + _, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs) + return err +} + +const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec +UPDATE template_builds +SET sandbox_id = $2, host_id = $3 +WHERE id = $1 +` + +type UpdateBuildSandboxParams struct { + ID pgtype.UUID `json:"id"` + SandboxID pgtype.UUID `json:"sandbox_id"` + HostID pgtype.UUID `json:"host_id"` +} + +func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error { + _, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID) + return err +} + +const updateBuildStatus = `-- name: UpdateBuildStatus :one +UPDATE template_builds +SET status = $2, + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END +WHERE id = $1 +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post +` + +type UpdateBuildStatusParams struct { + ID pgtype.UUID `json:"id"` + Status string `json:"status"` +} + +func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + &i.TemplateID, + &i.TeamID, + &i.SkipPrePost, + ) + return i, err +} diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go index cafae692..7d37808d 100644 --- a/internal/db/templates.sql.go +++ b/internal/db/templates.sql.go @@ -12,11 +12,11 @@ import ( ) const deleteTemplate = `-- name: DeleteTemplate :exec -DELETE FROM templates WHERE name = $1 +DELETE FROM templates WHERE id = $1 ` -func (q *Queries) DeleteTemplate(ctx context.Context, name string) error { - _, err := q.db.Exec(ctx, deleteTemplate, name) +func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteTemplate, id) return err } @@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2 ` type DeleteTemplateByTeamParams struct { - Name string `json:"name"` - TeamID string `json:"team_id"` + Name string `json:"name"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error { @@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy return err } -const getTemplate = `-- name: GetTemplate :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 +const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec +DELETE FROM templates WHERE team_id = $1 ` -func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) { - row := q.db.QueryRow(ctx, getTemplate, name) +// Bulk delete all templates owned by a team (for team soft-delete cleanup). +func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID) + return err +} + +const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1 +` + +// Check if a global (platform) template exists with the given name. +func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) { + row := q.db.QueryRow(ctx, getPlatformTemplateByName, name) var i Template err := row.Scan( &i.Name, @@ -49,19 +60,67 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, + ) + return i, err +} + +const getTemplate = `-- name: GetTemplate :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1 +` + +func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) { + row := q.db.QueryRow(ctx, getTemplate, id) + var i Template + err := row.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, + ) + return i, err +} + +const getTemplateByName = `-- name: GetTemplateByName :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2 +` + +type GetTemplateByNameParams struct { + TeamID pgtype.UUID `json:"team_id"` + Name string `json:"name"` +} + +// Look up a template by team_id and name (exact team match, no global fallback). +func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) { + row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name) + var i Template + err := row.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, ) return i, err } const getTemplateByTeam = `-- name: GetTemplateByTeam :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2 +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000') ` type GetTemplateByTeamParams struct { - Name string `json:"name"` - TeamID string `json:"team_id"` + Name string `json:"name"` + TeamID pgtype.UUID `json:"team_id"` } +// Platform templates (team_id = 00000000-...) are visible to all teams. func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) { row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID) var i Template @@ -73,27 +132,30 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ) return i, err } const insertTemplate = `-- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id ` type InsertTemplateParams struct { + ID pgtype.UUID `json:"id"` Name string `json:"name"` Type string `json:"type"` - Vcpus pgtype.Int4 `json:"vcpus"` - MemoryMb pgtype.Int4 `json:"memory_mb"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) { row := q.db.QueryRow(ctx, insertTemplate, + arg.ID, arg.Name, arg.Type, arg.Vcpus, @@ -110,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ) return i, err } const listTemplates = `-- name: ListTemplates :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC ` func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { @@ -135,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -147,10 +211,11 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { } const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC ` -func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) { +// Platform templates are visible to all teams. +func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) { rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID) if err != nil { return nil, err @@ -167,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -179,14 +245,15 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem } const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC ` type ListTemplatesByTeamAndTypeParams struct { - TeamID string `json:"team_id"` - Type string `json:"type"` + TeamID pgtype.UUID `json:"team_id"` + Type string `json:"type"` } +// Platform templates are visible to all teams. func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) { rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type) if err != nil { @@ -204,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC +` + +// List templates owned by a specific team (NOT including platform templates). +func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) { + rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Template + for rows.Next() { + var i Template + if err := rows.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -216,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla } const listTemplatesByType = `-- name: ListTemplatesByType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) { @@ -236,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go index 50ba287e..9de866bb 100644 --- a/internal/db/users.sql.go +++ b/internal/db/users.sql.go @@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2 ` type DeleteAdminPermissionParams struct { - UserID string `json:"user_id"` - Permission string `json:"permission"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error { @@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission ` -func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) { +func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) { rows, err := q.db.Query(ctx, getAdminPermissions, userID) if err != nil { return nil, err @@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm } const getAdminUsers = `-- name: GetAdminUsers :many -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE is_admin = TRUE ORDER BY created_at +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at ` func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { @@ -71,10 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE email = $1 +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1 ` func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { @@ -97,29 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE id = $1 +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1 ` -func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) { +func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) { row := q.db.QueryRow(ctx, getUserByID, id) var i User err := row.Scan( &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -131,8 +131,8 @@ SELECT EXISTS( ` type HasAdminPermissionParams struct { - UserID string `json:"user_id"` - Permission string `json:"permission"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) { @@ -148,9 +148,9 @@ VALUES ($1, $2, $3) ` type InsertAdminPermissionParams struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Permission string `json:"permission"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error { @@ -161,11 +161,11 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm const insertUser = `-- name: InsertUser :one INSERT INTO users (id, email, password_hash, name) VALUES ($1, $2, $3, $4) -RETURNING id, email, password_hash, created_at, updated_at, is_admin, name +RETURNING id, email, password_hash, name, is_admin, created_at, updated_at ` type InsertUserParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Email string `json:"email"` PasswordHash pgtype.Text `json:"password_hash"` Name string `json:"name"` @@ -183,10 +183,10 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -194,13 +194,13 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e const insertUserOAuth = `-- name: InsertUserOAuth :one INSERT INTO users (id, email, name) VALUES ($1, $2, $3) -RETURNING id, email, password_hash, created_at, updated_at, is_admin, name +RETURNING id, email, password_hash, name, is_admin, created_at, updated_at ` type InsertUserOAuthParams struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` } func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) { @@ -210,10 +210,10 @@ func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -223,8 +223,8 @@ SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10 ` type SearchUsersByEmailPrefixRow struct { - ID string `json:"id"` - Email string `json:"email"` + ID pgtype.UUID `json:"id"` + Email string `json:"email"` } func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) { @@ -252,8 +252,8 @@ UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1 ` type SetUserAdminParams struct { - ID string `json:"id"` - IsAdmin bool `json:"is_admin"` + ID pgtype.UUID `json:"id"` + IsAdmin bool `json:"is_admin"` } func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error { @@ -266,8 +266,8 @@ UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1 ` type UpdateUserNameParams struct { - ID string `json:"id"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` } func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error { diff --git a/internal/devicemapper/devicemapper.go b/internal/devicemapper/devicemapper.go index ea14fcd8..9fa08332 100644 --- a/internal/devicemapper/devicemapper.go +++ b/internal/devicemapper/devicemapper.go @@ -116,9 +116,10 @@ type SnapshotDevice struct { // writable CoW layer. // // The origin loop device must already exist (from LoopRegistry.Acquire). -func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) { - // Create sparse CoW file sized to match the origin. - if err := createSparseFile(cowPath, originSizeBytes); err != nil { +func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) { + // Create sparse CoW file. The logical size limits how many blocks can be + // modified; because the file is sparse, only written blocks use real disk. + if err := createSparseFile(cowPath, cowSizeBytes); err != nil { return nil, fmt.Errorf("create cow file: %w", err) } @@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) return nil, fmt.Errorf("losetup cow: %w", err) } + // The dm-snapshot virtual device size must match the origin — the snapshot + // target maps 1:1 onto origin sectors. The CoW file just needs enough + // space to store all modified blocks (it's sparse, so 20GB costs nothing). sectors := originSizeBytes / 512 if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil { if detachErr := losetupDetach(cowLoopDev); detachErr != nil { @@ -220,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error { "if="+dmDevPath, "of="+outputPath, "bs=4M", + "conv=sparse", "status=none", ) if out, err := cmd.CombinedOutput(); err != nil { diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go index 04a1dc2e..49765690 100644 --- a/internal/envdclient/client.go +++ b/internal/envdclient/client.go @@ -3,14 +3,12 @@ package envdclient import ( "bytes" "context" - "encoding/json" "fmt" "io" "log/slog" "mime/multipart" "net/http" "net/url" - "time" "connectrpc.com/connect" @@ -49,35 +47,6 @@ func (c *Client) BaseURL() string { return c.base } -// Init calls POST /init on envd to sync the guest clock with the host. -// This is important after snapshot resume where the guest clock is frozen. -func (c *Client) Init(ctx context.Context) error { - now := time.Now().UTC() - body, err := json.Marshal(map[string]any{"timestamp": now}) - if err != nil { - return fmt.Errorf("marshal init body: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("create init request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("init request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusNoContent { - respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody)) - } - - return nil -} - // ExecResult holds the output of a command execution. type ExecResult struct { Stdout []byte diff --git a/internal/hostagent/certstore.go b/internal/hostagent/certstore.go new file mode 100644 index 00000000..4260ba2f --- /dev/null +++ b/internal/hostagent/certstore.go @@ -0,0 +1,42 @@ +package hostagent + +import ( + "crypto/tls" + "fmt" + "sync/atomic" +) + +// CertStore provides lock-free read/write access to the agent's current TLS +// certificate. It is used with tls.Config.GetCertificate to enable hot-swap +// of the agent's cert on JWT refresh without restarting the server. +// +// The zero value is usable; GetCert returns an error until a cert is stored. +type CertStore struct { + ptr atomic.Pointer[tls.Certificate] +} + +// Store atomically replaces the current certificate. +func (s *CertStore) Store(cert *tls.Certificate) { + s.ptr.Store(cert) +} + +// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert. +// If parsing fails the existing cert is unchanged. +func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error { + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return fmt.Errorf("parse TLS key pair: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has +// been stored yet. +func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no TLS certificate available") + } + return cert, nil +} diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go new file mode 100644 index 00000000..bbee4741 --- /dev/null +++ b/internal/hostagent/proxy.go @@ -0,0 +1,89 @@ +package hostagent + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httputil" + "strconv" + "strings" + + "git.omukk.dev/wrenn/sandbox/internal/sandbox" +) + +// ProxyHandler reverse-proxies HTTP requests to services running inside +// sandboxes. It handles requests of the form: +// +// /proxy/{sandbox_id}/{port}/{path...} +// +// The sandbox's HostIP (routable on this machine) is used as the upstream. +// This supports any protocol that rides on HTTP, including WebSocket upgrades. +type ProxyHandler struct { + mgr *sandbox.Manager + transport http.RoundTripper +} + +// NewProxyHandler creates a new sandbox proxy handler. +func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler { + return &ProxyHandler{ + mgr: mgr, + transport: http.DefaultTransport, + } +} + +// ServeHTTP implements http.Handler. +func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Expected path: /proxy/{sandbox_id}/{port}/... + // After trimming "/proxy/", we get "{sandbox_id}/{port}/..." + trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/") + if trimmed == r.URL.Path { + http.Error(w, "invalid proxy path", http.StatusBadRequest) + return + } + + parts := strings.SplitN(trimmed, "/", 3) + if len(parts) < 2 { + http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest) + return + } + + sandboxID := parts[0] + port := parts[1] + remainder := "" + if len(parts) == 3 { + remainder = parts[2] + } + + // Validate port is a number in the valid range. + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + http.Error(w, "invalid port", http.StatusBadRequest) + return + } + + hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID) + if !ok { + http.Error(w, "sandbox is not available", http.StatusServiceUnavailable) + return + } + defer tracker.Release() + + targetHost := fmt.Sprintf("%s:%d", hostIP, portNum) + + proxy := &httputil.ReverseProxy{ + Transport: h.transport, + Director: func(req *http.Request) { + req.URL.Scheme = "http" + req.URL.Host = targetHost + req.URL.Path = "/" + remainder + req.URL.RawQuery = r.URL.RawQuery + req.Host = targetHost + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err) + http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) + }, + } + + proxy.ServeHTTP(w, r) +} diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go index 9f39c3b0..07909ee1 100644 --- a/internal/hostagent/registration.go +++ b/internal/hostagent/registration.go @@ -17,18 +17,24 @@ import ( "golang.org/x/sys/unix" ) -// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt. -type tokenFile struct { +// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json. +// It holds all credentials the agent needs: the host JWT, refresh token, and +// (when mTLS is enabled) the TLS certificate material for the agent's server. +type TokenFile struct { HostID string `json:"host_id"` JWT string `json:"jwt"` RefreshToken string `json:"refresh_token"` + // mTLS fields — empty when the CP has no CA configured. + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } // RegistrationConfig holds the configuration for host registration. type RegistrationConfig struct { CPURL string // Control plane base URL (e.g., http://localhost:8000) RegistrationToken string // One-time registration token from the control plane - TokenFile string // Path to persist the host JWT after registration + TokenFile string // Path to persist the credentials after registration Address string // Externally-reachable address (ip:port) for this host } @@ -41,22 +47,20 @@ type registerRequest struct { Address string `json:"address"` } -type registerResponse struct { +// authResponse is the shared JSON shape for both register and refresh responses. +type authResponse struct { Host json.RawMessage `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type refreshRequest struct { RefreshToken string `json:"refresh_token"` } -type refreshResponse struct { - Host json.RawMessage `json:"host"` - Token string `json:"token"` - RefreshToken string `json:"refresh_token"` -} - type errorResponse struct { Error struct { Code string `json:"code"` @@ -64,8 +68,8 @@ type errorResponse struct { } `json:"error"` } -// loadTokenFile reads and parses the persisted token file. -func loadTokenFile(path string) (*tokenFile, error) { +// LoadTokenFile reads and parses the persisted credentials file. +func LoadTokenFile(path string) (*TokenFile, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) { if !strings.HasPrefix(trimmed, "{") { // Old format: just the JWT, no refresh token. hostID, _ := hostIDFromJWT(trimmed) - return &tokenFile{HostID: hostID, JWT: trimmed}, nil + return &TokenFile{HostID: hostID, JWT: trimmed}, nil } - var tf tokenFile + var tf TokenFile if err := json.Unmarshal(data, &tf); err != nil { - return nil, fmt.Errorf("parse token file: %w", err) + return nil, fmt.Errorf("parse credentials file: %w", err) } return &tf, nil } -// saveTokenFile writes the token file as JSON with 0600 permissions. -func saveTokenFile(path string, tf tokenFile) error { +// saveTokenFile writes the credentials file as JSON with 0600 permissions. +func saveTokenFile(path string, tf TokenFile) error { data, err := json.MarshalIndent(tf, "", " ") if err != nil { - return fmt.Errorf("marshal token file: %w", err) + return fmt.Errorf("marshal credentials file: %w", err) } return os.WriteFile(path, data, 0600) } // Register calls the control plane to register this host agent and persists -// the returned JWT and refresh token to disk. Returns the host JWT token string. -func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { - // If no explicit registration token was given, reuse the saved JWT. +// the returned credentials to disk. Returns the full TokenFile on success. +func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) { + // If no explicit registration token was given, reuse the saved credentials. // A --register flag always overrides the local file so operators can - // force re-registration without manually deleting host.jwt. + // force re-registration without manually deleting the credentials file. if cfg.RegistrationToken == "" { - if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { - slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID) - return tf.JWT, nil + if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { + slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID) + return tf, nil } - return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)") + return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)") } arch := runtime.GOARCH @@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { body, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("marshal registration request: %w", err) + return nil, fmt.Errorf("marshal registration request: %w", err) } url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create registration request: %w", err) + return nil, fmt.Errorf("create registration request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("registration request failed: %w", err) + return nil, fmt.Errorf("registration request failed: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("read registration response: %w", err) + return nil, fmt.Errorf("read registration response: %w", err) } if resp.StatusCode != http.StatusCreated { var errResp errorResponse if err := json.Unmarshal(respBody, &errResp); err == nil { - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) } - var regResp registerResponse + var regResp authResponse if err := json.Unmarshal(respBody, ®Resp); err != nil { - return "", fmt.Errorf("parse registration response: %w", err) + return nil, fmt.Errorf("parse registration response: %w", err) } if regResp.Token == "" { - return "", fmt.Errorf("registration response missing token") + return nil, fmt.Errorf("registration response missing token") } hostID, err := hostIDFromJWT(regResp.Token) if err != nil { - return "", fmt.Errorf("extract host ID from JWT: %w", err) + return nil, fmt.Errorf("extract host ID from JWT: %w", err) } - // Persist JWT + refresh token. - tf := tokenFile{ + tf := TokenFile{ HostID: hostID, JWT: regResp.Token, RefreshToken: regResp.RefreshToken, + CertPEM: regResp.CertPEM, + KeyPEM: regResp.KeyPEM, + CACertPEM: regResp.CACertPEM, } if err := saveTokenFile(cfg.TokenFile, tf); err != nil { - return "", fmt.Errorf("save host token: %w", err) + return nil, fmt.Errorf("save host credentials: %w", err) } - slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID) + slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID) - return regResp.Token, nil + return &tf, nil } -// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token. -// It reads and updates the token file in place. -func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) { - tf, err := loadTokenFile(tokenFilePath) +// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh +// token, and (when mTLS is enabled) a new TLS certificate. The credentials file +// is updated in place. Returns the updated TokenFile. +func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) { + tf, err := LoadTokenFile(credentialsFilePath) if err != nil { - return "", fmt.Errorf("load token file: %w", err) + return nil, fmt.Errorf("load credentials file: %w", err) } if tf.RefreshToken == "" { - return "", fmt.Errorf("no refresh token available; host must re-register") + return nil, fmt.Errorf("no refresh token available; host must re-register") } body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken}) url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create refresh request: %w", err) + return nil, fmt.Errorf("create refresh request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("refresh request failed: %w", err) + return nil, fmt.Errorf("refresh request failed: %w", err) } defer resp.Body.Close() @@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error if resp.StatusCode != http.StatusOK { var errResp errorResponse if json.Unmarshal(respBody, &errResp) == nil { - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) } - var refResp refreshResponse + var refResp authResponse if err := json.Unmarshal(respBody, &refResp); err != nil { - return "", fmt.Errorf("parse refresh response: %w", err) + return nil, fmt.Errorf("parse refresh response: %w", err) } tf.JWT = refResp.Token tf.RefreshToken = refResp.RefreshToken - if err := saveTokenFile(tokenFilePath, *tf); err != nil { - return "", fmt.Errorf("save refreshed token: %w", err) + if refResp.CertPEM != "" { + tf.CertPEM = refResp.CertPEM + tf.KeyPEM = refResp.KeyPEM + tf.CACertPEM = refResp.CACertPEM + } + if err := saveTokenFile(credentialsFilePath, *tf); err != nil { + return nil, fmt.Errorf("save refreshed credentials: %w", err) } - slog.Info("host JWT refreshed", "host_id", tf.HostID) - return refResp.Token, nil + slog.Info("host credentials refreshed", "host_id", tf.HostID) + return tf, nil } // StartHeartbeat launches a background goroutine that sends periodic heartbeats // to the control plane. It runs until the context is cancelled. // -// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh +// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh // also fails (expired refresh token), it calls pauseAll and stops. // // On repeated network failures (3 consecutive), it calls pauseAll but keeps // retrying — the connection may recover and the host should resume heartbeating. // // onDeleted is called when CP returns 404, meaning this host record was deleted. -// The token file is removed before calling onDeleted so subsequent starts prompt -// for a new registration token. -func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) { +// The credentials file is removed before calling onDeleted so subsequent starts +// prompt for a new registration token. +// +// onCredsRefreshed is called after a successful credential refresh (JWT + cert). +// It may be nil. The caller uses it to hot-swap the agent's TLS certificate. +func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) { client := &http.Client{Timeout: 10 * time.Second} go func() { @@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure := false currentJWT := "" - // Load the current JWT from disk. - if tf, err := loadTokenFile(tokenFilePath); err == nil { + // Load the current JWT from the credentials file. + if tf, err := LoadTokenFile(credentialsFilePath); err == nil { currentJWT = tf.JWT } @@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure = false case http.StatusUnauthorized, http.StatusForbidden: - slog.Warn("heartbeat: JWT rejected — attempting token refresh") - newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath) + slog.Warn("heartbeat: JWT rejected — attempting credentials refresh") + newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath) if refreshErr != nil { - slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required", + slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required", "error", refreshErr) if pauseAll != nil && !pausedDueToFailure { pauseAll() @@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in // Stop the heartbeat loop — operator must re-register. return true } - currentJWT = newJWT - slog.Info("heartbeat: JWT refreshed successfully") + currentJWT = newCreds.JWT + slog.Info("heartbeat: credentials refreshed successfully") + if onCredsRefreshed != nil { + onCredsRefreshed(newCreds) + } case http.StatusNotFound: - slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting") - if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) { - slog.Warn("heartbeat: failed to remove token file", "error", err) + slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting") + if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) { + slog.Warn("heartbeat: failed to remove credentials file", "error", err) } if onDeleted != nil { onDeleted() @@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) { } // hostIDFromJWT is the internal implementation used by both HostIDFromToken and -// the token file loader. +// the credentials file loader. func hostIDFromJWT(token string) (string, error) { parts := strings.Split(token, ".") if len(parts) != 3 { diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index fb7fb664..7cd78f4c 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -12,6 +12,8 @@ import ( "time" "connectrpc.com/connect" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -33,13 +35,35 @@ func NewServer(mgr *sandbox.Manager, terminate func()) *Server { return &Server{mgr: mgr, terminate: terminate} } +// parseUUIDString parses a UUID hex string into a pgtype.UUID. +// An empty string yields an all-zeros UUID (valid). +func parseUUIDString(s string) (pgtype.UUID, error) { + if s == "" { + return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil + } + parsed, err := uuid.Parse(s) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err) + } + return pgtype.UUID{Bytes: parsed, Valid: true}, nil +} + func (s *Server) CreateSandbox( ctx context.Context, req *connect.Request[pb.CreateSandboxRequest], ) (*connect.Response[pb.CreateSandboxResponse], error) { msg := req.Msg - sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec)) + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb)) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err)) } @@ -90,12 +114,21 @@ func (s *Server) CreateSnapshot( ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest], ) (*connect.Response[pb.CreateSnapshotResponse], error) { - sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name) + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err)) } return connect.NewResponse(&pb.CreateSnapshotResponse{ - Name: req.Msg.Name, SizeBytes: sizeBytes, }), nil } @@ -104,12 +137,45 @@ func (s *Server) DeleteSnapshot( ctx context.Context, req *connect.Request[pb.DeleteSnapshotRequest], ) (*connect.Response[pb.DeleteSnapshotResponse], error) { - if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil { + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err)) } return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil } +func (s *Server) FlattenRootfs( + ctx context.Context, + req *connect.Request[pb.FlattenRootfsRequest], +) (*connect.Response[pb.FlattenRootfsResponse], error) { + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err)) + } + return connect.NewResponse(&pb.FlattenRootfsResponse{ + SizeBytes: sizeBytes, + }), nil +} + func (s *Server) PingSandbox( ctx context.Context, req *connect.Request[pb.PingSandboxRequest], @@ -400,7 +466,8 @@ func (s *Server) ListSandboxes( infos[i] = &pb.SandboxInfo{ SandboxId: sb.ID, Status: string(sb.Status), - Template: sb.Template, + TeamId: uuid.UUID(sb.TemplateTeamID).String(), + TemplateId: uuid.UUID(sb.TemplateID).String(), Vcpus: int32(sb.VCPUs), MemoryMb: int32(sb.MemoryMB), HostIp: sb.HostIP.String(), diff --git a/internal/id/id.go b/internal/id/id.go index bbda47c7..f4b6cdb6 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -4,8 +4,167 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "math/big" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) +const ( + base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz" + base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID +) + +var base36Base = big.NewInt(36) + +// --- Generation --- + +// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use. +func newUUID() pgtype.UUID { + return pgtype.UUID{Bytes: uuid.New(), Valid: true} +} + +func NewSandboxID() pgtype.UUID { return newUUID() } +func NewUserID() pgtype.UUID { return newUUID() } +func NewTeamID() pgtype.UUID { return newUUID() } +func NewAPIKeyID() pgtype.UUID { return newUUID() } +func NewHostID() pgtype.UUID { return newUUID() } +func NewHostTokenID() pgtype.UUID { return newUUID() } +func NewRefreshTokenID() pgtype.UUID { return newUUID() } +func NewAuditLogID() pgtype.UUID { return newUUID() } +func NewBuildID() pgtype.UUID { return newUUID() } +func NewAdminPermissionID() pgtype.UUID { return newUUID() } + +func NewTemplateID() pgtype.UUID { return newUUID() } + +// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars. +func NewSnapshotName() string { + return "template-" + hex8() +} + +// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy". +func NewTeamSlug() string { + b := make([]byte, 6) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:]) +} + +// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy). +func NewRegistrationToken() string { + return hexToken(32) +} + +// NewRefreshToken generates a 64-char hex token (32 bytes of entropy). +func NewRefreshToken() string { + return hexToken(32) +} + +// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) --- + +const ( + PrefixSandbox = "cl-" + PrefixUser = "usr-" + PrefixTeam = "team-" + PrefixAPIKey = "key-" + PrefixHost = "host-" + PrefixHostToken = "htok-" + PrefixRefreshToken = "hrt-" + PrefixAuditLog = "log-" + PrefixBuild = "bld-" + PrefixAdminPermission = "perm-" +) + +// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z). +func UUIDToBase36(b [16]byte) string { + n := new(big.Int).SetBytes(b[:]) + buf := make([]byte, base36IDLen) + mod := new(big.Int) + for i := base36IDLen - 1; i >= 0; i-- { + n.DivMod(n, base36Base, mod) + buf[i] = base36Alphabet[mod.Int64()] + } + return string(buf) +} + +// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes. +func base36ToUUID(s string) ([16]byte, error) { + if len(s) != base36IDLen { + return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s)) + } + n := new(big.Int) + for _, c := range s { + idx := strings.IndexRune(base36Alphabet, c) + if idx < 0 { + return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c) + } + n.Mul(n, base36Base) + n.Add(n, big.NewInt(int64(idx))) + } + b := n.Bytes() + var out [16]byte + // big.Int.Bytes() strips leading zeros; right-align into 16-byte array. + copy(out[16-len(b):], b) + return out, nil +} + +func formatUUID(prefix string, id pgtype.UUID) string { + return prefix + UUIDToBase36(id.Bytes) +} + +func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) } +func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) } +func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) } +func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) } +func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) } +func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) } +func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) } +func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) } +func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) } + +// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) --- + +func parseUUID(prefix, s string) (pgtype.UUID, error) { + if !strings.HasPrefix(s, prefix) { + return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s) + } + b, err := base36ToUUID(strings.TrimPrefix(s, prefix)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err) + } + return pgtype.UUID{Bytes: b, Valid: true}, nil +} + +func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) } +func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) } +func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) } +func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) } +func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) } +func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) } +func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) } +func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) } + +// --- Well-known IDs --- + +// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources +// (e.g. base templates, shared infrastructure). +var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true} + +// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal" +// template. When both team_id and template_id are zero, the host agent uses +// the minimal rootfs at WRENN_DIR/images/minimal/. +var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true} + +// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string +// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format. +func UUIDString(id pgtype.UUID) string { + return uuid.UUID(id.Bytes).String() +} + +// --- Helpers --- + func hex8() string { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { @@ -14,73 +173,8 @@ func hex8() string { return hex.EncodeToString(b) } -// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars. -func NewSandboxID() string { - return "sb-" + hex8() -} - -// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars. -func NewSnapshotName() string { - return "template-" + hex8() -} - -// NewUserID generates a new user ID in the format "usr-" + 8 hex chars. -func NewUserID() string { - return "usr-" + hex8() -} - -// NewTeamID generates a new team ID in the format "team-" + 8 hex chars. -func NewTeamID() string { - return "team-" + hex8() -} - -// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy" -// where each part is 3 random bytes encoded as hex (6 hex chars each). -func NewTeamSlug() string { - b := make([]byte, 6) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("crypto/rand failed: %v", err)) - } - return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:]) -} - -// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars. -func NewAPIKeyID() string { - return "key-" + hex8() -} - -// NewHostID generates a new host ID in the format "host-" + 8 hex chars. -func NewHostID() string { - return "host-" + hex8() -} - -// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars. -func NewHostTokenID() string { - return "htok-" + hex8() -} - -// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy). -func NewRegistrationToken() string { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("crypto/rand failed: %v", err)) - } - return hex.EncodeToString(b) -} - -// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars. -func NewRefreshTokenID() string { - return "hrt-" + hex8() -} - -// NewAuditLogID generates a new audit log ID in the format "log-" + 8 hex chars. -func NewAuditLogID() string { - return "log-" + hex8() -} - -// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token. -func NewRefreshToken() string { - b := make([]byte, 32) +func hexToken(nBytes int) string { + b := make([]byte, nBytes) if _, err := rand.Read(b); err != nil { panic(fmt.Sprintf("crypto/rand failed: %v", err)) } diff --git a/internal/id/id_test.go b/internal/id/id_test.go new file mode 100644 index 00000000..2000e9cb --- /dev/null +++ b/internal/id/id_test.go @@ -0,0 +1,118 @@ +package id + +import ( + "testing" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +func TestBase36RoundTrip(t *testing.T) { + for i := 0; i < 1000; i++ { + orig := uuid.New() + encoded := UUIDToBase36(orig) + + if len(encoded) != base36IDLen { + t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded) + } + + decoded, err := base36ToUUID(encoded) + if err != nil { + t.Fatalf("decode failed: %v", err) + } + + if decoded != orig { + t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded) + } + } +} + +func TestBase36ZeroUUID(t *testing.T) { + var zero [16]byte + encoded := UUIDToBase36(zero) + if encoded != "0000000000000000000000000" { + t.Fatalf("zero UUID should encode to all zeros, got %s", encoded) + } + decoded, err := base36ToUUID(encoded) + if err != nil { + t.Fatalf("decode failed: %v", err) + } + if decoded != zero { + t.Fatalf("round-trip failed for zero UUID") + } +} + +func TestFormatParseRoundTrip(t *testing.T) { + id := NewSandboxID() + formatted := FormatSandboxID(id) + + if formatted[:3] != "cl-" { + t.Fatalf("expected cl- prefix, got %s", formatted) + } + if len(formatted) != 3+base36IDLen { + t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted) + } + + parsed, err := ParseSandboxID(formatted) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if parsed != id { + t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed) + } +} + +func TestBase36InvalidInput(t *testing.T) { + // Wrong length. + if _, err := base36ToUUID("abc"); err == nil { + t.Fatal("expected error for short input") + } + // Invalid character. + if _, err := base36ToUUID("000000000000000000000000!"); err == nil { + t.Fatal("expected error for invalid character") + } +} + +func TestPlatformTeamIDFormats(t *testing.T) { + formatted := FormatTeamID(PlatformTeamID) + parsed, err := ParseTeamID(formatted) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if parsed != PlatformTeamID { + t.Fatalf("platform team ID round-trip failed") + } +} + +func TestMaxUUID(t *testing.T) { + max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + encoded := UUIDToBase36(max) + if len(encoded) != base36IDLen { + t.Fatalf("max UUID encoding wrong length: %d", len(encoded)) + } + decoded, err := base36ToUUID(encoded) + if err != nil { + t.Fatalf("decode failed: %v", err) + } + if decoded != max { + t.Fatalf("round-trip failed for max UUID") + } +} + +func BenchmarkFormatSandboxID(b *testing.B) { + id := pgtype.UUID{Bytes: uuid.New(), Valid: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + FormatSandboxID(id) + } +} + +func BenchmarkParseSandboxID(b *testing.B) { + id := pgtype.UUID{Bytes: uuid.New(), Valid: true} + s := FormatSandboxID(id) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseSandboxID(s) + } +} diff --git a/internal/layout/layout.go b/internal/layout/layout.go new file mode 100644 index 00000000..d4084f5f --- /dev/null +++ b/internal/layout/layout.go @@ -0,0 +1,58 @@ +package layout + +import ( + "path/filepath" + + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +// IsMinimal reports whether the given team and template IDs represent the +// built-in "minimal" template (both all-zeros). +func IsMinimal(teamID, templateID pgtype.UUID) bool { + return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes +} + +// TemplateDir returns the on-disk directory for a template. +// +// minimal (zeros, zeros): {wrennDir}/images/minimal +// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)} +func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string { + if IsMinimal(teamID, templateID) { + return filepath.Join(wrennDir, "images", "minimal") + } + return filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(teamID.Bytes), + id.UUIDToBase36(templateID.Bytes)) +} + +// TemplateRootfs returns the path to a template's rootfs.ext4. +func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string { + return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4") +} + +// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files. +func PauseSnapshotDir(wrennDir, sandboxID string) string { + return filepath.Join(wrennDir, "snapshots", sandboxID) +} + +// SandboxesDir returns the directory for running sandbox CoW files. +func SandboxesDir(wrennDir string) string { + return filepath.Join(wrennDir, "sandboxes") +} + +// KernelPath returns the path to the Firecracker kernel. +func KernelPath(wrennDir string) string { + return filepath.Join(wrennDir, "kernels", "vmlinux") +} + +// ImagesRoot returns the root images directory. +func ImagesRoot(wrennDir string) string { + return filepath.Join(wrennDir, "images") +} + +// TeamsDir returns the directory containing all team template subdirectories. +func TeamsDir(wrennDir string) string { + return filepath.Join(wrennDir, "images", "teams") +} diff --git a/internal/layout/layout_test.go b/internal/layout/layout_test.go new file mode 100644 index 00000000..f7b9afd3 --- /dev/null +++ b/internal/layout/layout_test.go @@ -0,0 +1,120 @@ +package layout + +import ( + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +func TestIsMinimal(t *testing.T) { + tests := []struct { + name string + teamID pgtype.UUID + templateID pgtype.UUID + want bool + }{ + { + name: "both zeros", + teamID: id.PlatformTeamID, + templateID: id.MinimalTemplateID, + want: true, + }, + { + name: "non-zero team", + teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}, + templateID: id.MinimalTemplateID, + want: false, + }, + { + name: "non-zero template", + teamID: id.PlatformTeamID, + templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}, + want: false, + }, + { + name: "both non-zero", + teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true}, + templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want { + t.Errorf("IsMinimal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTemplateDir(t *testing.T) { + wrennDir := "/var/lib/wrenn" + + t.Run("minimal", func(t *testing.T) { + got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + want := filepath.Join(wrennDir, "images", "minimal") + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) + + t.Run("team template", func(t *testing.T) { + teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true} + tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true} + got := TemplateDir(wrennDir, teamID, tmplID) + want := filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(teamID.Bytes), + id.UUIDToBase36(tmplID.Bytes)) + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) + + t.Run("global template (platform team, non-zero template)", func(t *testing.T) { + tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true} + got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID) + want := filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(id.PlatformTeamID.Bytes), + id.UUIDToBase36(tmplID.Bytes)) + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) +} + +func TestTemplateRootfs(t *testing.T) { + wrennDir := "/var/lib/wrenn" + got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4") + if got != want { + t.Errorf("TemplateRootfs() = %q, want %q", got, want) + } +} + +func TestPauseSnapshotDir(t *testing.T) { + got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123") + want := "/var/lib/wrenn/snapshots/cl-abc123" + if got != want { + t.Errorf("PauseSnapshotDir() = %q, want %q", got, want) + } +} + +func TestSandboxesDir(t *testing.T) { + got := SandboxesDir("/var/lib/wrenn") + want := "/var/lib/wrenn/sandboxes" + if got != want { + t.Errorf("SandboxesDir() = %q, want %q", got, want) + } +} + +func TestKernelPath(t *testing.T) { + got := KernelPath("/var/lib/wrenn") + want := "/var/lib/wrenn/kernels/vmlinux" + if got != want { + t.Errorf("KernelPath() = %q, want %q", got, want) + } +} diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go index 0caf5ece..f5784898 100644 --- a/internal/lifecycle/hostpool.go +++ b/internal/lifecycle/hostpool.go @@ -1,6 +1,7 @@ package lifecycle import ( + "crypto/tls" "fmt" "net/http" "strings" @@ -8,6 +9,7 @@ import ( "time" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -18,14 +20,33 @@ type HostClientPool struct { mu sync.RWMutex clients map[string]hostagentv1connect.HostAgentServiceClient httpClient *http.Client + scheme string // "http://" or "https://" } -// NewHostClientPool creates a new pool. The underlying HTTP client uses a -// 10-minute timeout to support long-running streaming operations. +// NewHostClientPool creates a pool that connects to agents over plain HTTP. +// Use NewHostClientPoolTLS when mTLS is required. func NewHostClientPool() *HostClientPool { return &HostClientPool{ clients: make(map[string]hostagentv1connect.HostAgentServiceClient), httpClient: &http.Client{Timeout: 10 * time.Minute}, + scheme: "http://", + } +} + +// NewHostClientPoolTLS creates a pool that connects to agents over mTLS. +// tlsCfg should already carry the CP client cert and CA trust anchor +// (use auth.CPClientTLSConfig to construct it). +func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool { + transport := &http.Transport{ + TLSClientConfig: tlsCfg, + } + return &HostClientPool{ + clients: make(map[string]hostagentv1connect.HostAgentServiceClient), + httpClient: &http.Client{ + Timeout: 10 * time.Minute, + Transport: transport, + }, + scheme: "https://", } } @@ -45,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen if c, ok = p.clients[hostID]; ok { return c } - c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address)) + c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address)) p.clients[hostID] = c return c } @@ -53,10 +74,10 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen // GetForHost is a convenience wrapper that extracts the address from a db.Host // and returns an error if the host has no address recorded yet. func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) { - if !h.Address.Valid || h.Address.String == "" { - return nil, fmt.Errorf("host %s has no address", h.ID) + if h.Address == "" { + return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID)) } - return p.Get(h.ID, h.Address.String), nil + return p.Get(id.FormatHostID(h.ID), h.Address), nil } // Evict removes the cached client for the given host, forcing a new client to be @@ -68,8 +89,35 @@ func (p *HostClientPool) Evict(hostID string) { p.mu.Unlock() } -// ensureScheme adds "http://" if the address has no scheme. -func ensureScheme(addr string) string { +// ensureScheme prepends the pool's configured scheme if the address has none. +func (p *HostClientPool) ensureScheme(addr string) string { + if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { + return addr + } + return p.scheme + addr +} + +// Transport returns the http.RoundTripper used by this pool. Use this when you +// need to make raw HTTP requests to agent addresses with the same TLS settings +// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy). +func (p *HostClientPool) Transport() http.RoundTripper { + if p.httpClient.Transport != nil { + return p.httpClient.Transport + } + return http.DefaultTransport +} + +// ResolveAddr prepends the pool's configured scheme to addr if it has none. +// Use this when constructing URLs that must use the same transport as the pool +// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does +// the same thing, but ResolveAddr exposes it for callers that only need the URL. +func (p *HostClientPool) ResolveAddr(addr string) string { + return p.ensureScheme(addr) +} + +// EnsureScheme adds "http://" if the address has no scheme. +// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting. +func EnsureScheme(addr string) string { if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { return addr } diff --git a/internal/models/sandbox.go b/internal/models/sandbox.go index b99bd6b8..ab72cd3e 100644 --- a/internal/models/sandbox.go +++ b/internal/models/sandbox.go @@ -18,15 +18,16 @@ const ( // Sandbox holds all state for a running sandbox on this host. type Sandbox struct { - ID string - Status SandboxStatus - Template string - VCPUs int - MemoryMB int - TimeoutSec int - SlotIndex int - HostIP net.IP - RootfsPath string - CreatedAt time.Time - LastActiveAt time.Time + ID string + Status SandboxStatus + TemplateTeamID [16]byte + TemplateID [16]byte + VCPUs int + MemoryMB int + TimeoutSec int + SlotIndex int + HostIP net.IP + RootfsPath string + CreatedAt time.Time + LastActiveAt time.Time } diff --git a/internal/network/setup.go b/internal/network/setup.go index 70a8a547..ee06d391 100644 --- a/internal/network/setup.go +++ b/internal/network/setup.go @@ -5,13 +5,91 @@ import ( "fmt" "log/slog" "net" + "os" "os/exec" "runtime" + "strings" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" ) +const nsPrefix = "wrenn-ns-" + +// CleanupStaleNamespaces removes leftover wrenn network namespaces from a +// previous crash. Called once at agent startup. +func CleanupStaleNamespaces() { + entries, err := os.ReadDir("/run/netns") + if err != nil { + return // no /run/netns or unreadable — nothing to clean + } + for _, e := range entries { + name := e.Name() + if !strings.HasPrefix(name, nsPrefix) { + continue + } + // Also remove the associated veth from the host side. + vethName := "wrenn-veth-" + strings.TrimPrefix(name, nsPrefix) + if link, err := netlink.LinkByName(vethName); err == nil { + _ = netlink.LinkDel(link) + } + if err := netns.DeleteNamed(name); err != nil { + slog.Warn("failed to remove stale namespace", "ns", name, "error", err) + } else { + slog.Info("removed stale namespace", "ns", name) + } + } + + // Clean up any stale wrenn iptables rules referencing old veth interfaces. + cleanupStaleIptablesRules() +} + +// cleanupStaleIptablesRules removes host iptables rules that reference +// wrenn-veth interfaces no longer present on the system. +func cleanupStaleIptablesRules() { + for _, table := range []string{"filter", "nat"} { + cmd := exec.Command("iptables-save", "-t", table) + out, err := cmd.Output() + if err != nil { + continue + } + for _, line := range strings.Split(string(out), "\n") { + if !strings.Contains(line, "wrenn-veth-") { + continue + } + // Lines look like "-A FORWARD -i wrenn-veth-1 -o wlo1 -j ACCEPT" + // Convert -A to -D to delete the rule. + if !strings.HasPrefix(line, "-A ") { + continue + } + delRule := "-D " + line[3:] + args := strings.Fields(delRule) + delCmd := exec.Command("iptables", append([]string{"-t", table}, args...)...) + if err := delCmd.Run(); err != nil { + slog.Debug("failed to remove stale iptables rule", "rule", line, "error", err) + } + } + } + + // Also remove stale host routes to 10.11.0.x via wrenn-veth interfaces. + routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) + if err != nil { + return + } + for _, r := range routes { + if r.LinkIndex == 0 { + continue + } + link, err := netlink.LinkByIndex(r.LinkIndex) + if err != nil { + continue + } + if strings.HasPrefix(link.Attrs().Name, "wrenn-veth-") { + _ = netlink.RouteDel(&r) + } + } +} + const ( // Fixed addresses inside each network namespace (safe because each // sandbox gets its own netns). @@ -84,8 +162,8 @@ func NewSlot(index int) *Slot { GuestIP: guestIP, GuestNetMask: guestNetMask, TapName: tapName, - NamespaceID: fmt.Sprintf("ns-%d", index), - VethName: fmt.Sprintf("veth-%d", index), + NamespaceID: fmt.Sprintf("wrenn-ns-%d", index), + VethName: fmt.Sprintf("wrenn-veth-%d", index), } } diff --git a/internal/recipe/context.go b/internal/recipe/context.go new file mode 100644 index 00000000..db4c39cc --- /dev/null +++ b/internal/recipe/context.go @@ -0,0 +1,63 @@ +package recipe + +import "strings" + +// ExecContext holds mutable state that persists across recipe steps. +// It is initialized empty and updated by ENV and WORKDIR steps. +type ExecContext struct { + WorkDir string + EnvVars map[string]string +} + +// WrappedCommand returns the full shell command for a RUN step with context +// applied. The result is passed as the argument to /bin/sh -c. +// +// If WORKDIR and/or ENV are set, they are prepended as a shell preamble: +// +// cd '/the/dir' && KEY='val' /bin/sh -c 'original command' +func (c *ExecContext) WrappedCommand(cmd string) string { + prefix := c.shellPrefix() + if prefix == "" { + return cmd + } + return prefix + "/bin/sh -c " + shellescape(cmd) +} + +// StartCommand returns the shell command for a START step. The process is +// launched in the background via nohup so that the outer shell exits +// immediately, allowing the build to continue. stdout/stderr of the +// background process are discarded (the process keeps running in the VM). +// +// Multiple START steps can be issued to run several background processes +// simultaneously before a healthcheck is evaluated. +func (c *ExecContext) StartCommand(cmd string) string { + prefix := c.shellPrefix() + return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &" +} + +// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command. +// Returns an empty string when no context is set. +func (c *ExecContext) shellPrefix() string { + if c.WorkDir == "" && len(c.EnvVars) == 0 { + return "" + } + var sb strings.Builder + if c.WorkDir != "" { + sb.WriteString("cd ") + sb.WriteString(shellescape(c.WorkDir)) + sb.WriteString(" && ") + } + for k, v := range c.EnvVars { + sb.WriteString(k) + sb.WriteByte('=') + sb.WriteString(shellescape(v)) + sb.WriteByte(' ') + } + return sb.String() +} + +// shellescape wraps s in single quotes, escaping any embedded single quotes. +// This is POSIX-safe for paths, env values, and shell commands. +func shellescape(s string) string { + return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'" +} diff --git a/internal/recipe/context_test.go b/internal/recipe/context_test.go new file mode 100644 index 00000000..b00dfce0 --- /dev/null +++ b/internal/recipe/context_test.go @@ -0,0 +1,114 @@ +package recipe + +import "testing" + +func TestExecContext_WrappedCommand(t *testing.T) { + tests := []struct { + name string + ctx ExecContext + cmd string + want string + }{ + { + name: "no context", + ctx: ExecContext{}, + cmd: "apt install -y curl", + want: "apt install -y curl", + }, + { + name: "workdir only", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "npm install", + want: "cd '/app' && /bin/sh -c 'npm install'", + }, + { + name: "env only", + ctx: ExecContext{EnvVars: map[string]string{"PORT": "8080"}}, + cmd: "node server.js", + want: "PORT='8080' /bin/sh -c 'node server.js'", + }, + { + name: "workdir with space", + ctx: ExecContext{WorkDir: "/my project"}, + cmd: "make build", + want: "cd '/my project' && /bin/sh -c 'make build'", + }, + { + name: "command with single quotes", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "echo 'hello'", + want: "cd '/app' && /bin/sh -c 'echo '\\''hello'\\'''", + }, + { + name: "env value with single quotes", + ctx: ExecContext{EnvVars: map[string]string{"MSG": "it's fine"}}, + cmd: "echo $MSG", + want: "MSG='it'\\''s fine' /bin/sh -c 'echo $MSG'", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.ctx.WrappedCommand(tc.cmd) + if got != tc.want { + t.Errorf("WrappedCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want) + } + }) + } +} + +func TestExecContext_StartCommand(t *testing.T) { + tests := []struct { + name string + ctx ExecContext + cmd string + want string + }{ + { + name: "no context", + ctx: ExecContext{}, + cmd: "python3 app.py", + want: "nohup /bin/sh -c 'python3 app.py' >/dev/null 2>&1 &", + }, + { + name: "with workdir", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "python3 server.py", + want: "cd '/app' && nohup /bin/sh -c 'python3 server.py' >/dev/null 2>&1 &", + }, + { + name: "with env", + ctx: ExecContext{EnvVars: map[string]string{"PORT": "9000"}}, + cmd: "node index.js", + want: "PORT='9000' nohup /bin/sh -c 'node index.js' >/dev/null 2>&1 &", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.ctx.StartCommand(tc.cmd) + if got != tc.want { + t.Errorf("StartCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want) + } + }) + } +} + +func TestShellescape(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"simple", "'simple'"}, + {"/path/to/dir", "'/path/to/dir'"}, + {"it's fine", "'it'\\''s fine'"}, + {"", "''"}, + {"a'b'c", "'a'\\''b'\\''c'"}, + } + for _, tc := range tests { + got := shellescape(tc.input) + if got != tc.want { + t.Errorf("shellescape(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/internal/recipe/executor.go b/internal/recipe/executor.go new file mode 100644 index 00000000..3df45dc5 --- /dev/null +++ b/internal/recipe/executor.go @@ -0,0 +1,185 @@ +package recipe + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "connectrpc.com/connect" + + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" +) + +// DefaultStepTimeout is the fallback timeout for RUN steps that carry no +// explicit --timeout flag. +const DefaultStepTimeout = 30 * time.Second + +// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB). +type BuildLogEntry struct { + Step int `json:"step"` + Phase string `json:"phase"` + Cmd string `json:"cmd"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + Exit int32 `json:"exit"` + Ok bool `json:"ok"` + Elapsed int64 `json:"elapsed_ms"` +} + +// ExecFunc is the agent.Exec call signature used by the executor. It matches +// the method on the hostagent Connect RPC client. +type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error) + +// Execute runs steps sequentially against sandboxID using execFn. +// +// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build"). +// - startStep is the 1-based offset so entries are globally numbered across phases. +// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes. +// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward +// into subsequent phases when the caller passes the same pointer. +// +// Returns all log entries appended during this call, the next step counter +// value, and whether all steps succeeded. On false the last entry contains +// failure details; the caller is responsible for destroying the sandbox and +// recording the build error. +func Execute( + ctx context.Context, + phase string, + steps []Step, + sandboxID string, + startStep int, + defaultTimeout time.Duration, + bctx *ExecContext, + execFn ExecFunc, +) (entries []BuildLogEntry, nextStep int, ok bool) { + if defaultTimeout <= 0 { + defaultTimeout = 10 * time.Minute + } + + step := startStep + for _, st := range steps { + step++ + slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw) + + switch st.Kind { + case KindENV: + if bctx.EnvVars == nil { + bctx.EnvVars = make(map[string]string) + } + bctx.EnvVars[st.Key] = st.Value + entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true}) + + case KindWORKDIR: + bctx.WorkDir = st.Path + entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true}) + + case KindUSER, KindCOPY: + verb := strings.ToUpper(strings.Fields(st.Raw)[0]) + entries = append(entries, BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Stderr: verb + " is not yet supported", + Ok: false, + }) + return entries, step, false + + case KindSTART: + entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn) + entries = append(entries, entry) + if !succeeded { + return entries, step, false + } + + case KindRUN: + timeout := defaultTimeout + if st.Timeout > 0 { + timeout = st.Timeout + } + entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn) + entries = append(entries, entry) + if !succeeded { + return entries, step, false + } + } + } + return entries, step, true +} + +func execRun( + ctx context.Context, + st Step, + sandboxID, phase string, + step int, + timeout time.Duration, + bctx *ExecContext, + execFn ExecFunc, +) (BuildLogEntry, bool) { + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", bctx.WrappedCommand(st.Shell)}, + TimeoutSec: int32(timeout.Seconds()), + })) + + entry := BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Elapsed: time.Since(start).Milliseconds(), + } + if err != nil { + entry.Stderr = fmt.Sprintf("exec error: %v", err) + return entry, false + } + entry.Stdout = string(resp.Msg.Stdout) + entry.Stderr = string(resp.Msg.Stderr) + entry.Exit = resp.Msg.ExitCode + entry.Ok = resp.Msg.ExitCode == 0 + return entry, entry.Ok +} + +func execStart( + ctx context.Context, + st Step, + sandboxID, phase string, + step int, + bctx *ExecContext, + execFn ExecFunc, +) (BuildLogEntry, bool) { + // START uses a short timeout: just long enough for the shell to fork and + // return. The background process itself runs indefinitely inside the VM. + execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + start := time.Now() + resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", bctx.StartCommand(st.Shell)}, + TimeoutSec: 10, + })) + + entry := BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Elapsed: time.Since(start).Milliseconds(), + } + if err != nil { + entry.Stderr = fmt.Sprintf("start error: %v", err) + return entry, false + } + entry.Exit = resp.Msg.ExitCode + entry.Ok = resp.Msg.ExitCode == 0 + if !entry.Ok { + entry.Stderr = fmt.Sprintf("start failed with exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr)) + } + return entry, entry.Ok +} diff --git a/internal/recipe/step.go b/internal/recipe/step.go new file mode 100644 index 00000000..7d510362 --- /dev/null +++ b/internal/recipe/step.go @@ -0,0 +1,129 @@ +package recipe + +import ( + "fmt" + "strings" + "time" +) + +// Kind identifies the instruction type in a recipe line. +type Kind int + +const ( + KindRUN Kind = iota // Execute a command and wait for it to exit. + KindSTART // Start a command in the background (non-blocking). + KindENV // Set an environment variable for subsequent steps. + KindWORKDIR // Set the working directory for subsequent steps. + KindUSER // Switch the unix user for subsequent steps. (stub) + KindCOPY // Copy files into the sandbox. (stub) +) + +// Step is the parsed representation of one recipe instruction. +type Step struct { + Kind Kind + Raw string // original string, preserved for logging + Shell string // KindRUN, KindSTART: the shell command text + Timeout time.Duration // KindRUN: 0 means use caller's default + Key string // KindENV: variable name + Value string // KindENV: variable value + Path string // KindWORKDIR: directory path +} + +// ParseStep parses a single recipe instruction string into a Step. +// Instructions are Dockerfile-like: a keyword followed by arguments. +// +// Supported syntax: +// +// RUN — run command, wait for exit +// RUN --timeout= — run command with explicit timeout (e.g. --timeout=5m) +// START — start command in background, return immediately +// ENV = — set environment variable +// WORKDIR — set working directory +// USER — not yet supported +// COPY — not yet supported +func ParseStep(s string) (Step, error) { + s = strings.TrimSpace(s) + if s == "" { + return Step{}, fmt.Errorf("empty step") + } + + // Split on first space to get the keyword. + keyword, rest, _ := strings.Cut(s, " ") + rest = strings.TrimSpace(rest) + + switch strings.ToUpper(keyword) { + case "RUN": + return parseRUN(s, rest) + case "START": + return parseSTART(s, rest) + case "ENV": + return parseENV(s, rest) + case "WORKDIR": + return parseWORKDIR(s, rest) + case "USER": + return Step{Kind: KindUSER, Raw: s}, nil + case "COPY": + return Step{Kind: KindCOPY, Raw: s}, nil + default: + return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword) + } +} + +// ParseRecipe parses all recipe lines, returning on the first error. +func ParseRecipe(lines []string) ([]Step, error) { + steps := make([]Step, 0, len(lines)) + for i, line := range lines { + st, err := ParseStep(line) + if err != nil { + return nil, fmt.Errorf("recipe line %d: %w", i+1, err) + } + steps = append(steps, st) + } + return steps, nil +} + +func parseRUN(raw, rest string) (Step, error) { + var timeout time.Duration + if strings.HasPrefix(rest, "--timeout=") { + rest = rest[len("--timeout="):] + flag, cmd, found := strings.Cut(rest, " ") + if !found || strings.TrimSpace(cmd) == "" { + return Step{}, fmt.Errorf("RUN --timeout= flag has no command: %q", raw) + } + d, err := time.ParseDuration(flag) + if err != nil { + return Step{}, fmt.Errorf("RUN --timeout= invalid duration %q: %w", flag, err) + } + timeout = d + rest = strings.TrimSpace(cmd) + } + if rest == "" { + return Step{}, fmt.Errorf("RUN requires a command: %q", raw) + } + return Step{Kind: KindRUN, Raw: raw, Shell: rest, Timeout: timeout}, nil +} + +func parseSTART(raw, rest string) (Step, error) { + if rest == "" { + return Step{}, fmt.Errorf("START requires a command: %q", raw) + } + return Step{Kind: KindSTART, Raw: raw, Shell: rest}, nil +} + +func parseENV(raw, rest string) (Step, error) { + key, value, found := strings.Cut(rest, "=") + if !found { + return Step{}, fmt.Errorf("ENV requires KEY=VALUE format: %q", raw) + } + if key == "" { + return Step{}, fmt.Errorf("ENV key is empty: %q", raw) + } + return Step{Kind: KindENV, Raw: raw, Key: key, Value: value}, nil +} + +func parseWORKDIR(raw, path string) (Step, error) { + if path == "" { + return Step{}, fmt.Errorf("WORKDIR requires a path: %q", raw) + } + return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil +} diff --git a/internal/recipe/step_test.go b/internal/recipe/step_test.go new file mode 100644 index 00000000..2370bb21 --- /dev/null +++ b/internal/recipe/step_test.go @@ -0,0 +1,208 @@ +package recipe + +import ( + "testing" + "time" +) + +func TestParseStep(t *testing.T) { + tests := []struct { + name string + input string + want Step + wantErr bool + }{ + // RUN + { + name: "RUN basic", + input: "RUN apt install -y curl", + want: Step{Kind: KindRUN, Raw: "RUN apt install -y curl", Shell: "apt install -y curl"}, + }, + { + name: "RUN lowercase", + input: "run echo hello", + want: Step{Kind: KindRUN, Raw: "run echo hello", Shell: "echo hello"}, + }, + { + name: "RUN with timeout", + input: "RUN --timeout=5m npm install", + want: Step{Kind: KindRUN, Raw: "RUN --timeout=5m npm install", Shell: "npm install", Timeout: 5 * time.Minute}, + }, + { + name: "RUN with timeout seconds", + input: "RUN --timeout=30s make build", + want: Step{Kind: KindRUN, Raw: "RUN --timeout=30s make build", Shell: "make build", Timeout: 30 * time.Second}, + }, + { + name: "RUN no command", + input: "RUN", + wantErr: true, + }, + { + name: "RUN timeout no command", + input: "RUN --timeout=5m", + wantErr: true, + }, + { + name: "RUN invalid timeout", + input: "RUN --timeout=notaduration echo hi", + wantErr: true, + }, + // START + { + name: "START basic", + input: "START python3 app.py", + want: Step{Kind: KindSTART, Raw: "START python3 app.py", Shell: "python3 app.py"}, + }, + { + name: "START uppercase", + input: "START node server.js --port=8080", + want: Step{Kind: KindSTART, Raw: "START node server.js --port=8080", Shell: "node server.js --port=8080"}, + }, + { + name: "START no command", + input: "START", + wantErr: true, + }, + // ENV + { + name: "ENV basic", + input: "ENV FOO=bar", + want: Step{Kind: KindENV, Raw: "ENV FOO=bar", Key: "FOO", Value: "bar"}, + }, + { + name: "ENV value with spaces", + input: "ENV GREETING=hello world", + want: Step{Kind: KindENV, Raw: "ENV GREETING=hello world", Key: "GREETING", Value: "hello world"}, + }, + { + name: "ENV value with equals sign", + input: "ENV URL=http://example.com?a=1", + want: Step{Kind: KindENV, Raw: "ENV URL=http://example.com?a=1", Key: "URL", Value: "http://example.com?a=1"}, + }, + { + name: "ENV empty value", + input: "ENV FOO=", + want: Step{Kind: KindENV, Raw: "ENV FOO=", Key: "FOO", Value: ""}, + }, + { + name: "ENV missing equals", + input: "ENV FOO", + wantErr: true, + }, + { + name: "ENV empty key", + input: "ENV =value", + wantErr: true, + }, + // WORKDIR + { + name: "WORKDIR basic", + input: "WORKDIR /app", + want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /app", Path: "/app"}, + }, + { + name: "WORKDIR with spaces in path", + input: "WORKDIR /my project", + want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /my project", Path: "/my project"}, + }, + { + name: "WORKDIR empty", + input: "WORKDIR", + wantErr: true, + }, + // USER and COPY stubs + { + name: "USER stub", + input: "USER www-data", + want: Step{Kind: KindUSER, Raw: "USER www-data"}, + }, + { + name: "COPY stub", + input: "COPY config.yaml /etc/app/config.yaml", + want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"}, + }, + // Unknown keyword + { + name: "unknown keyword", + input: "FROBNICATE something", + wantErr: true, + }, + // Empty input + { + name: "empty string", + input: "", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseStep(tc.input) + if tc.wantErr { + if err == nil { + t.Fatalf("ParseStep(%q) expected error, got %+v", tc.input, got) + } + return + } + if err != nil { + t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err) + } + if got != tc.want { + t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want) + } + }) + } +} + +func TestParseRecipe(t *testing.T) { + t.Run("valid recipe", func(t *testing.T) { + lines := []string{ + "RUN apt update", + "WORKDIR /app", + "ENV PORT=8080", + "START python3 server.py", + "RUN --timeout=2m pip install -r requirements.txt", + } + steps, err := ParseRecipe(lines) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(steps) != 5 { + t.Fatalf("expected 5 steps, got %d", len(steps)) + } + if steps[0].Kind != KindRUN { + t.Errorf("step 0: want KindRUN, got %v", steps[0].Kind) + } + if steps[1].Kind != KindWORKDIR { + t.Errorf("step 1: want KindWORKDIR, got %v", steps[1].Kind) + } + if steps[3].Kind != KindSTART { + t.Errorf("step 3: want KindSTART, got %v", steps[3].Kind) + } + if steps[4].Timeout != 2*time.Minute { + t.Errorf("step 4: want 2m timeout, got %v", steps[4].Timeout) + } + }) + + t.Run("error on invalid line", func(t *testing.T) { + lines := []string{ + "RUN apt update", + "BADCMD something", + } + _, err := ParseRecipe(lines) + if err == nil { + t.Fatal("expected error for invalid line, got nil") + } + }) + + t.Run("empty recipe", func(t *testing.T) { + steps, err := ParseRecipe(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(steps) != 0 { + t.Fatalf("expected 0 steps, got %d", len(steps)) + } + }) +} diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go new file mode 100644 index 00000000..d9eac72b --- /dev/null +++ b/internal/sandbox/conntracker.go @@ -0,0 +1,66 @@ +package sandbox + +import ( + "sync" + "sync/atomic" + "time" +) + +// ConnTracker tracks active proxy connections for a single sandbox and +// provides a drain mechanism for pre-pause graceful shutdown. +// It is safe for concurrent use. +type ConnTracker struct { + draining atomic.Bool + wg sync.WaitGroup +} + +// Acquire registers one in-flight connection. Returns false if the tracker +// is already draining; the caller must not call Release in that case. +func (t *ConnTracker) Acquire() bool { + if t.draining.Load() { + return false + } + t.wg.Add(1) + // Re-check after Add: Drain may have set draining between our Load + // and Add. If so, undo the Add and reject the connection. + if t.draining.Load() { + t.wg.Done() + return false + } + return true +} + +// Release marks one connection as complete. Must be called exactly once +// per successful Acquire. +func (t *ConnTracker) Release() { + t.wg.Done() +} + +// Drain marks the tracker as draining (all future Acquire calls return +// false) and waits up to timeout for in-flight connections to finish. +// +// Note: if the timeout expires with connections still in-flight, the +// internal goroutine waiting on wg.Wait() will remain until those +// connections complete. This is bounded by the number of hung connections +// at drain time and self-heals once they close. +func (t *ConnTracker) Drain(timeout time.Duration) { + t.draining.Store(true) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + } +} + +// Reset re-enables the tracker after a failed drain. This allows the +// sandbox to accept proxy connections again if the pause operation fails +// and the VM is resumed. +func (t *ConnTracker) Reset() { + t.draining.Store(false) +} diff --git a/internal/sandbox/images.go b/internal/sandbox/images.go new file mode 100644 index 00000000..1716d80d --- /dev/null +++ b/internal/sandbox/images.go @@ -0,0 +1,106 @@ +package sandbox + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" +) + +// DefaultDiskSizeMB is the standard disk size for base images. Images smaller +// than this are expanded at startup so that dm-snapshot sandboxes see the full +// size without per-sandbox copies. The expansion is sparse — only metadata +// changes; no physical disk is consumed beyond the original content. +const DefaultDiskSizeMB = 5120 // 5 GB + +// EnsureImageSizes walks template directories and expands any rootfs.ext4 that +// is smaller than the target size. This is idempotent: images already at or +// above the target size are left untouched. Should be called once at host agent +// startup before any sandboxes are created. +func EnsureImageSizes(wrennDir string, targetMB int) error { + if targetMB <= 0 { + targetMB = DefaultDiskSizeMB + } + targetBytes := int64(targetMB) * 1024 * 1024 + + // Expand the built-in minimal image. + minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil { + return err + } + + // Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep. + teamsDir := layout.TeamsDir(wrennDir) + teamEntries, err := os.ReadDir(teamsDir) + if err != nil { + if os.IsNotExist(err) { + return nil // teams dir doesn't exist yet — nothing to expand + } + return fmt.Errorf("read teams dir: %w", err) + } + + for _, teamEntry := range teamEntries { + if !teamEntry.IsDir() { + continue + } + teamPath := filepath.Join(teamsDir, teamEntry.Name()) + templateEntries, err := os.ReadDir(teamPath) + if err != nil { + continue + } + for _, tmplEntry := range templateEntries { + if !tmplEntry.IsDir() { + continue + } + rootfs := filepath.Join(teamPath, tmplEntry.Name(), "rootfs.ext4") + if err := expandImage(rootfs, targetBytes, targetMB); err != nil { + return err + } + } + } + + return nil +} + +// expandImage expands a single rootfs image if it is smaller than targetBytes. +func expandImage(rootfs string, targetBytes int64, targetMB int) error { + info, err := os.Stat(rootfs) + if err != nil { + return nil // not every template dir has a rootfs.ext4 + } + + if info.Size() >= targetBytes { + return nil // already large enough + } + + slog.Info("expanding base image", + "path", rootfs, + "from_mb", info.Size()/(1024*1024), + "to_mb", targetMB, + ) + + // Expand the file (sparse — instant, no physical disk used). + if err := os.Truncate(rootfs, targetBytes); err != nil { + return fmt.Errorf("truncate %s: %w", rootfs, err) + } + + // Check filesystem before resize. + if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil { + // e2fsck returns 1 if it fixed errors, which is fine. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { + return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err) + } + } + + // Grow the ext4 filesystem to fill the new file size. + if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil { + return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err) + } + + slog.Info("base image expanded", "path", rootfs, "size_mb", targetMB) + return nil +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 9a795b5c..67a70ca0 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -4,31 +4,32 @@ import ( "context" "fmt" "log/slog" + "net" "os" + "os/exec" "path/filepath" + "strings" "sync" "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/envdclient" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" "git.omukk.dev/wrenn/sandbox/internal/models" "git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/snapshot" "git.omukk.dev/wrenn/sandbox/internal/uffd" - "git.omukk.dev/wrenn/sandbox/internal/validate" "git.omukk.dev/wrenn/sandbox/internal/vm" ) // Config holds the paths and defaults for the sandbox manager. type Config struct { - KernelPath string - ImagesDir string // directory containing template images (e.g., /var/lib/wrenn/images/{name}/rootfs.ext4) - SandboxesDir string // directory for per-sandbox rootfs clones (e.g., /var/lib/wrenn/sandboxes) - SnapshotsDir string // directory for pause snapshots (e.g., /var/lib/wrenn/snapshots/{sandbox-id}/) - EnvdTimeout time.Duration + WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package + EnvdTimeout time.Duration } // Manager orchestrates sandbox lifecycle: VM, network, filesystem, envd. @@ -50,7 +51,8 @@ type sandboxState struct { models.Sandbox slot *network.Slot client *envdclient.Client - uffdSocketPath string // non-empty for sandboxes restored from snapshot + connTracker *ConnTracker // tracks in-flight proxy connections for pre-pause drain + uffdSocketPath string // non-empty for sandboxes restored from snapshot dmDevice *devicemapper.SnapshotDevice baseImagePath string // path to the base template rootfs (for loop registry release) @@ -74,8 +76,12 @@ type snapshotParent struct { } // maxDiffGenerations caps how many incremental diff generations we chain -// before falling back to a Full snapshot to collapse the chain. -const maxDiffGenerations = 10 +// before falling back to a Full snapshot to collapse the chain. Firecracker +// snapshot/restore of a Go process (envd) accumulates runtime memory state +// drift; empirically, ~10 diff-based cycles corrupt the Go page allocator. +// A Full snapshot resets the generation counter and produces a clean base, +// preventing the crash. +const maxDiffGenerations = 8 // New creates a new sandbox manager. func New(cfg Config) *Manager { @@ -94,9 +100,9 @@ func New(cfg Config) *Manager { // Create boots a new sandbox: clone rootfs, set up network, start VM, wait for envd. // If sandboxID is empty, a new ID is generated. -func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec int) (*models.Sandbox, error) { +func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, memoryMB, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { if sandboxID == "" { - sandboxID = id.NewSandboxID() + sandboxID = id.FormatSandboxID(id.NewSandboxID()) } if vcpus <= 0 { @@ -105,21 +111,18 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, if memoryMB <= 0 { memoryMB = 512 } - - if template == "" { - template = "minimal" - } - if err := validate.SafeName(template); err != nil { - return nil, fmt.Errorf("invalid template name: %w", err) + if diskSizeMB <= 0 { + diskSizeMB = 5120 // 5 GB default } // Check if template refers to a snapshot (has snapfile + memfile + header + rootfs). - if snapshot.IsSnapshot(m.cfg.ImagesDir, template) { - return m.createFromSnapshot(ctx, sandboxID, template, vcpus, memoryMB, timeoutSec) + tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if _, err := os.Stat(filepath.Join(tmplDir, snapshot.SnapFileName)); err == nil { + return m.createFromSnapshot(ctx, sandboxID, teamID, templateID, vcpus, memoryMB, timeoutSec, diskSizeMB) } - // Resolve base rootfs image: /var/lib/wrenn/images/{template}/rootfs.ext4 - baseRootfs := filepath.Join(m.cfg.ImagesDir, template, "rootfs.ext4") + // Resolve base rootfs image. + baseRootfs := layout.TemplateRootfs(m.cfg.WrennDir, teamID, templateID) if _, err := os.Stat(baseRootfs); err != nil { return nil, fmt.Errorf("base rootfs not found at %s: %w", baseRootfs, err) } @@ -138,8 +141,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Create dm-snapshot with per-sandbox CoW file. dmName := "wrenn-" + sandboxID - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) - dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize) + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) + cowSize := int64(diskSizeMB) * 1024 * 1024 + dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { m.loops.Release(baseRootfs) return nil, fmt.Errorf("create dm-snapshot: %w", err) @@ -167,7 +171,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Boot VM — Firecracker gets the dm device path. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + TemplateID: id.UUIDString(templateID), + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, MemoryMB: memoryMB, @@ -203,33 +208,25 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ - ID: sandboxID, - Status: models.StatusRunning, - Template: template, - VCPUs: vcpus, - MemoryMB: memoryMB, - TimeoutSec: timeoutSec, - SlotIndex: slotIdx, - HostIP: slot.HostIP, - RootfsPath: dmDev.DevicePath, - CreatedAt: now, - LastActiveAt: now, + ID: sandboxID, + Status: models.StatusRunning, + TemplateTeamID: teamID.Bytes, + TemplateID: templateID.Bytes, + VCPUs: vcpus, + MemoryMB: memoryMB, + TimeoutSec: timeoutSec, + SlotIndex: slotIdx, + HostIP: slot.HostIP, + RootfsPath: dmDev.DevicePath, + CreatedAt: now, + LastActiveAt: now, }, slot: slot, client: client, + connTracker: &ConnTracker{}, dmDevice: dmDev, baseImagePath: baseRootfs, } @@ -242,7 +239,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, slog.Info("sandbox created", "id", sandboxID, - "template", template, + "team_id", teamID, + "template_id", templateID, "host_ip", slot.HostIP.String(), "dm_device", dmDev.DevicePath, ) @@ -265,7 +263,9 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { } // Always clean up pause snapshot files (may exist if sandbox was paused). - warnErr("snapshot cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + if err := os.RemoveAll(layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)); err != nil { + slog.Warn("snapshot cleanup error", "id", sandboxID, "error", err) + } slog.Info("sandbox destroyed", "id", sandboxID) return nil @@ -311,43 +311,53 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Step 0: Drain in-flight proxy connections before freezing vCPUs. + // This prevents Go runtime corruption inside the guest caused by stale + // TCP state from connections that were alive when the VM was snapshotted. + sb.connTracker.Drain(2 * time.Second) + slog.Debug("pause: proxy connections drained", "id", sandboxID) + pauseStart := time.Now() // Step 1: Pause the VM (freeze vCPUs). if err := m.vm.Pause(ctx, sandboxID); err != nil { + sb.connTracker.Reset() return fmt.Errorf("pause VM: %w", err) } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) - // Determine snapshot type: Diff if resumed from snapshot (avoids UFFD - // fault-in storm), Full otherwise or if generation cap is reached. + // Always use Diff when we have a parent snapshot — Diff only captures + // changed pages and is much faster than Full (which dumps all memory). + // For first-time pauses (no parent) we must use Full. snapshotType := "Full" - if sb.parent != nil && sb.parent.header.Metadata.Generation < maxDiffGenerations { + if sb.parent != nil { snapshotType = "Diff" } // resumeOnError unpauses the VM so the sandbox stays usable when a // post-freeze step fails. If the resume itself fails, the sandbox is - // left frozen — the caller should destroy it. + // left frozen — the caller should destroy it. It also resets the + // connection tracker so the sandbox can accept proxy connections again. resumeOnError := func() { + sb.connTracker.Reset() if err := m.vm.Resume(ctx, sandboxID); err != nil { slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) } } // Step 2: Take VM state snapshot (snapfile + memfile). - if err := snapshot.EnsureDir(m.cfg.SnapshotsDir, sandboxID); err != nil { + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if err := os.MkdirAll(pauseDir, 0755); err != nil { resumeOnError() return fmt.Errorf("create snapshot dir: %w", err) } - snapDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID) - rawMemPath := filepath.Join(snapDir, "memfile.raw") - snapPath := snapshot.SnapPath(m.cfg.SnapshotsDir, sandboxID) + rawMemPath := filepath.Join(pauseDir, "memfile.raw") + snapPath := filepath.Join(pauseDir, snapshot.SnapFileName) snapshotStart := time.Now() if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("create VM snapshot: %w", err) } @@ -355,34 +365,76 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // Step 3: Process the raw memfile into a compact diff + header. buildID := uuid.New() - headerPath := snapshot.MemHeaderPath(m.cfg.SnapshotsDir, sandboxID) + headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName) processStart := time.Now() - if sb.parent != nil && snapshotType == "Diff" { + if sb.parent != nil { // Diff: process against parent header, producing only changed blocks. - diffPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, buildID) + diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID) if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("process memfile with parent: %w", err) } // Copy previous generation diff files into the snapshot directory. for prevBuildID, prevPath := range sb.parent.diffPaths { - dstPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, uuid.MustParse(prevBuildID)) + dstPath := snapshot.MemDiffPathForBuild(pauseDir, "", uuid.MustParse(prevBuildID)) if prevPath != dstPath { if err := copyFile(prevPath, dstPath); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("copy parent diff file: %w", err) } } } + + // If the generation cap is reached, merge all diff files into a + // single file to collapse the chain. This is a file-level operation + // (no Firecracker involvement) so it's fast and reliable. + generation := sb.parent.header.Metadata.Generation + 1 + if generation >= maxDiffGenerations { + slog.Debug("pause: merging diff generations", "id", sandboxID, "generation", generation) + + // Load the header we just wrote (it references all generations). + headerData, err := os.ReadFile(headerPath) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("read header for merge: %w", err) + } + currentHeader, err := snapshot.Deserialize(headerData) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("deserialize header for merge: %w", err) + } + + // Locate all diff files referenced by the header. + diffFiles, err := snapshot.ListDiffFiles(pauseDir, "", currentHeader) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("list diff files for merge: %w", err) + } + + // Merge into a single new diff file. + mergedPath := snapshot.MemDiffPath(pauseDir, "") + if _, err := snapshot.MergeDiffs(currentHeader, diffFiles, mergedPath, headerPath); err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("merge diff files: %w", err) + } + + // Remove the old per-generation diff files. + removeStaleMemDiffs(pauseDir) + slog.Debug("pause: diff merge complete", "id", sandboxID) + } } else { - // Full: first generation or generation cap reached — single diff file. - diffPath := snapshot.MemDiffPath(m.cfg.SnapshotsDir, sandboxID) + // Full: first pause — no parent to diff against. + diffPath := snapshot.MemDiffPath(pauseDir, "") if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("process memfile: %w", err) } @@ -412,7 +464,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { if sb.uffdSocketPath != "" { os.Remove(sb.uffdSocketPath) } - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) m.mu.Lock() delete(m.boxes, sandboxID) m.mu.Unlock() @@ -420,9 +472,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } // Move (not copy) the CoW file into the snapshot directory. - snapshotCow := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID) + snapshotCow := snapshot.CowPath(pauseDir, "") if err := os.Rename(sb.dmDevice.CowPath, snapshotCow); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) // VM and dm-snapshot are already gone — clean up remaining resources. warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot)) m.slots.Release(sb.SlotIndex) @@ -439,10 +491,10 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } // Record which base template this CoW was built against. - if err := snapshot.WriteMeta(m.cfg.SnapshotsDir, sandboxID, &snapshot.RootfsMeta{ + if err := snapshot.WriteMeta(pauseDir, "", &snapshot.RootfsMeta{ BaseTemplate: sb.baseImagePath, }); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) // VM and dm-snapshot are already gone — clean up remaining resources. warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot)) m.slots.Release(sb.SlotIndex) @@ -482,13 +534,13 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // Resume restores a paused sandbox from its snapshot using UFFD for // lazy memory loading. The sandbox gets a new network slot. func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) (*models.Sandbox, error) { - snapDir := m.cfg.SnapshotsDir - if !snapshot.Exists(snapDir, sandboxID) { + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if _, err := os.Stat(pauseDir); err != nil { return nil, fmt.Errorf("no snapshot found for sandbox %s", sandboxID) } // Read the header to set up the UFFD memory source. - headerData, err := os.ReadFile(snapshot.MemHeaderPath(snapDir, sandboxID)) + headerData, err := os.ReadFile(filepath.Join(pauseDir, snapshot.MemHeaderName)) if err != nil { return nil, fmt.Errorf("read header: %w", err) } @@ -499,7 +551,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Build diff file map — supports both single-generation and multi-generation. - diffPaths, err := snapshot.ListDiffFiles(snapDir, sandboxID, header) + diffPaths, err := snapshot.ListDiffFiles(pauseDir, "", header) if err != nil { return nil, fmt.Errorf("list diff files: %w", err) } @@ -510,7 +562,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Read rootfs metadata to find the base template image. - meta, err := snapshot.ReadMeta(snapDir, sandboxID) + meta, err := snapshot.ReadMeta(pauseDir, "") if err != nil { source.Close() return nil, fmt.Errorf("read rootfs meta: %w", err) @@ -532,8 +584,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Move CoW file from snapshot dir to sandboxes dir for the running sandbox. - savedCow := snapshot.CowPath(snapDir, sandboxID) - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) + savedCow := snapshot.CowPath(pauseDir, "") + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) if err := os.Rename(savedCow, cowPath); err != nil { source.Close() m.loops.Release(baseImagePath) @@ -579,7 +631,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Start UFFD server. - uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID)) + uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID)) os.Remove(uffdSocketPath) // Clean stale socket. uffdServer := uffd.NewServer(uffdSocketPath, source) if err := uffdServer.Start(ctx); err != nil { @@ -595,7 +647,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) // Restore VM from snapshot. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: 1, // Placeholder; overridden by snapshot. MemoryMB: int(header.Metadata.Size / (1024 * 1024)), // Placeholder; overridden by snapshot. @@ -607,8 +659,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) NetMask: slot.GuestNetMask, } - snapPath := snapshot.SnapPath(snapDir, sandboxID) - if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil { + resumeSnapPath := filepath.Join(pauseDir, snapshot.SnapFileName) + if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, resumeSnapPath, uffdSocketPath); err != nil { warnErr("uffd server stop error", sandboxID, uffdServer.Stop()) source.Close() warnErr("network cleanup error", sandboxID, network.RemoveNetwork(slot)) @@ -636,22 +688,11 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ ID: sandboxID, Status: models.StatusRunning, - Template: "", VCPUs: vmCfg.VCPUs, MemoryMB: vmCfg.MemoryMB, TimeoutSec: timeoutSec, @@ -663,6 +704,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseImagePath, @@ -700,11 +742,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) // The rootfs is flattened (base + CoW merged) into a new standalone rootfs.ext4 // so the template has no dependency on the original base image. Memory state // and VM snapshot files are copied as-is. -func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (int64, error) { - if err := validate.SafeName(name); err != nil { - return 0, fmt.Errorf("invalid snapshot name: %w", err) - } - +func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) { // If the sandbox is running, pause it first. if _, err := m.get(sandboxID); err == nil { if err := m.Pause(ctx, sandboxID); err != nil { @@ -712,25 +750,26 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i } } - // At this point, pause snapshot files must exist in SnapshotsDir/{sandboxID}/. - if !snapshot.Exists(m.cfg.SnapshotsDir, sandboxID) { + // At this point, pause snapshot files must exist. + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if _, err := os.Stat(pauseDir); err != nil { return 0, fmt.Errorf("no snapshot found for sandbox %s", sandboxID) } // Create template directory. - if err := snapshot.EnsureDir(m.cfg.ImagesDir, name); err != nil { + dstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if err := os.MkdirAll(dstDir, 0755); err != nil { return 0, fmt.Errorf("create template dir: %w", err) } // Copy VM snapshot file and memory header. - srcDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID) - dstDir := snapshot.DirPath(m.cfg.ImagesDir, name) + srcDir := pauseDir for _, fname := range []string{snapshot.SnapFileName, snapshot.MemHeaderName} { src := filepath.Join(srcDir, fname) dst := filepath.Join(dstDir, fname) if err := copyFile(src, dst); err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("copy %s: %w", fname, err) } } @@ -738,59 +777,59 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i // Copy all memory diff files referenced by the header (supports multi-generation). headerData, err := os.ReadFile(filepath.Join(srcDir, snapshot.MemHeaderName)) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("read header for template: %w", err) } srcHeader, err := snapshot.Deserialize(headerData) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("deserialize header for template: %w", err) } - srcDiffPaths, err := snapshot.ListDiffFiles(m.cfg.SnapshotsDir, sandboxID, srcHeader) + srcDiffPaths, err := snapshot.ListDiffFiles(pauseDir, "", srcHeader) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("list diff files for template: %w", err) } for _, srcPath := range srcDiffPaths { dstPath := filepath.Join(dstDir, filepath.Base(srcPath)) if err := copyFile(srcPath, dstPath); err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("copy diff file %s: %w", filepath.Base(srcPath), err) } } // Flatten rootfs: temporarily set up dm device from base + CoW, dd to new image. - meta, err := snapshot.ReadMeta(m.cfg.SnapshotsDir, sandboxID) + meta, err := snapshot.ReadMeta(pauseDir, "") if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("read rootfs meta: %w", err) } originLoop, err := m.loops.Acquire(meta.BaseTemplate) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("acquire loop device for flatten: %w", err) } originSize, err := devicemapper.OriginSizeBytes(originLoop) if err != nil { m.loops.Release(meta.BaseTemplate) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("get origin size: %w", err) } // Temporarily restore the dm-snapshot to read the merged view. - cowPath := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID) + cowPath := snapshot.CowPath(pauseDir, "") tmpDmName := "wrenn-flatten-" + sandboxID tmpDev, err := devicemapper.RestoreSnapshot(ctx, tmpDmName, originLoop, cowPath, originSize) if err != nil { m.loops.Release(meta.BaseTemplate) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("restore dm-snapshot for flatten: %w", err) } // Flatten to new standalone rootfs. - flattenedPath := snapshot.RootfsPath(m.cfg.ImagesDir, name) + flattenedPath := filepath.Join(dstDir, snapshot.RootfsFileName) flattenErr := devicemapper.FlattenSnapshot(tmpDev.DevicePath, flattenedPath) // Always clean up the temporary dm device. @@ -798,40 +837,131 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i m.loops.Release(meta.BaseTemplate) if flattenErr != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("flatten rootfs: %w", flattenErr) } - sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name) + sizeBytes, err := snapshot.DirSize(dstDir, "") if err != nil { slog.Warn("failed to calculate snapshot size", "error", err) } slog.Info("template snapshot created (rootfs flattened)", "sandbox", sandboxID, - "name", name, + "team_id", teamID, + "template_id", templateID, "size_bytes", sizeBytes, ) return sizeBytes, nil } -// DeleteSnapshot removes a snapshot template from disk. -func (m *Manager) DeleteSnapshot(name string) error { - if err := validate.SafeName(name); err != nil { - return fmt.Errorf("invalid snapshot name: %w", err) +// FlattenRootfs stops a running sandbox, flattens its device-mapper CoW +// rootfs into a standalone rootfs.ext4, and cleans up all resources. +// The result is an image-only template (no VM memory/CPU state) stored in +// ImagesDir/{name}/rootfs.ext4. +func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) { + m.mu.Lock() + sb, ok := m.boxes[sandboxID] + if ok { + delete(m.boxes, sandboxID) } - return snapshot.Remove(m.cfg.ImagesDir, name) + m.mu.Unlock() + + if !ok { + return 0, fmt.Errorf("sandbox %s not found", sandboxID) + } + + // Stop the VM but keep the dm device alive for flattening. + m.stopSampler(sb) + if err := m.vm.Destroy(ctx, sb.ID); err != nil { + slog.Warn("vm destroy error during flatten", "id", sb.ID, "error", err) + } + + // Release network resources — not needed after VM is stopped. + if err := network.RemoveNetwork(sb.slot); err != nil { + slog.Warn("network cleanup error during flatten", "id", sb.ID, "error", err) + } + m.slots.Release(sb.SlotIndex) + + if sb.uffdSocketPath != "" { + os.Remove(sb.uffdSocketPath) + } + + // Create template directory and flatten the dm-snapshot. + flattenDstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if err := os.MkdirAll(flattenDstDir, 0755); err != nil { + m.cleanupDM(sb) + return 0, fmt.Errorf("create template dir: %w", err) + } + + outputPath := filepath.Join(flattenDstDir, snapshot.RootfsFileName) + if sb.dmDevice == nil { + m.cleanupDM(sb) + warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir)) + return 0, fmt.Errorf("sandbox %s has no dm device", sandboxID) + } + + if err := devicemapper.FlattenSnapshot(sb.dmDevice.DevicePath, outputPath); err != nil { + m.cleanupDM(sb) + warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir)) + return 0, fmt.Errorf("flatten rootfs: %w", err) + } + + // Clean up dm device and loop device now that flatten is complete. + m.cleanupDM(sb) + + // Shrink the flattened image to its minimum size so stored templates are + // compact. EnsureImageSizes will re-expand them on the next agent startup. + if out, err := exec.Command("e2fsck", "-fy", outputPath).CombinedOutput(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { + slog.Warn("e2fsck before shrink failed (non-fatal)", "output", string(out), "error", err) + } + } + if out, err := exec.Command("resize2fs", "-M", outputPath).CombinedOutput(); err != nil { + slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err) + } + + sizeBytes, err := snapshot.DirSize(flattenDstDir, "") + if err != nil { + slog.Warn("failed to calculate template size", "error", err) + } + + slog.Info("rootfs flattened to image-only template", + "sandbox", sandboxID, + "team_id", teamID, + "template_id", templateID, + "size_bytes", sizeBytes, + ) + return sizeBytes, nil +} + +// cleanupDM tears down the dm-snapshot device and releases the base image loop device. +func (m *Manager) cleanupDM(sb *sandboxState) { + if sb.dmDevice != nil { + if err := devicemapper.RemoveSnapshot(context.Background(), sb.dmDevice); err != nil { + slog.Warn("dm-snapshot remove error", "id", sb.ID, "error", err) + } + os.Remove(sb.dmDevice.CowPath) + } + if sb.baseImagePath != "" { + m.loops.Release(sb.baseImagePath) + } +} + +// DeleteSnapshot removes a snapshot template from disk. +func (m *Manager) DeleteSnapshot(teamID, templateID pgtype.UUID) error { + return os.RemoveAll(layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)) } // createFromSnapshot creates a new sandbox by restoring from a snapshot template // in ImagesDir/{snapshotName}/. Uses UFFD for lazy memory loading. // The template's rootfs.ext4 is a flattened standalone image — we create a // dm-snapshot on top of it just like a normal Create. -func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotName string, vcpus, _, timeoutSec int) (*models.Sandbox, error) { - imagesDir := m.cfg.ImagesDir +func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, _, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { + tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) // Read the header. - headerData, err := os.ReadFile(snapshot.MemHeaderPath(imagesDir, snapshotName)) + headerData, err := os.ReadFile(filepath.Join(tmplDir, snapshot.MemHeaderName)) if err != nil { return nil, fmt.Errorf("read snapshot header: %w", err) } @@ -845,7 +975,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam memoryMB := int(header.Metadata.Size / (1024 * 1024)) // Build diff file map — supports multi-generation templates. - diffPaths, err := snapshot.ListDiffFiles(imagesDir, snapshotName, header) + diffPaths, err := snapshot.ListDiffFiles(tmplDir, "", header) if err != nil { return nil, fmt.Errorf("list diff files: %w", err) } @@ -856,7 +986,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } // Set up dm-snapshot on the template's flattened rootfs. - baseRootfs := snapshot.RootfsPath(imagesDir, snapshotName) + baseRootfs := filepath.Join(tmplDir, snapshot.RootfsFileName) originLoop, err := m.loops.Acquire(baseRootfs) if err != nil { source.Close() @@ -871,8 +1001,9 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } dmName := "wrenn-" + sandboxID - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) - dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize) + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) + cowSize := int64(diskSizeMB) * 1024 * 1024 + dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { source.Close() m.loops.Release(baseRootfs) @@ -900,7 +1031,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } // Start UFFD server. - uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID)) + uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID)) os.Remove(uffdSocketPath) uffdServer := uffd.NewServer(uffdSocketPath, source) if err := uffdServer.Start(ctx); err != nil { @@ -916,7 +1047,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam // Restore VM. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + TemplateID: id.UUIDString(templateID), + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, MemoryMB: memoryMB, @@ -928,7 +1060,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam NetMask: slot.GuestNetMask, } - snapPath := snapshot.SnapPath(imagesDir, snapshotName) + snapPath := filepath.Join(tmplDir, snapshot.SnapFileName) if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil { warnErr("uffd server stop error", sandboxID, uffdServer.Stop()) source.Close() @@ -957,33 +1089,25 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ - ID: sandboxID, - Status: models.StatusRunning, - Template: snapshotName, - VCPUs: vcpus, - MemoryMB: memoryMB, - TimeoutSec: timeoutSec, - SlotIndex: slotIdx, - HostIP: slot.HostIP, - RootfsPath: dmDev.DevicePath, - CreatedAt: now, - LastActiveAt: now, + ID: sandboxID, + Status: models.StatusRunning, + TemplateTeamID: teamID.Bytes, + TemplateID: templateID.Bytes, + VCPUs: vcpus, + MemoryMB: memoryMB, + TimeoutSec: timeoutSec, + SlotIndex: slotIdx, + HostIP: slot.HostIP, + RootfsPath: dmDev.DevicePath, + CreatedAt: now, + LastActiveAt: now, }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseRootfs, @@ -1002,7 +1126,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam slog.Info("sandbox created from snapshot", "id", sandboxID, - "snapshot", snapshotName, + "team_id", teamID, + "template_id", templateID, "host_ip", slot.HostIP.String(), "dm_device", dmDev.DevicePath, ) @@ -1079,6 +1204,25 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) { return sb.client, nil } +// AcquireProxyConn atomically looks up a sandbox by ID and registers an +// in-flight proxy connection. Returns the sandbox's host-reachable IP, the +// connection tracker, and true on success. The caller must call +// tracker.Release() when the request completes. Returns zero values and +// false if the sandbox is not found, not running, or is draining for a pause. +func (m *Manager) AcquireProxyConn(sandboxID string) (net.IP, *ConnTracker, bool) { + m.mu.RLock() + sb, ok := m.boxes[sandboxID] + m.mu.RUnlock() + + if !ok || sb.Status != models.StatusRunning { + return nil, nil, false + } + if !sb.connTracker.Acquire() { + return nil, nil, false + } + return sb.HostIP, sb.connTracker, true +} + // Ping resets the inactivity timer for a running sandbox. func (m *Manager) Ping(sandboxID string) error { m.mu.Lock() @@ -1218,6 +1362,23 @@ func (m *Manager) PauseAll(ctx context.Context) { } } +// removeStaleMemDiffs removes memfile.{uuid} diff files from a snapshot +// directory. Called before writing a Full snapshot to prevent orphaned diffs +// from accumulating across generation resets. +func removeStaleMemDiffs(dir string) { + entries, err := os.ReadDir(dir) + if err != nil { + return + } + for _, e := range entries { + name := e.Name() + // Match "memfile.{uuid}" but not "memfile", "memfile.header", or "memfile.raw". + if strings.HasPrefix(name, "memfile.") && name != snapshot.MemHeaderName && name != "memfile.raw" { + os.Remove(filepath.Join(dir, name)) + } + } +} + // warnErr logs a warning if err is non-nil. Used for best-effort cleanup // in error paths where the primary error has already been captured. func warnErr(msg string, id string, err error) { diff --git a/internal/scheduler/round_robin.go b/internal/scheduler/round_robin.go index 31433a0b..c2ab0f40 100644 --- a/internal/scheduler/round_robin.go +++ b/internal/scheduler/round_robin.go @@ -5,6 +5,8 @@ import ( "fmt" "sync/atomic" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" ) @@ -15,7 +17,7 @@ type HostScheduler interface { // For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID // are considered. For non-BYOC teams, only online regular (platform) hosts // are considered. Returns an error if no suitable host is available. - SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) + SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) } // RoundRobinScheduler cycles through eligible online hosts in round-robin order. @@ -32,7 +34,7 @@ func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler { } // SelectHost returns the next eligible online host in round-robin order. -func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) { +func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) { hosts, err := s.db.ListActiveHosts(ctx) if err != nil { return db.Host{}, fmt.Errorf("list hosts: %w", err) @@ -40,12 +42,12 @@ func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isB var eligible []db.Host for _, h := range hosts { - if h.Status != "online" || !h.Address.Valid || h.Address.String == "" { + if h.Status != "online" || h.Address == "" { continue } if isByoc { // BYOC team: only use hosts belonging to this team. - if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID.String != teamID { + if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID { continue } } else { diff --git a/internal/service/apikey.go b/internal/service/apikey.go index c49ddcaa..7a2b073a 100644 --- a/internal/service/apikey.go +++ b/internal/service/apikey.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" @@ -22,7 +24,7 @@ type APIKeyCreateResult struct { } // Create generates a new API key for the given team. -func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) { +func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) { if name == "" { name = "Unnamed API Key" } @@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) } // List returns all API keys belonging to the given team. -func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) { +func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) { return s.DB.ListAPIKeysByTeam(ctx, teamID) } // ListWithCreator returns all API keys for the team, joined with the creator's email. -func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) { +func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) { return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID) } // Delete removes an API key by ID, scoped to the given team. -func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error { +func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error { return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID}) } diff --git a/internal/service/audit.go b/internal/service/audit.go index 53061421..67faafa2 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) const auditMaxLimit = 200 @@ -31,13 +32,13 @@ type AuditEntry struct { // AuditListParams controls the ListAuditLogs query. type AuditListParams struct { - TeamID string - AdminScoped bool // true → include admin-scoped events; false → team-scoped only - ResourceTypes []string // empty = no filter; multiple values = OR match - Actions []string // empty = no filter; multiple values = OR match - Before time.Time // zero = no cursor (start from latest) - BeforeID string // tie-breaker: id of the last item at the Before timestamp; empty = no tie-break - Limit int // clamped to auditMaxLimit by the handler + TeamID pgtype.UUID + AdminScoped bool // true → include admin-scoped events; false → team-scoped only + ResourceTypes []string // empty = no filter; multiple values = OR match + Actions []string // empty = no filter; multiple values = OR match + Before time.Time // zero = no cursor (start from latest) + BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break + Limit int // clamped to auditMaxLimit by the handler } // AuditService provides the read side of the audit log. @@ -94,11 +95,11 @@ func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntr _ = json.Unmarshal(row.Metadata, &meta) } entries[i] = AuditEntry{ - ID: row.ID, - TeamID: row.TeamID, + ID: id.FormatAuditLogID(row.ID), + TeamID: id.FormatTeamID(row.TeamID), ActorType: row.ActorType, ActorID: row.ActorID.String, - ActorName: row.ActorName.String, + ActorName: row.ActorName, ResourceType: row.ResourceType, ResourceID: row.ResourceID.String, Action: row.Action, diff --git a/internal/service/build.go b/internal/service/build.go new file mode 100644 index 00000000..1108044d --- /dev/null +++ b/internal/service/build.go @@ -0,0 +1,519 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sync" + "time" + + "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" + "github.com/redis/go-redis/v9" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/recipe" + "git.omukk.dev/wrenn/sandbox/internal/scheduler" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" +) + +const ( + buildQueueKey = "wrenn:build_queue" + buildCommandTimeout = 30 * time.Second + healthcheckInterval = 1 * time.Second + healthcheckTimeout = 60 * time.Second +) + +// preBuildCmds run before the user recipe to prepare the build environment. +var preBuildCmds = []string{ + "RUN apt update", +} + +// postBuildCmds run after the user recipe to clean up caches and reduce image size. +var postBuildCmds = []string{ + "RUN apt clean", + "RUN apt autoremove -y", + "RUN rm -rf /var/lib/apt/lists/*", +} + +// buildAgentClient is the subset of the host agent client used by the build worker. +type buildAgentClient interface { + CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error) + DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error) + Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error) + CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error) + FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error) +} + +// BuildService handles template build orchestration. +type BuildService struct { + DB *db.Queries + Redis *redis.Client + Pool *lifecycle.HostClientPool + Scheduler scheduler.HostScheduler + + mu sync.Mutex + cancelMap map[string]context.CancelFunc // buildID → per-build cancel func +} + +// BuildCreateParams holds the parameters for creating a template build. +type BuildCreateParams struct { + Name string + BaseTemplate string + Recipe []string + Healthcheck string + VCPUs int32 + MemoryMB int32 + SkipPrePost bool +} + +// Create inserts a new build record and enqueues it to Redis. +func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) { + if p.BaseTemplate == "" { + p.BaseTemplate = "minimal" + } + if p.VCPUs <= 0 { + p.VCPUs = 1 + } + if p.MemoryMB <= 0 { + p.MemoryMB = 512 + } + + recipeJSON, err := json.Marshal(p.Recipe) + if err != nil { + return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err) + } + + buildID := id.NewBuildID() + buildIDStr := id.FormatBuildID(buildID) + newTemplateID := id.NewTemplateID() + + defaultSteps := len(preBuildCmds) + len(postBuildCmds) + if p.SkipPrePost { + defaultSteps = 0 + } + + build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{ + ID: buildID, + Name: p.Name, + BaseTemplate: p.BaseTemplate, + Recipe: recipeJSON, + Healthcheck: p.Healthcheck, + Vcpus: p.VCPUs, + MemoryMb: p.MemoryMB, + TotalSteps: int32(len(p.Recipe) + defaultSteps), + TemplateID: newTemplateID, + TeamID: id.PlatformTeamID, + SkipPrePost: p.SkipPrePost, + }) + if err != nil { + return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) + } + + // Enqueue build ID (as formatted string) to Redis for workers to pick up. + if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil { + return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err) + } + + return build, nil +} + +// Get returns a single build by ID. +func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) { + return s.DB.GetTemplateBuild(ctx, buildID) +} + +// List returns all builds ordered by creation time. +func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) { + return s.DB.ListTemplateBuilds(ctx) +} + +// Cancel cancels a pending or running build. For pending builds the status is +// updated in the DB and the worker skips it when dequeued. For running builds +// the per-build context is cancelled, which causes the current exec step to +// abort; executeBuild then detects the cancellation and records the status. +func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error { + build, err := s.DB.GetTemplateBuild(ctx, buildID) + if err != nil { + return fmt.Errorf("get build: %w", err) + } + switch build.Status { + case "success", "failed", "cancelled": + return fmt.Errorf("build is already %s", build.Status) + } + + // Mark cancelled in DB first. This handles both pending builds (which haven't + // been picked up yet) and acts as a flag for executeBuild to check on start. + if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "cancelled", + }); err != nil { + return fmt.Errorf("update build status: %w", err) + } + + // If the build is currently running, signal its context. + buildIDStr := id.FormatBuildID(buildID) + s.mu.Lock() + cancel, running := s.cancelMap[buildIDStr] + s.mu.Unlock() + if running { + cancel() + } + + return nil +} + +// StartWorkers launches n goroutines that consume from the Redis build queue. +// The returned cancel function stops all workers. +func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc { + ctx, cancel := context.WithCancel(ctx) + for i := range n { + go s.worker(ctx, i) + } + slog.Info("build workers started", "count", n) + return cancel +} + +func (s *BuildService) worker(ctx context.Context, workerID int) { + log := slog.With("worker", workerID) + for { + // BLPOP blocks until a build ID is available or context is cancelled. + result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result() + if err != nil { + if ctx.Err() != nil { + log.Info("build worker shutting down") + return + } + log.Error("redis BLPOP error", "error", err) + time.Sleep(time.Second) + continue + } + // result[0] is the key, result[1] is the build ID (formatted string). + buildIDStr := result[1] + log.Info("picked up build", "build_id", buildIDStr) + s.executeBuild(ctx, buildIDStr) + } +} + +func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { + log := slog.With("build_id", buildIDStr) + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + log.Error("invalid build ID from queue", "error", err) + return + } + + // Create a per-build context so this build can be cancelled independently of + // the worker. Register in cancelMap before fetching the build so that a + // concurrent Cancel call can always find and signal it. + buildCtx, buildCancel := context.WithCancel(ctx) + defer buildCancel() + + s.mu.Lock() + if s.cancelMap == nil { + s.cancelMap = make(map[string]context.CancelFunc) + } + s.cancelMap[buildIDStr] = buildCancel + s.mu.Unlock() + defer func() { + s.mu.Lock() + delete(s.cancelMap, buildIDStr) + s.mu.Unlock() + }() + + build, err := s.DB.GetTemplateBuild(buildCtx, buildID) + if err != nil { + log.Error("failed to fetch build", "error", err) + return + } + + // Skip if already cancelled (Cancel was called before we dequeued). + if build.Status == "cancelled" { + log.Info("build already cancelled, skipping") + return + } + + // Mark as running. + if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "running", + }); err != nil { + log.Error("failed to update build status", "error", err) + return + } + + // Parse user recipe. + var userRecipe []string + if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err)) + return + } + + // Pick a platform host and create a sandbox. + host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false) + if err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err)) + return + } + + agent, err := s.Pool.GetForHost(host) + if err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err)) + return + } + + sandboxID := id.NewSandboxID() + sandboxIDStr := id.FormatSandboxID(sandboxID) + log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID)) + + // Resolve the base template to UUIDs. "minimal" is the zero sentinel. + baseTeamID := id.PlatformTeamID + baseTemplateID := id.MinimalTemplateID + if build.BaseTemplate != "minimal" { + baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate) + if err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err)) + return + } + baseTeamID = baseTmpl.TeamID + baseTemplateID = baseTmpl.ID + } + + resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{ + SandboxId: sandboxIDStr, + Template: build.BaseTemplate, + TeamId: id.UUIDString(baseTeamID), + TemplateId: id.UUIDString(baseTemplateID), + Vcpus: build.Vcpus, + MemoryMb: build.MemoryMb, + TimeoutSec: 0, // no auto-pause for builds + DiskSizeMb: 5120, // 5 GB for template builds + })) + if err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) + return + } + _ = resp + + // Record sandbox/host association. + _ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{ + ID: buildID, + SandboxID: sandboxID, + HostID: host.ID, + }) + + // Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always + // valid; panic on error is appropriate here since it would be a programmer mistake. + preBuildSteps, err := recipe.ParseRecipe(preBuildCmds) + if err != nil { + panic(fmt.Sprintf("invalid pre-build recipe: %v", err)) + } + userRecipeSteps, err := recipe.ParseRecipe(userRecipe) + if err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err)) + return + } + postBuildSteps, err := recipe.ParseRecipe(postBuildCmds) + if err != nil { + panic(fmt.Sprintf("invalid post-build recipe: %v", err)) + } + + // Execute build phases: pre-build → user recipe → post-build. + // bctx carries working directory and env vars across all phases. + var logs []recipe.BuildLogEntry + step := 0 + bctx := &recipe.ExecContext{} + + runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool { + newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec) + logs = append(logs, newEntries...) + step = nextStep + s.updateLogs(buildCtx, buildID, step, logs) + if !ok { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + // If the build was cancelled, status is already set — don't overwrite with "failed". + if buildCtx.Err() != nil { + return false + } + last := newEntries[len(newEntries)-1] + reason := last.Stderr + if reason == "" { + reason = fmt.Sprintf("exit code %d", last.Exit) + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason)) + } + return ok + } + + if !build.SkipPrePost { + if !runPhase("pre-build", preBuildSteps, 0) { + return + } + } + if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) { + return + } + if !build.SkipPrePost { + if !runPhase("post-build", postBuildSteps, 0) { + return + } + } + + // Healthcheck or direct snapshot. + var sizeBytes int64 + if build.Healthcheck != "" { + log.Info("running healthcheck", "cmd", build.Healthcheck) + if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, build.Healthcheck); err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err)) + return + } + + // Healthcheck passed → full snapshot (with memory/CPU state). + log.Info("healthcheck passed, creating snapshot") + snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{ + SandboxId: sandboxIDStr, + Name: build.Name, + TeamId: id.UUIDString(build.TeamID), + TemplateId: id.UUIDString(build.TemplateID), + })) + if err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) + return + } + sizeBytes = snapResp.Msg.SizeBytes + } else { + // No healthcheck → image-only template (rootfs only). + log.Info("no healthcheck, flattening rootfs") + flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{ + SandboxId: sandboxIDStr, + Name: build.Name, + TeamId: id.UUIDString(build.TeamID), + TemplateId: id.UUIDString(build.TemplateID), + })) + if err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) + return + } + sizeBytes = flatResp.Msg.SizeBytes + } + + // Insert into templates table as a global (platform) template. + templateType := "base" + if build.Healthcheck != "" { + templateType = "snapshot" + } + + if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{ + ID: build.TemplateID, + Name: build.Name, + Type: templateType, + Vcpus: build.Vcpus, + MemoryMb: build.MemoryMb, + SizeBytes: sizeBytes, + TeamID: id.PlatformTeamID, + }); err != nil { + log.Error("failed to insert template record", "error", err) + // Build succeeded on disk, just DB record failed — don't mark as failed. + } + + // For CreateSnapshot, the sandbox is already destroyed by the snapshot process. + // For FlattenRootfs, the sandbox is already destroyed by the flatten process. + // No additional destroy needed. + + // Mark build as success. + if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "success", + }); err != nil { + log.Error("failed to mark build as success", "error", err) + } + + log.Info("template build completed successfully", "name", build.Name) +} + +func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr, cmd string) error { + deadline := time.NewTimer(healthcheckTimeout) + defer deadline.Stop() + ticker := time.NewTicker(healthcheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-deadline.C: + return fmt.Errorf("healthcheck timed out after %s", healthcheckTimeout) + case <-ticker.C: + execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxIDStr, + Cmd: "/bin/sh", + Args: []string{"-c", cmd}, + TimeoutSec: 10, + })) + cancel() + + if err != nil { + slog.Debug("healthcheck exec error (retrying)", "error", err) + continue + } + if resp.Msg.ExitCode == 0 { + return nil + } + slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode) + } + } +} + +func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) { + logsJSON, err := json.Marshal(logs) + if err != nil { + slog.Warn("failed to marshal build logs", "error", err) + return + } + if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{ + ID: buildID, + CurrentStep: int32(step), + Logs: logsJSON, + }); err != nil { + slog.Warn("failed to update build progress", "error", err) + } +} + +func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) { + slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg) + // Use a detached context so DB writes survive parent context cancellation (e.g. shutdown). + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{ + ID: buildID, + Error: errMsg, + }); err != nil { + slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err) + } +} + +func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) { + // Use a detached context so cleanup succeeds even during shutdown. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ + SandboxId: sandboxIDStr, + })); err != nil { + slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err) + } +} diff --git a/internal/service/host.go b/internal/service/host.go index b3538dfb..74018ebe 100644 --- a/internal/service/host.go +++ b/internal/service/host.go @@ -27,15 +27,16 @@ type HostService struct { Redis *redis.Client JWT []byte Pool *lifecycle.HostClientPool + CA *auth.CA // nil disables mTLS cert issuance (dev/test environments) } // HostCreateParams holds the parameters for creating a host. type HostCreateParams struct { Type string - TeamID string // required for BYOC, empty for regular + TeamID pgtype.UUID // required for BYOC, zero value for regular Provider string AvailabilityZone string - RequestingUserID string + RequestingUserID pgtype.UUID IsRequestorAdmin bool } @@ -55,18 +56,28 @@ type HostRegisterParams struct { Address string } -// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token. +// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived +// refresh token, and optionally the host's mTLS certificate material. type HostRegisterResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } -// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh. +// HostRefreshResult holds a new JWT and rotated refresh token after a successful +// refresh, plus refreshed mTLS certificate material when CA is configured. type HostRefreshResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } // HostDeletePreview describes what will be affected by deleting a host. @@ -103,7 +114,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat } } else { // BYOC: platform admin, or team owner/admin. - if p.TeamID == "" { + if !p.TeamID.Valid { return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts") } if !p.IsRequestorAdmin { @@ -124,7 +135,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat } // Validate team exists, is not deleted, and has BYOC enabled. - if p.TeamID != "" { + if p.TeamID.Valid { team, err := s.DB.GetTeam(ctx, p.TeamID) if err != nil || team.DeletedAt.Valid { return HostCreateResult{}, fmt.Errorf("invalid request: team not found") @@ -136,25 +147,12 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat hostID := id.NewHostID() - var teamID pgtype.Text - if p.TeamID != "" { - teamID = pgtype.Text{String: p.TeamID, Valid: true} - } - var provider pgtype.Text - if p.Provider != "" { - provider = pgtype.Text{String: p.Provider, Valid: true} - } - var az pgtype.Text - if p.AvailabilityZone != "" { - az = pgtype.Text{String: p.AvailabilityZone, Valid: true} - } - host, err := s.DB.InsertHost(ctx, db.InsertHostParams{ ID: hostID, Type: p.Type, - TeamID: teamID, - Provider: provider, - AvailabilityZone: az, + TeamID: p.TeamID, + Provider: p.Provider, + AvailabilityZone: p.AvailabilityZone, CreatedBy: p.RequestingUserID, }) if err != nil { @@ -166,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat tokenID := id.NewHostTokenID() payload, _ := json.Marshal(regTokenPayload{ - HostID: hostID, - TokenID: tokenID, + HostID: id.FormatHostID(hostID), + TokenID: id.FormatHostTokenID(tokenID), }) if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) @@ -180,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat CreatedBy: p.RequestingUserID, ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) } return HostCreateResult{Host: host, RegistrationToken: token}, nil @@ -189,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat // RegenerateToken issues a new registration token for a host still in "pending" // status. This allows retry when a previous registration attempt failed after // the original token was consumed. -func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) { +func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return HostCreateResult{}, fmt.Errorf("host not found: %w", err) @@ -202,7 +200,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI if host.Type != "byoc" { return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts") } - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team") } membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ @@ -224,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI tokenID := id.NewHostTokenID() payload, _ := json.Marshal(regTokenPayload{ - HostID: hostID, - TokenID: tokenID, + HostID: id.FormatHostID(hostID), + TokenID: id.FormatHostTokenID(tokenID), }) if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) @@ -238,7 +236,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI CreatedBy: userID, ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) } return HostCreateResult{Host: host, RegistrationToken: token}, nil @@ -262,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("corrupted registration token") } - if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil { + hostID, err := id.ParseHostID(payload.HostID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err) + } + tokenID, err := id.ParseHostTokenID(payload.TokenID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err) + } + + if _, err := s.DB.GetHost(ctx, hostID); err != nil { return HostRegisterResult{}, fmt.Errorf("host not found: %w", err) } // Sign JWT before mutating DB — if signing fails, the host stays pending. - hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID) + hostJWT, err := auth.SignHostJWT(s.JWT, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) } + // Issue mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err) + } + } + // Atomically update only if still pending (defense-in-depth against races). rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ - ID: payload.HostID, - Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""}, - CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0}, - MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0}, - DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0}, - Address: pgtype.Text{String: p.Address, Valid: p.Address != ""}, + ID: hostID, + Arch: p.Arch, + CpuCores: p.CPUCores, + MemoryMb: p.MemoryMB, + DiskGb: p.DiskGB, + Address: p.Address, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil}, }) if err != nil { return HostRegisterResult{}, fmt.Errorf("register host: %w", err) @@ -289,23 +307,29 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR } // Mark audit trail. - if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil { + if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil { slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err) } // Issue a long-lived refresh token. - refreshToken, err := s.issueRefreshToken(ctx, payload.HostID) + refreshToken, err := s.issueRefreshToken(ctx, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err) } // Re-fetch the host to get the updated state. - host, err := s.DB.GetHost(ctx, payload.HostID) + host, err := s.DB.GetHost(ctx, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) } - return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil + result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // Refresh validates a refresh token, rotates it (revokes old, issues new), @@ -332,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err) } + // Renew mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err) + } + if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{ + ID: host.ID, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true}, + }); err != nil { + return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err) + } + } + // Issue-then-revoke rotation: insert new token first so a crash between // the two DB calls leaves the host with two valid tokens rather than zero. newRefreshToken, err := s.issueRefreshToken(ctx, host.ID) @@ -344,12 +384,18 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err) } - return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil + result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // issueRefreshToken creates a new refresh token record in the DB and returns // the opaque token string. -func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) { +func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) { token := id.NewRefreshToken() hash := hashToken(token) now := time.Now() @@ -375,7 +421,7 @@ func hashToken(token string) string { // Heartbeat updates the last heartbeat timestamp for a host and transitions // any 'unreachable' host back to 'online'. Returns a "host not found" error // (which becomes 404) if the host record no longer exists (e.g., was deleted). -func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { +func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error { n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID) if err != nil { return err @@ -388,21 +434,21 @@ func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { // List returns hosts visible to the caller. // Admins see all hosts; non-admins see only BYOC hosts belonging to their team. -func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) { +func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) { if isAdmin { return s.DB.ListHosts(ctx) } - return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true}) + return s.DB.ListHostsByTeam(ctx, teamID) } // Get returns a single host, enforcing access control. -func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) { +func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return db.Host{}, fmt.Errorf("host not found: %w", err) } if !isAdmin { - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return db.Host{}, fmt.Errorf("host not found") } } @@ -411,8 +457,8 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo // DeletePreview returns what would be affected by deleting the host, without // making any changes. Use this to show the user a confirmation prompt. -func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) { - host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin) +func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) { + host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin) if err != nil { return HostDeletePreview{}, err } @@ -427,7 +473,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, ids := make([]string, len(sandboxes)) for i, sb := range sandboxes { - ids[i] = sb.ID + ids[i] = id.FormatSandboxID(sb.ID) } return HostDeletePreview{Host: host, SandboxIDs: ids}, nil @@ -436,7 +482,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, // Delete removes a host. Without force it returns an error listing active // sandboxes so the caller can present a confirmation. With force it gracefully // destroys all running sandboxes before deleting the host record. -func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error { +func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error { host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin) if err != nil { return err @@ -453,35 +499,37 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, if len(sandboxes) > 0 && !force { ids := make([]string, len(sandboxes)) for i, sb := range sandboxes { - ids[i] = sb.ID + ids[i] = id.FormatSandboxID(sb.ID) } return &HostHasSandboxesError{SandboxIDs: ids} } + hostIDStr := id.FormatHostID(hostID) + // Gracefully destroy running sandboxes and terminate the agent (best-effort). - if host.Address.Valid && host.Address.String != "" { + if host.Address != "" { agent, err := s.Pool.GetForHost(host) if err == nil { for _, sb := range sandboxes { if sb.Status == "running" || sb.Status == "starting" { _, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sb.ID, + SandboxId: id.FormatSandboxID(sb.ID), })) if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound { - slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr) + slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr) } } } // Tell the agent to shut itself down immediately. if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil { - slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostID, "error", rpcErr) + slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr) } } } // Mark all affected sandboxes as stopped in DB. if len(sandboxes) > 0 { - sbIDs := make([]string, len(sandboxes)) + sbIDs := make([]pgtype.UUID, len(sandboxes)) for i, sb := range sandboxes { sbIDs[i] = sb.ID } @@ -489,18 +537,18 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, Column1: sbIDs, Status: "stopped", }); err != nil { - slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err) + slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err) } } // Revoke all refresh tokens for this host. if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil { - slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err) + slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err) } // Evict the client from the pool so no further RPCs are sent. if s.Pool != nil { - s.Pool.Evict(hostID) + s.Pool.Evict(id.FormatHostID(hostID)) } return s.DB.DeleteHost(ctx, hostID) @@ -508,7 +556,7 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, // checkDeletePermission verifies the caller has permission to delete the given // host and returns the host record on success. -func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) { +func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return db.Host{}, fmt.Errorf("host not found: %w", err) @@ -521,11 +569,11 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, if host.Type != "byoc" { return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts") } - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team") } - if userID != "" { + if userID.Valid { membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: userID, TeamID: teamID, @@ -545,7 +593,7 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, } // AddTag adds a tag to a host. -func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { +func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return err } @@ -553,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin } // RemoveTag removes a tag from a host. -func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { +func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return err } @@ -561,7 +609,7 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd } // ListTags returns all tags for a host. -func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) { +func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return nil, err } diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index 142b9bd0..2d1f68c4 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -27,15 +27,16 @@ type SandboxService struct { // SandboxCreateParams holds the parameters for creating a sandbox. type SandboxCreateParams struct { - TeamID string + TeamID pgtype.UUID Template string VCPUs int32 MemoryMB int32 TimeoutSec int32 + DiskSizeMB int32 } // agentForSandbox looks up the host for the given sandbox and returns a client. -func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) { +func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) { sb, err := s.DB.GetSandbox(ctx, sandboxID) if err != nil { return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -77,18 +78,28 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. if p.MemoryMB <= 0 { p.MemoryMB = 512 } + if p.DiskSizeMB <= 0 { + p.DiskSizeMB = 5120 // 5 GB default + } - // If the template is a snapshot, use its baked-in vcpus/memory. - if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" { - if tmpl.Vcpus.Valid { - p.VCPUs = tmpl.Vcpus.Int32 + // Resolve template name → (teamID, templateID). + templateTeamID := id.PlatformTeamID + templateID := id.MinimalTemplateID + if p.Template != "minimal" { + tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}) + if err != nil { + return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err) } - if tmpl.MemoryMb.Valid { - p.MemoryMB = tmpl.MemoryMb.Int32 + templateTeamID = tmpl.TeamID + templateID = tmpl.ID + // If the template is a snapshot, use its baked-in vcpus/memory. + if tmpl.Type == "snapshot" { + p.VCPUs = tmpl.Vcpus + p.MemoryMB = tmpl.MemoryMb } } - if p.TeamID == "" { + if !p.TeamID.Valid { return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required") } @@ -110,32 +121,39 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } sandboxID := id.NewSandboxID() + sandboxIDStr := id.FormatSandboxID(sandboxID) if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{ - ID: sandboxID, - TeamID: p.TeamID, - HostID: host.ID, - Template: p.Template, - Status: "pending", - Vcpus: p.VCPUs, - MemoryMb: p.MemoryMB, - TimeoutSec: p.TimeoutSec, + ID: sandboxID, + TeamID: p.TeamID, + HostID: host.ID, + Template: p.Template, + Status: "pending", + Vcpus: p.VCPUs, + MemoryMb: p.MemoryMB, + TimeoutSec: p.TimeoutSec, + DiskSizeMb: p.DiskSizeMB, + TemplateID: templateID, + TemplateTeamID: templateTeamID, }); err != nil { return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err) } resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Template: p.Template, + TeamId: id.UUIDString(templateTeamID), + TemplateId: id.UUIDString(templateID), Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TimeoutSec: p.TimeoutSec, + DiskSizeMb: p.DiskSizeMB, })) if err != nil { if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ ID: sandboxID, Status: "error", }); dbErr != nil { - slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr) + slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr) } return db.Sandbox{}, fmt.Errorf("agent create: %w", err) } @@ -158,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } // List returns active sandboxes (excludes stopped/error) belonging to the given team. -func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) { +func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) { return s.DB.ListSandboxesByTeam(ctx, teamID) } // Get returns a single sandbox by ID, scoped to the given team. -func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) } // Pause snapshots and freezes a running sandbox to disk. -func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -182,26 +200,40 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d return db.Sandbox{}, err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + + // Pre-mark as "paused" in DB before the RPC so the reconciler does not + // mark the sandbox "stopped" while the host agent processes the pause. + if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "paused", + }); err != nil { + return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err) + } + // Flush all metrics tiers before pausing so data survives in DB. s.flushAndPersistMetrics(ctx, agent, sandboxID, true) if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil { + // Revert status on failure. + if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "running", + }); dbErr != nil { + slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr) + } return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) } - sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: sandboxID, Status: "paused", - }) + sb, err = s.DB.GetSandbox(ctx, sandboxID) if err != nil { - return db.Sandbox{}, fmt.Errorf("update status: %w", err) + return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err) } return sb, nil } // Resume restores a paused sandbox from snapshot. -func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -215,8 +247,10 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) ( return db.Sandbox{}, err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, TimeoutSec: sb.TimeoutSec, })) if err != nil { @@ -240,7 +274,7 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) ( } // Destroy stops a sandbox and marks it as stopped. -func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error { +func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return fmt.Errorf("sandbox not found: %w", err) @@ -251,6 +285,8 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) return err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + // If running, flush 24h tier metrics for analytics before destroying. if sb.Status == "running" { s.flushAndPersistMetrics(ctx, agent, sandboxID, false) @@ -258,7 +294,7 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) // Destroy on host agent. A not-found response is fine — sandbox is already gone. if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { return fmt.Errorf("agent destroy: %w", err) } @@ -284,12 +320,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) // flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores // the returned data to DB. If allTiers is true, all three tiers are saved; // otherwise only the 24h tier (for post-destroy analytics). -func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID string, allTiers bool) { +func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) { + sandboxIDStr := id.FormatSandboxID(sandboxID) resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })) if err != nil { - slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxID, "error", err) + slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err) return } msg := resp.Msg @@ -301,7 +338,8 @@ func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hosta s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H) } -func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tier string, points []*pb.MetricPoint) { +func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) { + sandboxIDStr := id.FormatSandboxID(sandboxID) for _, p := range points { if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{ SandboxID: sandboxID, @@ -311,13 +349,13 @@ func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tie MemBytes: p.MemBytes, DiskBytes: p.DiskBytes, }); err != nil { - slog.Warn("persist metric point failed", "sandbox_id", sandboxID, "tier", tier, "error", err) + slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err) } } } // Ping resets the inactivity timer for a running sandbox. -func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error { +func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return fmt.Errorf("sandbox not found: %w", err) @@ -331,8 +369,10 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err return err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil { return fmt.Errorf("agent ping: %w", err) } @@ -344,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err Valid: true, }, }); err != nil { - slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err) + slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err) } return nil } diff --git a/internal/service/stats.go b/internal/service/stats.go index 1a075aa0..88cace72 100644 --- a/internal/service/stats.go +++ b/internal/service/stats.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -72,7 +73,7 @@ type StatsService struct { // GetStats returns current stats, 30-day peaks, and a time-series for the // given team and time range. If no snapshots exist yet, zeros are returned. -func (s *StatsService) GetStats(ctx context.Context, teamID string, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) { +func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) { cfg, ok := rangeConfigs[r] if !ok { return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r) @@ -132,7 +133,7 @@ GROUP BY bucket ORDER BY bucket ASC ` -func (s *StatsService) queryTimeSeries(ctx context.Context, teamID string, cfg rangeConfig) ([]StatPoint, error) { +func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) { rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral) if err != nil { return nil, err diff --git a/internal/service/team.go b/internal/service/team.go index d4c911c2..a7acbac1 100644 --- a/internal/service/team.go +++ b/internal/service/team.go @@ -9,6 +9,7 @@ import ( "connectrpc.com/connect" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -43,7 +44,7 @@ type MemberInfo struct { // callerRole fetches the calling user's role in the given team from DB. // Returns an error wrapping "forbidden" if the caller is not a member. -func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID string) (string, error) { +func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) { m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: callerUserID, TeamID: teamID, @@ -66,7 +67,7 @@ func requireAdmin(role string) error { } // GetTeam returns the team by ID. Returns an error if the team is deleted or not found. -func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, error) { +func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) { team, err := s.DB.GetTeam(ctx, teamID) if err != nil { if err == pgx.ErrNoRows { @@ -81,7 +82,7 @@ func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, erro } // ListTeamsForUser returns all active teams the user belongs to, with their role in each. -func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]TeamWithRole, error) { +func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) { rows, err := s.DB.GetTeamsForUser(ctx, userID) if err != nil { return nil, fmt.Errorf("list teams: %w", err) @@ -97,7 +98,7 @@ func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]Te } // CreateTeam creates a new team owned by the given user. -func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string) (TeamWithRole, error) { +func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) { if !teamNameRE.MatchString(name) { return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _") } @@ -137,7 +138,7 @@ func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string) } // RenameTeam updates the team name. Caller must be admin or owner (verified from DB). -func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newName string) error { +func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error { if !teamNameRE.MatchString(newName) { return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _") } @@ -159,7 +160,7 @@ func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newN // DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes. // Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates) // are preserved; only the team's deleted_at is set and active VMs are stopped. -func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID string) error { +func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -174,16 +175,16 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin return fmt.Errorf("list active sandboxes: %w", err) } - var stopIDs []string + var stopIDs []pgtype.UUID for _, sb := range sandboxes { host, hostErr := s.DB.GetHost(ctx, sb.HostID) if hostErr == nil { agent, agentErr := s.HostPool.GetForHost(host) if agentErr == nil { if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sb.ID, + SandboxId: id.FormatSandboxID(sb.ID), })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err) + slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err) } } } @@ -201,14 +202,63 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin } } + // Clean up team-owned templates from all hosts in the background. + go s.cleanupTeamTemplates(context.Background(), teamID) + if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil { return fmt.Errorf("soft delete team: %w", err) } return nil } +// cleanupTeamTemplates deletes all template files for a team from all online hosts, +// then removes the DB records. Called asynchronously during team deletion. +func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) { + templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID) + if err != nil { + slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err) + return + } + if len(templates) == 0 { + return + } + + hosts, err := s.DB.ListActiveHosts(ctx) + if err != nil { + slog.Warn("team delete: failed to list hosts for template cleanup", "error", err) + return + } + + for _, tmpl := range templates { + for _, host := range hosts { + if host.Status != "online" { + continue + } + agent, err := s.HostPool.GetForHost(host) + if err != nil { + continue + } + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: id.UUIDString(tmpl.TeamID), + TemplateId: id.UUIDString(tmpl.ID), + })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("team delete: failed to delete template on host", + "host_id", id.FormatHostID(host.ID), + "template", tmpl.Name, + "error", err, + ) + } + } + } + + // Remove DB records. + if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil { + slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err) + } +} + // GetMembers returns all members of the team with their emails and roles. -func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberInfo, error) { +func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) { rows, err := s.DB.GetTeamMembers(ctx, teamID) if err != nil { return nil, fmt.Errorf("get members: %w", err) @@ -220,7 +270,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn joinedAt = r.JoinedAt.Time } members[i] = MemberInfo{ - UserID: r.ID, + UserID: id.FormatUserID(r.ID), Name: r.Name, Email: r.Email, Role: r.Role, @@ -232,7 +282,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn // AddMember adds an existing user (looked up by email) to the team as a member. // Caller must be admin or owner (verified from DB). -func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email string) (MemberInfo, error) { +func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return MemberInfo{}, err @@ -269,12 +319,12 @@ func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email return MemberInfo{}, fmt.Errorf("insert member: %w", err) } - return MemberInfo{UserID: target.ID, Name: target.Name, Email: target.Email, Role: "member"}, nil + return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil } // RemoveMember removes a user from the team. // Caller must be admin or owner (verified from DB). Owner cannot be removed. -func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID string) error { +func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error { callerRole, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -310,7 +360,7 @@ func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, ta // UpdateMemberRole changes a member's role to admin or member. // Caller must be admin or owner (verified from DB). Owner's role cannot be changed. // Valid target roles: "admin", "member". -func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID, newRole string) error { +func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error { if newRole != "admin" && newRole != "member" { return fmt.Errorf("invalid: role must be admin or member") } @@ -350,7 +400,7 @@ func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID // LeaveTeam removes the calling user from the team. // The owner cannot leave; they must delete the team instead. -func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string) error { +func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -371,7 +421,7 @@ func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string // SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot // be disabled — it is a one-way transition. // Admin-only — the caller must verify admin status before invoking this. -func (s *TeamService) SetBYOC(ctx context.Context, teamID string, enabled bool) error { +func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error { team, err := s.DB.GetTeam(ctx, teamID) if err != nil { return fmt.Errorf("team not found: %w", err) diff --git a/internal/service/template.go b/internal/service/template.go index d669e455..22bc4d60 100644 --- a/internal/service/template.go +++ b/internal/service/template.go @@ -3,6 +3,8 @@ package service import ( "context" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" ) @@ -14,7 +16,7 @@ type TemplateService struct { // List returns all templates belonging to the given team. If typeFilter is // non-empty, only templates of that type ("base" or "snapshot") are returned. -func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) { +func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) { if typeFilter != "" { return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{ TeamID: teamID, diff --git a/internal/snapshot/memfile.go b/internal/snapshot/memfile.go index aabe8851..f7b14f9a 100644 --- a/internal/snapshot/memfile.go +++ b/internal/snapshot/memfile.go @@ -4,6 +4,7 @@ package snapshot import ( + "context" "fmt" "io" "os" @@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe return header, nil } +// MergeDiffs consolidates multiple generation diff files into a single diff +// file and resets the generation counter to 0. This is a pure file-level +// operation — no Firecracker involvement. +// +// It reads each non-nil block from the appropriate diff file (as mapped by +// the header), writes them all sequentially into a single new diff file, +// and produces a fresh header pointing only at that file. +// +// diffFiles maps build ID (string) → open file path for each generation's diff. +func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) { + blockSize := int64(header.Metadata.BlockSize) + mergedBuildID := uuid.New() + + // Open all source diff files. + sources := make(map[string]*os.File, len(diffFiles)) + for id, path := range diffFiles { + f, err := os.Open(path) + if err != nil { + // Close already opened files. + for _, sf := range sources { + sf.Close() + } + return nil, fmt.Errorf("open diff file for build %s: %w", id, err) + } + sources[id] = f + } + defer func() { + for _, f := range sources { + f.Close() + } + }() + + dst, err := os.Create(mergedDiffPath) + if err != nil { + return nil, fmt.Errorf("create merged diff file: %w", err) + } + defer dst.Close() + + totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize) + dirty := make([]bool, totalBlocks) + empty := make([]bool, totalBlocks) + buf := make([]byte, blockSize) + + for i := int64(0); i < totalBlocks; i++ { + offset := i * blockSize + mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset) + if err != nil { + return nil, fmt.Errorf("lookup block %d: %w", i, err) + } + + if *buildID == uuid.Nil { + empty[i] = true + continue + } + + src, ok := sources[buildID.String()] + if !ok { + return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i) + } + + if _, err := src.ReadAt(buf, mappedOffset); err != nil { + return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err) + } + + dirty[i] = true + if _, err := dst.Write(buf); err != nil { + return nil, fmt.Errorf("write merged block %d: %w", i, err) + } + } + + // Build fresh header with generation 0. + dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize) + emptyMappings := CreateMapping(uuid.Nil, empty, blockSize) + merged := MergeMappings(dirtyMappings, emptyMappings) + normalized := NormalizeMappings(merged) + + metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size) + newHeader, err := NewHeader(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("create merged header: %w", err) + } + + headerData, err := Serialize(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("serialize merged header: %w", err) + } + if err := os.WriteFile(headerPath, headerData, 0644); err != nil { + return nil, fmt.Errorf("write merged header: %w", err) + } + + return newHeader, nil +} + // isZeroBlock checks if a block is entirely zero bytes. func isZeroBlock(block []byte) bool { // Fast path: compare 8 bytes at a time. diff --git a/internal/validate/name_test.go b/internal/validate/name_test.go index 4b7769e2..f17e210a 100644 --- a/internal/validate/name_test.go +++ b/internal/validate/name_test.go @@ -11,7 +11,7 @@ func TestSafeName(t *testing.T) { {"simple", "minimal", false}, {"with-dash", "template-abc123", false}, {"with-dot", "my-snapshot.v2", false}, - {"sandbox-id", "sb-12345678", false}, + {"sandbox-id", "cl-12345678", false}, {"single-char", "a", false}, {"numbers", "123", false}, {"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false}, diff --git a/internal/vm/config.go b/internal/vm/config.go index 35bc2939..0c1f2582 100644 --- a/internal/vm/config.go +++ b/internal/vm/config.go @@ -4,9 +4,13 @@ import "fmt" // VMConfig holds the configuration for creating a Firecracker microVM. type VMConfig struct { - // SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4"). + // SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4"). SandboxID string + // TemplateID is the template UUID string used to populate MMDS metadata + // so that envd can read WRENN_TEMPLATE_ID from inside the guest. + TemplateID string + // KernelPath is the path to the uncompressed Linux kernel (vmlinux). KernelPath string @@ -91,7 +95,7 @@ func (c *VMConfig) kernelArgs() string { ) return fmt.Sprintf( - "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s", + "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s", c.InitPath, ipArg, ) } diff --git a/internal/vm/fc.go b/internal/vm/fc.go index b5af5dbb..3d0f246d 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -101,6 +101,31 @@ func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error }) } +// setMMDSConfig enables MMDS V2 token-based access on the given network interface. +// Must be called before startVM. +func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error { + return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{ + "version": "V2", + "network_interfaces": []string{ifaceID}, + }) +} + +// mmdsMetadata is the metadata payload written to the Firecracker MMDS store. +// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID. +type mmdsMetadata struct { + SandboxID string `json:"instanceID"` + TemplateID string `json:"envID"` +} + +// setMMDS writes sandbox metadata to the Firecracker MMDS store. +// Can be called after the VM has started. +func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error { + return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{ + SandboxID: sandboxID, + TemplateID: templateID, + }) +} + // startVM issues the InstanceStart action. func (c *fcClient) startVM(ctx context.Context) error { return c.do(ctx, http.MethodPut, "/actions", map[string]string{ diff --git a/internal/vm/manager.go b/internal/vm/manager.go index c7e34797..9e9466f5 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -71,6 +71,13 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) { return nil, fmt.Errorf("start VM: %w", err) } + // Step 5: Push sandbox metadata into MMDS so envd can read + // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. + if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil { + _ = proc.stop() + return nil, fmt.Errorf("set MMDS metadata: %w", err) + } + vm := &VM{ Config: cfg, process: proc, @@ -108,6 +115,12 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error { return fmt.Errorf("set machine config: %w", err) } + // MMDS config — enable V2 token access on eth0 so that envd can read + // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. + if err := client.setMMDSConfig(ctx, "eth0"); err != nil { + return fmt.Errorf("set MMDS config: %w", err) + } + return nil } @@ -238,6 +251,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath return nil, fmt.Errorf("resume VM: %w", err) } + // Step 5: Push sandbox metadata into MMDS. + if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil { + _ = proc.stop() + return nil, fmt.Errorf("set MMDS metadata: %w", err) + } + vm := &VM{ Config: cfg, process: proc, diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index f496b2cb..516e4d25 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -25,7 +25,7 @@ type CreateSandboxRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // Sandbox ID assigned by the control plane. If empty, the host agent generates one. SandboxId string `protobuf:"bytes,5,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - // Template name (e.g., "minimal", "python311"). Determines base rootfs. + // Deprecated: use team_id + template_id instead. Template string `protobuf:"bytes,1,opt,name=template,proto3" json:"template,omitempty"` // Number of virtual CPUs (default: 1). Vcpus int32 `protobuf:"varint,2,opt,name=vcpus,proto3" json:"vcpus,omitempty"` @@ -33,7 +33,14 @@ type CreateSandboxRequest struct { MemoryMb int32 `protobuf:"varint,3,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` // TTL in seconds. Sandbox is auto-paused after this duration of // inactivity. 0 means no auto-pause. - TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + // Disk size in MB for the rootfs. Base images are expanded to this size + // at host agent startup. Default: 5120 (5 GB). + DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"` + // Team UUID that owns the template (hex string). All-zeros = platform. + TeamId string `protobuf:"bytes,7,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel. + TemplateId string `protobuf:"bytes,8,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -103,6 +110,27 @@ func (x *CreateSandboxRequest) GetTimeoutSec() int32 { return 0 } +func (x *CreateSandboxRequest) GetDiskSizeMb() int32 { + if x != nil { + return x.DiskSizeMb + } + return 0 +} + +func (x *CreateSandboxRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *CreateSandboxRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type CreateSandboxResponse struct { state protoimpl.MessageState `protogen:"open.v1"` SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` @@ -438,9 +466,14 @@ func (x *ResumeSandboxResponse) GetHostIp() string { } type CreateSnapshotRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that will own the new template. + TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID for the new snapshot template. + TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -489,6 +522,20 @@ func (x *CreateSnapshotRequest) GetName() string { return "" } +func (x *CreateSnapshotRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *CreateSnapshotRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type CreateSnapshotResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -542,8 +589,13 @@ func (x *CreateSnapshotResponse) GetSizeBytes() int64 { } type DeleteSnapshotRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that owns the template. + TeamId string `protobuf:"bytes,2,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID to delete. + TemplateId string `protobuf:"bytes,3,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -585,6 +637,20 @@ func (x *DeleteSnapshotRequest) GetName() string { return "" } +func (x *DeleteSnapshotRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *DeleteSnapshotRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type DeleteSnapshotResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -841,16 +907,19 @@ func (x *ListSandboxesResponse) GetAutoPausedSandboxIds() []string { } type SandboxInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` - Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"` - Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"` - MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` - HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"` - CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"` - LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"` - TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + // Deprecated: use team_id + template_id instead. + Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"` + Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"` + MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` + HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"` + CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"` + LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"` + TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + TeamId string `protobuf:"bytes,10,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + TemplateId string `protobuf:"bytes,11,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -948,6 +1017,20 @@ func (x *SandboxInfo) GetTimeoutSec() int32 { return 0 } +func (x *SandboxInfo) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *SandboxInfo) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type WriteFileRequest struct { state protoimpl.MessageState `protogen:"open.v1"` SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` @@ -2171,11 +2254,126 @@ func (x *FlushSandboxMetricsResponse) GetPoints_24H() []*MetricPoint { return nil } +type FlattenRootfsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that will own the resulting template. + TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID for the output. + TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FlattenRootfsRequest) Reset() { + *x = FlattenRootfsRequest{} + mi := &file_hostagent_proto_msgTypes[40] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FlattenRootfsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FlattenRootfsRequest) ProtoMessage() {} + +func (x *FlattenRootfsRequest) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[40] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FlattenRootfsRequest.ProtoReflect.Descriptor instead. +func (*FlattenRootfsRequest) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{40} +} + +func (x *FlattenRootfsRequest) GetSandboxId() string { + if x != nil { + return x.SandboxId + } + return "" +} + +func (x *FlattenRootfsRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *FlattenRootfsRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *FlattenRootfsRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + +type FlattenRootfsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + SizeBytes int64 `protobuf:"varint,1,opt,name=size_bytes,json=sizeBytes,proto3" json:"size_bytes,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FlattenRootfsResponse) Reset() { + *x = FlattenRootfsResponse{} + mi := &file_hostagent_proto_msgTypes[41] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FlattenRootfsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FlattenRootfsResponse) ProtoMessage() {} + +func (x *FlattenRootfsResponse) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[41] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FlattenRootfsResponse.ProtoReflect.Descriptor instead. +func (*FlattenRootfsResponse) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{41} +} + +func (x *FlattenRootfsResponse) GetSizeBytes() int64 { + if x != nil { + return x.SizeBytes + } + return 0 +} + var File_hostagent_proto protoreflect.FileDescriptor const file_hostagent_proto_rawDesc = "" + "\n" + - "\x0fhostagent.proto\x12\fhostagent.v1\"\xa5\x01\n" + + "\x0fhostagent.proto\x12\fhostagent.v1\"\x81\x02\n" + "\x14CreateSandboxRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x05 \x01(\tR\tsandboxId\x12\x1a\n" + @@ -2183,7 +2381,12 @@ const file_hostagent_proto_rawDesc = "" + "\x05vcpus\x18\x02 \x01(\x05R\x05vcpus\x12\x1b\n" + "\tmemory_mb\x18\x03 \x01(\x05R\bmemoryMb\x12\x1f\n" + "\vtimeout_sec\x18\x04 \x01(\x05R\n" + - "timeoutSec\"g\n" + + "timeoutSec\x12 \n" + + "\fdisk_size_mb\x18\x06 \x01(\x05R\n" + + "diskSizeMb\x12\x17\n" + + "\ateam_id\x18\a \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\b \x01(\tR\n" + + "templateId\"g\n" + "\x15CreateSandboxResponse\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + @@ -2206,17 +2409,23 @@ const file_hostagent_proto_rawDesc = "" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + "\x06status\x18\x02 \x01(\tR\x06status\x12\x17\n" + - "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"J\n" + + "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"\x84\x01\n" + "\x15CreateSnapshotRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + - "\x04name\x18\x02 \x01(\tR\x04name\"K\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x04 \x01(\tR\n" + + "templateId\"K\n" + "\x16CreateSnapshotResponse\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n" + "\n" + - "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"+\n" + + "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"e\n" + "\x15DeleteSnapshotRequest\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\"\x18\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x02 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x03 \x01(\tR\n" + + "templateId\"\x18\n" + "\x16DeleteSnapshotResponse\"s\n" + "\vExecRequest\x12\x1d\n" + "\n" + @@ -2232,7 +2441,7 @@ const file_hostagent_proto_rawDesc = "" + "\x14ListSandboxesRequest\"\x87\x01\n" + "\x15ListSandboxesResponse\x127\n" + "\tsandboxes\x18\x01 \x03(\v2\x19.hostagent.v1.SandboxInfoR\tsandboxes\x125\n" + - "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xa4\x02\n" + + "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xde\x02\n" + "\vSandboxInfo\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + @@ -2244,7 +2453,11 @@ const file_hostagent_proto_rawDesc = "" + "\x0fcreated_at_unix\x18\a \x01(\x03R\rcreatedAtUnix\x12-\n" + "\x13last_active_at_unix\x18\b \x01(\x03R\x10lastActiveAtUnix\x12\x1f\n" + "\vtimeout_sec\x18\t \x01(\x05R\n" + - "timeoutSec\"_\n" + + "timeoutSec\x12\x17\n" + + "\ateam_id\x18\n" + + " \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\v \x01(\tR\n" + + "templateId\"_\n" + "\x10WriteFileRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + @@ -2319,7 +2532,17 @@ const file_hostagent_proto_rawDesc = "" + "points_10m\x18\x01 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints10m\x126\n" + "\tpoints_2h\x18\x02 \x03(\v2\x19.hostagent.v1.MetricPointR\bpoints2h\x128\n" + "\n" + - "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h2\xee\v\n" + + "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h\"\x83\x01\n" + + "\x14FlattenRootfsRequest\x12\x1d\n" + + "\n" + + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x04 \x01(\tR\n" + + "templateId\"6\n" + + "\x15FlattenRootfsResponse\x12\x1d\n" + + "\n" + + "size_bytes\x18\x01 \x01(\x03R\tsizeBytes2\xc8\f\n" + "\x10HostAgentService\x12X\n" + "\rCreateSandbox\x12\".hostagent.v1.CreateSandboxRequest\x1a#.hostagent.v1.CreateSandboxResponse\x12[\n" + "\x0eDestroySandbox\x12#.hostagent.v1.DestroySandboxRequest\x1a$.hostagent.v1.DestroySandboxResponse\x12U\n" + @@ -2338,7 +2561,8 @@ const file_hostagent_proto_rawDesc = "" + "\vPingSandbox\x12 .hostagent.v1.PingSandboxRequest\x1a!.hostagent.v1.PingSandboxResponse\x12L\n" + "\tTerminate\x12\x1e.hostagent.v1.TerminateRequest\x1a\x1f.hostagent.v1.TerminateResponse\x12d\n" + "\x11GetSandboxMetrics\x12&.hostagent.v1.GetSandboxMetricsRequest\x1a'.hostagent.v1.GetSandboxMetricsResponse\x12j\n" + - "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponseB\xb0\x01\n" + + "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponse\x12X\n" + + "\rFlattenRootfs\x12\".hostagent.v1.FlattenRootfsRequest\x1a#.hostagent.v1.FlattenRootfsResponseB\xb0\x01\n" + "\x10com.hostagent.v1B\x0eHostagentProtoP\x01Z;git.omukk.dev/wrenn/sandbox/proto/hostagent/gen;hostagentv1\xa2\x02\x03HXX\xaa\x02\fHostagent.V1\xca\x02\fHostagent\\V1\xe2\x02\x18Hostagent\\V1\\GPBMetadata\xea\x02\rHostagent::V1b\x06proto3" var ( @@ -2353,7 +2577,7 @@ func file_hostagent_proto_rawDescGZIP() []byte { return file_hostagent_proto_rawDescData } -var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 40) +var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 42) var file_hostagent_proto_goTypes = []any{ (*CreateSandboxRequest)(nil), // 0: hostagent.v1.CreateSandboxRequest (*CreateSandboxResponse)(nil), // 1: hostagent.v1.CreateSandboxResponse @@ -2395,6 +2619,8 @@ var file_hostagent_proto_goTypes = []any{ (*GetSandboxMetricsResponse)(nil), // 37: hostagent.v1.GetSandboxMetricsResponse (*FlushSandboxMetricsRequest)(nil), // 38: hostagent.v1.FlushSandboxMetricsRequest (*FlushSandboxMetricsResponse)(nil), // 39: hostagent.v1.FlushSandboxMetricsResponse + (*FlattenRootfsRequest)(nil), // 40: hostagent.v1.FlattenRootfsRequest + (*FlattenRootfsResponse)(nil), // 41: hostagent.v1.FlattenRootfsResponse } var file_hostagent_proto_depIdxs = []int32{ 16, // 0: hostagent.v1.ListSandboxesResponse.sandboxes:type_name -> hostagent.v1.SandboxInfo @@ -2423,25 +2649,27 @@ var file_hostagent_proto_depIdxs = []int32{ 33, // 23: hostagent.v1.HostAgentService.Terminate:input_type -> hostagent.v1.TerminateRequest 36, // 24: hostagent.v1.HostAgentService.GetSandboxMetrics:input_type -> hostagent.v1.GetSandboxMetricsRequest 38, // 25: hostagent.v1.HostAgentService.FlushSandboxMetrics:input_type -> hostagent.v1.FlushSandboxMetricsRequest - 1, // 26: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse - 3, // 27: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse - 5, // 28: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse - 7, // 29: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse - 13, // 30: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse - 15, // 31: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse - 18, // 32: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse - 20, // 33: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse - 9, // 34: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse - 11, // 35: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse - 22, // 36: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse - 28, // 37: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse - 30, // 38: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse - 32, // 39: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse - 34, // 40: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse - 37, // 41: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse - 39, // 42: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse - 26, // [26:43] is the sub-list for method output_type - 9, // [9:26] is the sub-list for method input_type + 40, // 26: hostagent.v1.HostAgentService.FlattenRootfs:input_type -> hostagent.v1.FlattenRootfsRequest + 1, // 27: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse + 3, // 28: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse + 5, // 29: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse + 7, // 30: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse + 13, // 31: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse + 15, // 32: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse + 18, // 33: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse + 20, // 34: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse + 9, // 35: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse + 11, // 36: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse + 22, // 37: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse + 28, // 38: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse + 30, // 39: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse + 32, // 40: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse + 34, // 41: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse + 37, // 42: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse + 39, // 43: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse + 41, // 44: hostagent.v1.HostAgentService.FlattenRootfs:output_type -> hostagent.v1.FlattenRootfsResponse + 27, // [27:45] is the sub-list for method output_type + 9, // [9:27] is the sub-list for method input_type 9, // [9:9] is the sub-list for extension type_name 9, // [9:9] is the sub-list for extension extendee 0, // [0:9] is the sub-list for field type_name @@ -2471,7 +2699,7 @@ func file_hostagent_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_hostagent_proto_rawDesc), len(file_hostagent_proto_rawDesc)), NumEnums: 0, - NumMessages: 40, + NumMessages: 42, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go index 7f0fa707..02f4ecc6 100644 --- a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go +++ b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go @@ -83,6 +83,9 @@ const ( // HostAgentServiceFlushSandboxMetricsProcedure is the fully-qualified name of the // HostAgentService's FlushSandboxMetrics RPC. HostAgentServiceFlushSandboxMetricsProcedure = "/hostagent.v1.HostAgentService/FlushSandboxMetrics" + // HostAgentServiceFlattenRootfsProcedure is the fully-qualified name of the HostAgentService's + // FlattenRootfs RPC. + HostAgentServiceFlattenRootfsProcedure = "/hostagent.v1.HostAgentService/FlattenRootfs" ) // HostAgentServiceClient is a client for the hostagent.v1.HostAgentService service. @@ -126,6 +129,11 @@ type HostAgentServiceClient interface { // FlushSandboxMetrics returns all ring buffer tiers and clears them. // Called by the control plane before pause/destroy to persist metrics to DB. FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) } // NewHostAgentServiceClient constructs a client for the hostagent.v1.HostAgentService service. By @@ -241,6 +249,12 @@ func NewHostAgentServiceClient(httpClient connect.HTTPClient, baseURL string, op connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")), connect.WithClientOptions(opts...), ), + flattenRootfs: connect.NewClient[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse]( + httpClient, + baseURL+HostAgentServiceFlattenRootfsProcedure, + connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")), + connect.WithClientOptions(opts...), + ), } } @@ -263,6 +277,7 @@ type hostAgentServiceClient struct { terminate *connect.Client[gen.TerminateRequest, gen.TerminateResponse] getSandboxMetrics *connect.Client[gen.GetSandboxMetricsRequest, gen.GetSandboxMetricsResponse] flushSandboxMetrics *connect.Client[gen.FlushSandboxMetricsRequest, gen.FlushSandboxMetricsResponse] + flattenRootfs *connect.Client[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse] } // CreateSandbox calls hostagent.v1.HostAgentService.CreateSandbox. @@ -350,6 +365,11 @@ func (c *hostAgentServiceClient) FlushSandboxMetrics(ctx context.Context, req *c return c.flushSandboxMetrics.CallUnary(ctx, req) } +// FlattenRootfs calls hostagent.v1.HostAgentService.FlattenRootfs. +func (c *hostAgentServiceClient) FlattenRootfs(ctx context.Context, req *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) { + return c.flattenRootfs.CallUnary(ctx, req) +} + // HostAgentServiceHandler is an implementation of the hostagent.v1.HostAgentService service. type HostAgentServiceHandler interface { // CreateSandbox boots a new microVM with the given configuration. @@ -391,6 +411,11 @@ type HostAgentServiceHandler interface { // FlushSandboxMetrics returns all ring buffer tiers and clears them. // Called by the control plane before pause/destroy to persist metrics to DB. FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) } // NewHostAgentServiceHandler builds an HTTP handler from the service implementation. It returns the @@ -502,6 +527,12 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")), connect.WithHandlerOptions(opts...), ) + hostAgentServiceFlattenRootfsHandler := connect.NewUnaryHandler( + HostAgentServiceFlattenRootfsProcedure, + svc.FlattenRootfs, + connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")), + connect.WithHandlerOptions(opts...), + ) return "/hostagent.v1.HostAgentService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case HostAgentServiceCreateSandboxProcedure: @@ -538,6 +569,8 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han hostAgentServiceGetSandboxMetricsHandler.ServeHTTP(w, r) case HostAgentServiceFlushSandboxMetricsProcedure: hostAgentServiceFlushSandboxMetricsHandler.ServeHTTP(w, r) + case HostAgentServiceFlattenRootfsProcedure: + hostAgentServiceFlattenRootfsHandler.ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -614,3 +647,7 @@ func (UnimplementedHostAgentServiceHandler) GetSandboxMetrics(context.Context, * func (UnimplementedHostAgentServiceHandler) FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlushSandboxMetrics is not implemented")) } + +func (UnimplementedHostAgentServiceHandler) FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlattenRootfs is not implemented")) +} diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index 214a84e0..817d5359 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -61,13 +61,19 @@ service HostAgentService { // Called by the control plane before pause/destroy to persist metrics to DB. rpc FlushSandboxMetrics(FlushSandboxMetricsRequest) returns (FlushSandboxMetricsResponse); + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + rpc FlattenRootfs(FlattenRootfsRequest) returns (FlattenRootfsResponse); + } message CreateSandboxRequest { // Sandbox ID assigned by the control plane. If empty, the host agent generates one. string sandbox_id = 5; - // Template name (e.g., "minimal", "python311"). Determines base rootfs. + // Deprecated: use team_id + template_id instead. string template = 1; // Number of virtual CPUs (default: 1). @@ -79,6 +85,16 @@ message CreateSandboxRequest { // TTL in seconds. Sandbox is auto-paused after this duration of // inactivity. 0 means no auto-pause. int32 timeout_sec = 4; + + // Disk size in MB for the rootfs. Base images are expanded to this size + // at host agent startup. Default: 5120 (5 GB). + int32 disk_size_mb = 6; + + // Team UUID that owns the template (hex string). All-zeros = platform. + string team_id = 7; + + // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel. + string template_id = 8; } message CreateSandboxResponse { @@ -115,7 +131,12 @@ message ResumeSandboxResponse { message CreateSnapshotRequest { string sandbox_id = 1; + // Deprecated: use team_id + template_id instead. string name = 2; + // Team UUID that will own the new template. + string team_id = 3; + // Template UUID for the new snapshot template. + string template_id = 4; } message CreateSnapshotResponse { @@ -124,7 +145,12 @@ message CreateSnapshotResponse { } message DeleteSnapshotRequest { + // Deprecated: use team_id + template_id instead. string name = 1; + // Team UUID that owns the template. + string team_id = 2; + // Template UUID to delete. + string template_id = 3; } message DeleteSnapshotResponse {} @@ -156,6 +182,7 @@ message ListSandboxesResponse { message SandboxInfo { string sandbox_id = 1; string status = 2; + // Deprecated: use team_id + template_id instead. string template = 3; int32 vcpus = 4; int32 memory_mb = 5; @@ -163,6 +190,8 @@ message SandboxInfo { int64 created_at_unix = 7; int64 last_active_at_unix = 8; int32 timeout_sec = 9; + string team_id = 10; + string template_id = 11; } message WriteFileRequest { @@ -284,3 +313,19 @@ message FlushSandboxMetricsResponse { repeated MetricPoint points_2h = 2; repeated MetricPoint points_24h = 3; } + +// ── FlattenRootfs ──────────────────────────────────────────────────── + +message FlattenRootfsRequest { + string sandbox_id = 1; + // Deprecated: use team_id + template_id instead. + string name = 2; + // Team UUID that will own the resulting template. + string team_id = 3; + // Template UUID for the output. + string template_id = 4; +} + +message FlattenRootfsResponse { + int64 size_bytes = 1; +} diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index ce1dd526..2f96a3a7 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -3,7 +3,10 @@ # rootfs-from-container.sh — Create a bootable Wrenn rootfs from a Docker container. # # Exports a container's filesystem, writes it into an ext4 image, injects -# envd + wrenn-init, and shrinks the image to minimum size. +# envd + wrenn-init + tini, and shrinks the image to minimum size. +# +# The container image must already include: socat, chrony, curl, ca-certificates, git. +# These are installed via apt in the container before export, not injected here. # # Usage: # bash scripts/rootfs-from-container.sh @@ -13,7 +16,7 @@ # image_name — Directory name under images dir (e.g. "waitlist") # # Output: -# ${AGENT_FILES_ROOTDIR}/images//rootfs.ext4 +# ${WRENN_DIR}/images//rootfs.ext4 # # Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini download) # Sudo is used only for mount/umount/copy-into-image operations. @@ -22,8 +25,8 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -AGENT_FILES_ROOTDIR="${AGENT_FILES_ROOTDIR:-/var/lib/wrenn}" -AGENT_IMAGES_PATH="${AGENT_FILES_ROOTDIR}/images" +WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}" +WRENN_IMAGES_PATH="${WRENN_DIR}/images" if [ $# -lt 2 ]; then echo "Usage: $0 " @@ -32,7 +35,7 @@ fi CONTAINER="$1" IMAGE_NAME="$2" -OUTPUT_DIR="${AGENT_IMAGES_PATH}/${IMAGE_NAME}" +OUTPUT_DIR="${WRENN_IMAGES_PATH}/${IMAGE_NAME}" OUTPUT_FILE="${OUTPUT_DIR}/rootfs.ext4" MOUNT_DIR="/tmp/wrenn-rootfs-build" TAR_FILE="/tmp/wrenn-rootfs-export-${IMAGE_NAME}.tar" @@ -130,16 +133,29 @@ sudo mkdir -p "${MOUNT_DIR}/sbin" sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini" sudo chmod 755 "${MOUNT_DIR}/sbin/tini" -# Step 6: Verify. +# Step 6: Verify injected binaries and required container packages. echo "" echo "==> Installed guest binaries:" ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" +echo "" +echo "==> Checking required container packages..." +MISSING_PKGS="" +for bin in socat chronyd curl git; do + if ! find "${MOUNT_DIR}" -name "${bin}" -type f 2>/dev/null | head -1 | grep -q .; then + MISSING_PKGS="${MISSING_PKGS} ${bin}" + fi +done +if [ -n "${MISSING_PKGS}" ]; then + echo "WARNING: The following binaries were not found in the container image:${MISSING_PKGS}" + echo " Install them in the container (via apt) before exporting." +fi + # Unmount before shrinking. sudo umount "${MOUNT_DIR}" rmdir "${MOUNT_DIR}" 2>/dev/null || true -# Step 7: Shrink the image to minimum size. +# Step 8: Shrink the image to minimum size. echo "" echo "==> Shrinking image..." sudo e2fsck -fy "${OUTPUT_FILE}" diff --git a/scripts/update-debug-rootfs.sh b/scripts/update-debug-rootfs.sh index 7d0544e1..bdedded2 100755 --- a/scripts/update-debug-rootfs.sh +++ b/scripts/update-debug-rootfs.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash # -# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init) into the debug rootfs. +# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init + tini) into the debug rootfs. # # This script: # 1. Builds a fresh envd static binary via make # 2. Mounts the rootfs image -# 3. Copies envd and wrenn-init into the image +# 3. Copies envd, wrenn-init, and tini into the image # 4. Unmounts cleanly # # Usage: