diff --git a/.env.example b/.env.example
index dee152cf..62e9b4ce 100644
--- a/.env.example
+++ b/.env.example
@@ -5,13 +5,13 @@ DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
REDIS_URL=redis://localhost:6379/0
# Control Plane
-CP_LISTEN_ADDR=:8000
+WRENN_CP_LISTEN_ADDR=:8080
# Host Agent
-AGENT_LISTEN_ADDR=:50051
-AGENT_FILES_ROOTDIR=/var/lib/wrenn
-AGENT_HOST_INTERFACE=eth0
-AGENT_CP_URL=http://localhost:8000
+WRENN_HOST_LISTEN_ADDR=:50051
+WRENN_DIR=/var/lib/wrenn
+WRENN_HOST_INTERFACE=eth0
+WRENN_CP_URL=http://localhost:8080
# Lago (billing — external service)
LAGO_API_URL=http://localhost:3000
@@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY=
# Auth
JWT_SECRET=
+# mTLS — CP→Agent channel
+# Generate a self-signed CA with:
+# openssl ecparam -genkey -name P-256 -noout -out ca.key
+# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca"
+# Then set these to the file contents (newlines replaced with \n or use multiline env).
+WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE-----
+WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY-----
+
# OAuth
OAUTH_GITHUB_CLIENT_ID=
OAUTH_GITHUB_CLIENT_SECRET=
diff --git a/Makefile b/Makefile
index 4a2e0b61..80fbd3a7 100644
--- a/Makefile
+++ b/Makefile
@@ -39,7 +39,7 @@ dev: dev-infra migrate-up dev-cp
dev-infra:
docker compose -f deploy/docker-compose.dev.yml up -d
@echo "Waiting for PostgreSQL..."
- @until pg_isready -h localhost -p 5432 -q; do sleep 0.5; done
+ @until docker compose -f deploy/docker-compose.dev.yml exec -T postgres pg_isready -q 2>/dev/null; do sleep 0.5; done
@echo "Dev infrastructure ready."
dev-down:
@@ -53,7 +53,7 @@ dev-agent:
sudo go run ./cmd/host-agent
dev-frontend:
- cd frontend && pnpm dev --port 5173
+ cd frontend && pnpm dev --port 5173 --host 0.0.0.0
dev-envd:
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
diff --git a/README.md b/README.md
index dff19325..e2b290f0 100644
--- a/README.md
+++ b/README.md
@@ -51,12 +51,12 @@ Copy `.env.example` to `.env` and edit:
DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
# Control plane
-CP_LISTEN_ADDR=:8000
+WRENN_CP_LISTEN_ADDR=:8000
CP_HOST_AGENT_ADDR=http://localhost:50051
# Host agent
-AGENT_LISTEN_ADDR=:50051
-AGENT_FILES_ROOTDIR=/var/lib/wrenn
+WRENN_HOST_LISTEN_ADDR=:50051
+WRENN_DIR=/var/lib/wrenn
```
### Run
@@ -69,7 +69,7 @@ make migrate-up
./builds/wrenn-cp
```
-Control plane listens on `CP_LISTEN_ADDR` (default `:8000`).
+Control plane listens on `WRENN_CP_LISTEN_ADDR` (default `:8000`).
### Host registration
@@ -87,16 +87,16 @@ Hosts must be registered with the control plane before they can serve sandboxes.
2. **Start the host agent** with the registration token and its externally-reachable address:
```bash
- sudo AGENT_CP_URL=http://cp-host:8000 \
+ sudo WRENN_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent \
--register \
--address 10.0.1.5:50051
```
- On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$AGENT_FILES_ROOTDIR/host-token`.
+ On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$WRENN_DIR/host-token`.
3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically:
```bash
- sudo AGENT_CP_URL=http://cp-host:8000 \
+ sudo WRENN_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent --address 10.0.1.5:50051
```
@@ -107,7 +107,7 @@ Hosts must be registered with the control plane before they can serve sandboxes.
```
Then restart the agent with the new token.
-The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
+The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `WRENN_HOST_LISTEN_ADDR` (default `:50051`).
### Rootfs images
diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go
index 9f84edcc..419dc875 100644
--- a/cmd/control-plane/main.go
+++ b/cmd/control-plane/main.go
@@ -15,6 +15,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/api"
"git.omukk.dev/wrenn/sandbox/internal/audit"
+ "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/config"
"git.omukk.dev/wrenn/sandbox/internal/db"
@@ -68,8 +69,52 @@ func main() {
}
slog.Info("connected to redis")
+ // mTLS: parse internal CA and build a TLS-capable host client pool.
+ // When CA env vars are absent the pool falls back to plain HTTP (dev mode).
+ var ca *auth.CA
+ if cfg.CACert != "" && cfg.CAKey != "" {
+ var err error
+ ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
+ if err != nil {
+ slog.Error("failed to parse mTLS CA from environment", "error", err)
+ os.Exit(1)
+ }
+ slog.Info("mTLS enabled: CA loaded")
+ } else {
+ slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
+ }
+
// Host client pool — manages Connect RPC clients to host agents.
- hostPool := lifecycle.NewHostClientPool()
+ var hostPool *lifecycle.HostClientPool
+ if ca != nil {
+ cpCertStore, err := auth.NewCPCertStore(ca)
+ if err != nil {
+ slog.Error("failed to issue CP client certificate", "error", err)
+ os.Exit(1)
+ }
+ // Renew the CP client certificate periodically so it never expires
+ // while the control plane is running (TTL = 24h, renewal = every 12h).
+ go func() {
+ ticker := time.NewTicker(auth.CPCertRenewInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ if err := cpCertStore.Refresh(); err != nil {
+ slog.Error("failed to renew CP client certificate", "error", err)
+ } else {
+ slog.Info("CP client certificate renewed")
+ }
+ }
+ }
+ }()
+ hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
+ slog.Info("host client pool: mTLS enabled")
+ } else {
+ hostPool = lifecycle.NewHostClientPool()
+ }
// Scheduler — picks a host for each new sandbox (round-robin for now).
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
@@ -88,7 +133,11 @@ func main() {
}
// API server.
- srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
+ srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca)
+
+ // Start template build workers (2 concurrent).
+ stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
+ defer stopBuildWorkers()
// Start host monitor (passive + active reconciliation every 30s).
monitor := api.NewHostMonitor(queries, hostPool, audit.New(queries), 30*time.Second)
@@ -98,9 +147,14 @@ func main() {
sampler := api.NewMetricsSampler(queries, 10*time.Second)
sampler.Start(ctx)
+ // Wrap the API handler with the sandbox proxy so that requests with
+ // {port}-{sandbox_id}.{domain} Host headers are routed to the sandbox's
+ // host agent. All other requests pass through to the normal API router.
+ proxyWrapper := api.NewSandboxProxyWrapper(srv.Handler(), queries, hostPool)
+
httpServer := &http.Server{
Addr: cfg.ListenAddr,
- Handler: srv.Handler(),
+ Handler: proxyWrapper,
}
// Graceful shutdown on signal.
diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go
index 2d34cd1d..de48a664 100644
--- a/cmd/host-agent/main.go
+++ b/cmd/host-agent/main.go
@@ -2,8 +2,10 @@ package main
import (
"context"
+ "crypto/tls"
"flag"
"log/slog"
+ "net"
"net/http"
"os"
"os/signal"
@@ -14,8 +16,10 @@ import (
"github.com/joho/godotenv"
+ "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
+ "git.omukk.dev/wrenn/sandbox/internal/network"
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
@@ -42,16 +46,17 @@ func main() {
slog.Warn("failed to enable ip_forward", "error", err)
}
- // Clean up any stale dm-snapshot devices from a previous crash.
+ // Clean up stale resources from a previous crash.
devicemapper.CleanupStaleDevices()
+ network.CleanupStaleNamespaces()
- listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
- rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
- cpURL := os.Getenv("AGENT_CP_URL")
- tokenFile := filepath.Join(rootDir, "host.jwt")
+ listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
+ rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
+ cpURL := os.Getenv("WRENN_CP_URL")
+ credsFile := filepath.Join(rootDir, "host-credentials.json")
if cpURL == "" {
- slog.Error("AGENT_CP_URL environment variable is required")
+ slog.Error("WRENN_CP_URL environment variable is required")
os.Exit(1)
}
if *advertiseAddr == "" {
@@ -59,11 +64,15 @@ func main() {
os.Exit(1)
}
+ // Expand base images to the standard disk size (sparse, no extra physical
+ // disk). This ensures dm-snapshot sandboxes see the full size from boot.
+ if err := sandbox.EnsureImageSizes(rootDir, sandbox.DefaultDiskSizeMB); err != nil {
+ slog.Error("failed to expand base images", "error", err)
+ os.Exit(1)
+ }
+
cfg := sandbox.Config{
- KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
- ImagesDir: filepath.Join(rootDir, "images"),
- SandboxesDir: filepath.Join(rootDir, "sandboxes"),
- SnapshotsDir: filepath.Join(rootDir, "snapshots"),
+ WrennDir: rootDir,
}
mgr := sandbox.New(cfg)
@@ -74,10 +83,10 @@ func main() {
mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating.
- hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
+ creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
- TokenFile: tokenFile,
+ TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
@@ -85,17 +94,29 @@ func main() {
os.Exit(1)
}
- hostID, err := hostagent.HostIDFromToken(hostToken)
- if err != nil {
- slog.Error("failed to extract host ID from token", "error", err)
- os.Exit(1)
- }
-
- slog.Info("host registered", "host_id", hostID)
+ slog.Info("host registered", "host_id", creds.HostID)
// httpServer is declared here so the shutdown func can reference it.
httpServer := &http.Server{Addr: listenAddr}
+ // Set up mTLS if the CP issued a certificate during registration.
+ var certStore hostagent.CertStore
+ if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
+ if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
+ slog.Error("failed to load host TLS certificate", "error", err)
+ os.Exit(1)
+ }
+ tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
+ if tlsCfg == nil {
+ slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
+ os.Exit(1)
+ }
+ httpServer.TLSConfig = tlsCfg
+ slog.Info("mTLS enabled on agent server")
+ } else {
+ slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
+ }
+
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
// and httpServer.Shutdown are each called exactly once regardless of
// whether shutdown is triggered by a signal, a heartbeat 404, or the
@@ -119,13 +140,16 @@ func main() {
})
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
+ proxyHandler := hostagent.NewProxyHandler(mgr)
+
mux := http.NewServeMux()
mux.Handle(path, handler)
+ mux.Handle("/proxy/", proxyHandler)
httpServer.Handler = mux
// Start heartbeat loop. Handler must be set before this because the
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
- hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second,
+ hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second,
// pauseAll: called on 3 consecutive network failures.
func() {
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
@@ -136,6 +160,17 @@ func main() {
func() {
doShutdown("host deleted from CP")
},
+ // onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
+ func(tf *hostagent.TokenFile) {
+ if tf.CertPEM == "" || tf.KeyPEM == "" {
+ return
+ }
+ if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil {
+ slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err)
+ } else {
+ slog.Info("TLS cert hot-swapped after credentials refresh")
+ }
+ },
)
// Graceful shutdown on SIGINT/SIGTERM.
@@ -146,10 +181,30 @@ func main() {
doShutdown("signal: " + sig.String())
}()
- slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
- if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- slog.Error("http server error", "error", err)
- os.Exit(1)
+ slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
+ if httpServer.TLSConfig != nil {
+ // When TLSConfig is pre-populated (cert via GetCertificate callback),
+ // ListenAndServeTLS does not work because it requires on-disk cert/key paths.
+ // Instead, create the TLS listener manually and call Serve.
+ ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
+ if err != nil {
+ slog.Error("failed to start TLS listener", "error", err)
+ os.Exit(1)
+ }
+ if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
+ slog.Error("https server error", "error", err)
+ os.Exit(1)
+ }
+ } else {
+ ln, err := net.Listen("tcp", listenAddr)
+ if err != nil {
+ slog.Error("failed to start listener", "error", err)
+ os.Exit(1)
+ }
+ if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
+ slog.Error("http server error", "error", err)
+ os.Exit(1)
+ }
}
slog.Info("host agent stopped")
diff --git a/db/migrations/20260310094104_initial.sql b/db/migrations/20260310094104_initial.sql
index c291815a..6c8afc47 100644
--- a/db/migrations/20260310094104_initial.sql
+++ b/db/migrations/20260310094104_initial.sql
@@ -1,25 +1,237 @@
-- +goose Up
-CREATE TABLE sandboxes (
- id TEXT PRIMARY KEY,
- owner_id TEXT NOT NULL DEFAULT '',
- host_id TEXT NOT NULL DEFAULT 'default',
- template TEXT NOT NULL DEFAULT 'minimal',
- status TEXT NOT NULL DEFAULT 'pending',
- vcpus INTEGER NOT NULL DEFAULT 1,
- memory_mb INTEGER NOT NULL DEFAULT 512,
- timeout_sec INTEGER NOT NULL DEFAULT 0,
- guest_ip TEXT NOT NULL DEFAULT '',
- host_ip TEXT NOT NULL DEFAULT '',
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- started_at TIMESTAMPTZ,
- last_active_at TIMESTAMPTZ,
- last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW()
+-- teams
+CREATE TABLE teams (
+ id UUID PRIMARY KEY,
+ name TEXT NOT NULL,
+ slug TEXT NOT NULL UNIQUE,
+ is_byoc BOOLEAN NOT NULL DEFAULT FALSE,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ deleted_at TIMESTAMPTZ
+);
+CREATE INDEX idx_teams_slug ON teams(slug);
+
+-- users
+CREATE TABLE users (
+ id UUID PRIMARY KEY,
+ email TEXT NOT NULL UNIQUE,
+ password_hash TEXT,
+ name TEXT NOT NULL DEFAULT '',
+ is_admin BOOLEAN NOT NULL DEFAULT FALSE,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
+-- users_teams (junction)
+CREATE TABLE users_teams (
+ user_id UUID NOT NULL REFERENCES users(id),
+ team_id UUID NOT NULL REFERENCES teams(id),
+ is_default BOOLEAN NOT NULL DEFAULT FALSE,
+ role TEXT NOT NULL DEFAULT 'member',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ PRIMARY KEY (team_id, user_id)
+);
+CREATE INDEX idx_users_teams_user ON users_teams(user_id);
+
+-- team_api_keys
+CREATE TABLE team_api_keys (
+ id UUID PRIMARY KEY,
+ team_id UUID NOT NULL REFERENCES teams(id),
+ name TEXT NOT NULL,
+ key_hash TEXT NOT NULL UNIQUE,
+ key_prefix TEXT NOT NULL,
+ created_by UUID NOT NULL REFERENCES users(id),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ last_used TIMESTAMPTZ
+);
+CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id);
+
+-- oauth_providers
+CREATE TABLE oauth_providers (
+ provider TEXT NOT NULL,
+ provider_id TEXT NOT NULL,
+ user_id UUID NOT NULL REFERENCES users(id),
+ email TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ PRIMARY KEY (provider, provider_id)
+);
+CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
+
+-- admin_permissions
+CREATE TABLE admin_permissions (
+ id UUID PRIMARY KEY,
+ user_id UUID NOT NULL REFERENCES users(id),
+ permission TEXT NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ UNIQUE (user_id, permission)
+);
+CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id);
+
+-- hosts
+CREATE TABLE hosts (
+ id UUID PRIMARY KEY,
+ type TEXT NOT NULL DEFAULT 'regular',
+ team_id UUID REFERENCES teams(id),
+ provider TEXT NOT NULL DEFAULT '',
+ availability_zone TEXT NOT NULL DEFAULT '',
+ arch TEXT NOT NULL DEFAULT '',
+ cpu_cores INTEGER NOT NULL DEFAULT 0,
+ memory_mb INTEGER NOT NULL DEFAULT 0,
+ disk_gb INTEGER NOT NULL DEFAULT 0,
+ address TEXT NOT NULL DEFAULT '',
+ status TEXT NOT NULL DEFAULT 'pending',
+ last_heartbeat_at TIMESTAMPTZ,
+ metadata JSONB NOT NULL DEFAULT '{}',
+ created_by UUID NOT NULL REFERENCES users(id),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ cert_fingerprint TEXT NOT NULL DEFAULT '',
+ mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE
+);
+CREATE INDEX idx_hosts_type ON hosts(type);
+CREATE INDEX idx_hosts_team ON hosts(team_id);
+CREATE INDEX idx_hosts_status ON hosts(status);
+
+-- host_tokens
+CREATE TABLE host_tokens (
+ id UUID PRIMARY KEY,
+ host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
+ created_by UUID NOT NULL REFERENCES users(id),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ expires_at TIMESTAMPTZ NOT NULL,
+ used_at TIMESTAMPTZ
+);
+CREATE INDEX idx_host_tokens_host ON host_tokens(host_id);
+
+-- host_tags
+CREATE TABLE host_tags (
+ host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
+ tag TEXT NOT NULL,
+ PRIMARY KEY (host_id, tag)
+);
+CREATE INDEX idx_host_tags_tag ON host_tags(tag);
+
+-- host_refresh_tokens
+CREATE TABLE host_refresh_tokens (
+ id UUID PRIMARY KEY,
+ host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
+ token_hash TEXT NOT NULL UNIQUE,
+ expires_at TIMESTAMPTZ NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ revoked_at TIMESTAMPTZ
+);
+CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
+
+-- templates (TEXT primary key — not UUID)
+CREATE TABLE templates (
+ name TEXT PRIMARY KEY,
+ type TEXT NOT NULL DEFAULT 'base',
+ vcpus INTEGER NOT NULL DEFAULT 1,
+ memory_mb INTEGER NOT NULL DEFAULT 512,
+ size_bytes BIGINT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ team_id UUID NOT NULL
+);
+CREATE INDEX idx_templates_team ON templates(team_id);
+
+-- sandboxes
+CREATE TABLE sandboxes (
+ id UUID PRIMARY KEY,
+ team_id UUID NOT NULL REFERENCES teams(id),
+ host_id UUID NOT NULL,
+ template TEXT NOT NULL DEFAULT 'minimal',
+ status TEXT NOT NULL DEFAULT 'pending',
+ vcpus INTEGER NOT NULL DEFAULT 1,
+ memory_mb INTEGER NOT NULL DEFAULT 512,
+ timeout_sec INTEGER NOT NULL DEFAULT 300,
+ disk_size_mb INTEGER NOT NULL DEFAULT 5120,
+ guest_ip TEXT NOT NULL DEFAULT '',
+ host_ip TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ started_at TIMESTAMPTZ,
+ last_active_at TIMESTAMPTZ,
+ last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
CREATE INDEX idx_sandboxes_status ON sandboxes(status);
CREATE INDEX idx_sandboxes_host_status ON sandboxes(host_id, status);
+CREATE INDEX idx_sandboxes_team ON sandboxes(team_id);
+
+-- audit_logs (id and team_id are UUID; actor_id and resource_id are TEXT for polymorphism)
+CREATE TABLE audit_logs (
+ id UUID PRIMARY KEY,
+ team_id UUID NOT NULL,
+ actor_type TEXT NOT NULL,
+ actor_id TEXT,
+ actor_name TEXT NOT NULL DEFAULT '',
+ resource_type TEXT NOT NULL,
+ resource_id TEXT,
+ action TEXT NOT NULL,
+ scope TEXT NOT NULL DEFAULT 'team',
+ status TEXT NOT NULL DEFAULT 'success',
+ metadata JSONB NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX idx_audit_logs_team_time ON audit_logs(team_id, created_at DESC);
+CREATE INDEX idx_audit_logs_team_resource ON audit_logs(team_id, resource_type, created_at DESC);
+
+-- sandbox_metrics_snapshots
+CREATE TABLE sandbox_metrics_snapshots (
+ id BIGSERIAL PRIMARY KEY,
+ team_id UUID NOT NULL,
+ sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ running_count INTEGER NOT NULL DEFAULT 0,
+ vcpus_reserved INTEGER NOT NULL DEFAULT 0,
+ memory_mb_reserved INTEGER NOT NULL DEFAULT 0
+);
+CREATE INDEX idx_metrics_snapshots_team_time ON sandbox_metrics_snapshots(team_id, sampled_at DESC);
+
+-- sandbox_metric_points
+CREATE TABLE sandbox_metric_points (
+ sandbox_id UUID NOT NULL,
+ tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')),
+ ts BIGINT NOT NULL,
+ cpu_pct FLOAT8 NOT NULL DEFAULT 0,
+ mem_bytes BIGINT NOT NULL DEFAULT 0,
+ disk_bytes BIGINT NOT NULL DEFAULT 0,
+ PRIMARY KEY (sandbox_id, tier, ts)
+);
+CREATE INDEX idx_sandbox_metric_points_sandbox_tier ON sandbox_metric_points(sandbox_id, tier);
+
+-- template_builds
+CREATE TABLE template_builds (
+ id UUID PRIMARY KEY,
+ name TEXT NOT NULL,
+ base_template TEXT NOT NULL,
+ recipe JSONB NOT NULL DEFAULT '[]',
+ healthcheck TEXT NOT NULL DEFAULT '',
+ vcpus INTEGER NOT NULL DEFAULT 1,
+ memory_mb INTEGER NOT NULL DEFAULT 512,
+ status TEXT NOT NULL DEFAULT 'pending',
+ current_step INTEGER NOT NULL DEFAULT 0,
+ total_steps INTEGER NOT NULL DEFAULT 0,
+ logs JSONB NOT NULL DEFAULT '[]',
+ error TEXT NOT NULL DEFAULT '',
+ sandbox_id UUID,
+ host_id UUID,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ started_at TIMESTAMPTZ,
+ completed_at TIMESTAMPTZ
+);
-- +goose Down
-
-DROP TABLE sandboxes;
+DROP TABLE IF EXISTS template_builds;
+DROP TABLE IF EXISTS sandbox_metric_points;
+DROP TABLE IF EXISTS sandbox_metrics_snapshots;
+DROP TABLE IF EXISTS audit_logs;
+DROP TABLE IF EXISTS sandboxes;
+DROP TABLE IF EXISTS templates;
+DROP TABLE IF EXISTS host_refresh_tokens;
+DROP TABLE IF EXISTS host_tags;
+DROP TABLE IF EXISTS host_tokens;
+DROP TABLE IF EXISTS hosts;
+DROP TABLE IF EXISTS admin_permissions;
+DROP TABLE IF EXISTS oauth_providers;
+DROP TABLE IF EXISTS team_api_keys;
+DROP TABLE IF EXISTS users_teams;
+DROP TABLE IF EXISTS users;
+DROP TABLE IF EXISTS teams;
diff --git a/db/migrations/20260311224925_snapshots.sql b/db/migrations/20260311224925_snapshots.sql
deleted file mode 100644
index 8a0427c2..00000000
--- a/db/migrations/20260311224925_snapshots.sql
+++ /dev/null
@@ -1,14 +0,0 @@
--- +goose Up
-
-CREATE TABLE templates (
- name TEXT PRIMARY KEY,
- type TEXT NOT NULL DEFAULT 'base', -- 'base' or 'snapshot'
- vcpus INTEGER,
- memory_mb INTEGER,
- size_bytes BIGINT NOT NULL DEFAULT 0,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
-);
-
--- +goose Down
-
-DROP TABLE templates;
diff --git a/db/migrations/20260313210608_auth.sql b/db/migrations/20260313210608_auth.sql
deleted file mode 100644
index 03970a8f..00000000
--- a/db/migrations/20260313210608_auth.sql
+++ /dev/null
@@ -1,46 +0,0 @@
--- +goose Up
-
-CREATE TABLE users (
- id TEXT PRIMARY KEY,
- email TEXT NOT NULL UNIQUE,
- password_hash TEXT NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
-);
-
-CREATE TABLE teams (
- id TEXT PRIMARY KEY,
- name TEXT NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
-);
-
-CREATE TABLE users_teams (
- user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
- is_default BOOLEAN NOT NULL DEFAULT TRUE,
- role TEXT NOT NULL DEFAULT 'owner',
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- PRIMARY KEY (team_id, user_id)
-);
-
-CREATE INDEX idx_users_teams_user ON users_teams(user_id);
-
-CREATE TABLE team_api_keys (
- id TEXT PRIMARY KEY,
- team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
- name TEXT NOT NULL DEFAULT '',
- key_hash TEXT NOT NULL UNIQUE,
- key_prefix TEXT NOT NULL DEFAULT '',
- created_by TEXT NOT NULL REFERENCES users(id),
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- last_used TIMESTAMPTZ
-);
-
-CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id);
-
--- +goose Down
-
-DROP TABLE team_api_keys;
-DROP TABLE users_teams;
-DROP TABLE teams;
-DROP TABLE users;
diff --git a/db/migrations/20260313210611_team_ownership.sql b/db/migrations/20260313210611_team_ownership.sql
deleted file mode 100644
index 849e781e..00000000
--- a/db/migrations/20260313210611_team_ownership.sql
+++ /dev/null
@@ -1,31 +0,0 @@
--- +goose Up
-
-ALTER TABLE sandboxes
- ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
-
-UPDATE sandboxes SET team_id = owner_id WHERE owner_id != '';
-
-ALTER TABLE sandboxes
- DROP COLUMN owner_id;
-
-ALTER TABLE templates
- ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
-
-CREATE INDEX idx_sandboxes_team ON sandboxes(team_id);
-CREATE INDEX idx_templates_team ON templates(team_id);
-
--- +goose Down
-
-ALTER TABLE sandboxes
- ADD COLUMN owner_id TEXT NOT NULL DEFAULT '';
-
-UPDATE sandboxes SET owner_id = team_id WHERE team_id != '';
-
-ALTER TABLE sandboxes
- DROP COLUMN team_id;
-
-ALTER TABLE templates
- DROP COLUMN team_id;
-
-DROP INDEX IF EXISTS idx_sandboxes_team;
-DROP INDEX IF EXISTS idx_templates_team;
diff --git a/db/migrations/20260315001514_oauth.sql b/db/migrations/20260315001514_oauth.sql
deleted file mode 100644
index c3c33e9b..00000000
--- a/db/migrations/20260315001514_oauth.sql
+++ /dev/null
@@ -1,22 +0,0 @@
--- +goose Up
-
-ALTER TABLE users
- ALTER COLUMN password_hash DROP NOT NULL;
-
-CREATE TABLE oauth_providers (
- provider TEXT NOT NULL,
- provider_id TEXT NOT NULL,
- user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- email TEXT NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- PRIMARY KEY (provider, provider_id)
-);
-
-CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
-
--- +goose Down
-
-DROP TABLE oauth_providers;
-
-UPDATE users SET password_hash = '' WHERE password_hash IS NULL;
-ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL;
diff --git a/db/migrations/20260316203135_admin_users.sql b/db/migrations/20260316203135_admin_users.sql
deleted file mode 100644
index eff669b5..00000000
--- a/db/migrations/20260316203135_admin_users.sql
+++ /dev/null
@@ -1,21 +0,0 @@
--- +goose Up
-
-ALTER TABLE users
- ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE;
-
-CREATE TABLE admin_permissions (
- id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- permission TEXT NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- UNIQUE (user_id, permission)
-);
-
-CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id);
-
--- +goose Down
-
-DROP TABLE admin_permissions;
-
-ALTER TABLE users
- DROP COLUMN is_admin;
diff --git a/db/migrations/20260316203138_byoc_teams.sql b/db/migrations/20260316203138_byoc_teams.sql
deleted file mode 100644
index bb2c8ec2..00000000
--- a/db/migrations/20260316203138_byoc_teams.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- +goose Up
-
-ALTER TABLE teams
- ADD COLUMN is_byoc BOOLEAN NOT NULL DEFAULT FALSE;
-
--- +goose Down
-
-ALTER TABLE teams
- DROP COLUMN is_byoc;
diff --git a/db/migrations/20260316203142_hosts.sql b/db/migrations/20260316203142_hosts.sql
deleted file mode 100644
index 372b3802..00000000
--- a/db/migrations/20260316203142_hosts.sql
+++ /dev/null
@@ -1,47 +0,0 @@
--- +goose Up
-
-CREATE TABLE hosts (
- id TEXT PRIMARY KEY,
- type TEXT NOT NULL DEFAULT 'regular', -- 'regular' or 'byoc'
- team_id TEXT REFERENCES teams(id) ON DELETE SET NULL,
- provider TEXT,
- availability_zone TEXT,
- arch TEXT,
- cpu_cores INTEGER,
- memory_mb INTEGER,
- disk_gb INTEGER,
- address TEXT, -- ip:port of host agent
- status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'online', 'offline', 'draining'
- last_heartbeat_at TIMESTAMPTZ,
- metadata JSONB NOT NULL DEFAULT '{}',
- created_by TEXT NOT NULL REFERENCES users(id),
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
-);
-
-CREATE TABLE host_tokens (
- id TEXT PRIMARY KEY,
- host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
- created_by TEXT NOT NULL REFERENCES users(id),
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- expires_at TIMESTAMPTZ NOT NULL,
- used_at TIMESTAMPTZ
-);
-
-CREATE TABLE host_tags (
- host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
- tag TEXT NOT NULL,
- PRIMARY KEY (host_id, tag)
-);
-
-CREATE INDEX idx_hosts_type ON hosts(type);
-CREATE INDEX idx_hosts_team ON hosts(team_id);
-CREATE INDEX idx_hosts_status ON hosts(status);
-CREATE INDEX idx_host_tokens_host ON host_tokens(host_id);
-CREATE INDEX idx_host_tags_tag ON host_tags(tag);
-
--- +goose Down
-
-DROP TABLE host_tags;
-DROP TABLE host_tokens;
-DROP TABLE hosts;
diff --git a/db/migrations/20260316223629_host_mtls.sql b/db/migrations/20260316223629_host_mtls.sql
deleted file mode 100644
index f56b923e..00000000
--- a/db/migrations/20260316223629_host_mtls.sql
+++ /dev/null
@@ -1,11 +0,0 @@
--- +goose Up
-
-ALTER TABLE hosts
- ADD COLUMN cert_fingerprint TEXT,
- ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
-
--- +goose Down
-
-ALTER TABLE hosts
- DROP COLUMN cert_fingerprint,
- DROP COLUMN mtls_enabled;
diff --git a/db/migrations/20260324071453_team_management.sql b/db/migrations/20260324071453_team_management.sql
deleted file mode 100644
index 1495d6dd..00000000
--- a/db/migrations/20260324071453_team_management.sql
+++ /dev/null
@@ -1,17 +0,0 @@
--- +goose Up
-
-ALTER TABLE teams ADD COLUMN slug TEXT;
-ALTER TABLE teams ADD COLUMN deleted_at TIMESTAMPTZ;
-
--- Backfill slugs for existing teams using MD5 of their ID.
--- MD5 returns 32 hex chars; take chars 1-6 and 7-12 to form a 6-6 slug.
-UPDATE teams SET slug = LEFT(MD5(id), 6) || '-' || SUBSTRING(MD5(id), 7, 6);
-
-ALTER TABLE teams ALTER COLUMN slug SET NOT NULL;
-CREATE UNIQUE INDEX idx_teams_slug ON teams(slug);
-
--- +goose Down
-
-DROP INDEX idx_teams_slug;
-ALTER TABLE teams DROP COLUMN deleted_at;
-ALTER TABLE teams DROP COLUMN slug;
diff --git a/db/migrations/20260324100234_user_names.sql b/db/migrations/20260324100234_user_names.sql
deleted file mode 100644
index 2775d12d..00000000
--- a/db/migrations/20260324100234_user_names.sql
+++ /dev/null
@@ -1,5 +0,0 @@
--- +goose Up
-ALTER TABLE users ADD COLUMN name TEXT NOT NULL DEFAULT '';
-
--- +goose Down
-ALTER TABLE users DROP COLUMN name;
diff --git a/db/migrations/20260324120214_host_refresh_tokens.sql b/db/migrations/20260324120214_host_refresh_tokens.sql
deleted file mode 100644
index 02a13f74..00000000
--- a/db/migrations/20260324120214_host_refresh_tokens.sql
+++ /dev/null
@@ -1,19 +0,0 @@
--- +goose Up
-
--- Refresh tokens for host agent JWT rotation.
--- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation).
--- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that.
-CREATE TABLE host_refresh_tokens (
- id TEXT PRIMARY KEY,
- host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
- token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token
- expires_at TIMESTAMPTZ NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete
-);
-
-CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
-
--- +goose Down
-
-DROP TABLE host_refresh_tokens;
diff --git a/db/migrations/20260324220743_audit_logs.sql b/db/migrations/20260324220743_audit_logs.sql
deleted file mode 100644
index 91b73754..00000000
--- a/db/migrations/20260324220743_audit_logs.sql
+++ /dev/null
@@ -1,28 +0,0 @@
--- +goose Up
-
-CREATE TABLE audit_logs (
- id TEXT PRIMARY KEY,
- team_id TEXT NOT NULL,
- actor_type TEXT NOT NULL, -- 'user', 'api_key', 'system'
- actor_id TEXT, -- user_id or api_key_id; NULL for system
- actor_name TEXT, -- display name snapshotted at write time; NULL for system
- resource_type TEXT NOT NULL, -- 'sandbox', 'snapshot', 'team', 'api_key', 'member', 'host'
- resource_id TEXT, -- primary ID of the affected resource; NULL when not applicable
- action TEXT NOT NULL, -- 'create', 'pause', 'resume', 'destroy', 'delete', 'rename',
- -- 'revoke', 'add', 'remove', 'leave', 'role_update',
- -- 'marked_down', 'marked_up'
- scope TEXT NOT NULL, -- 'team' or 'admin'
- status TEXT NOT NULL, -- 'success', 'info', 'warning', 'error'
- metadata JSONB NOT NULL DEFAULT '{}',
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
-);
-
--- Primary access pattern: team feed sorted newest-first with cursor pagination.
-CREATE INDEX idx_audit_logs_team_time ON audit_logs (team_id, created_at DESC);
-
--- Secondary index: filtered by resource_type and action within a team.
-CREATE INDEX idx_audit_logs_team_resource ON audit_logs (team_id, resource_type, action, created_at DESC);
-
--- +goose Down
-
-DROP TABLE audit_logs;
diff --git a/db/migrations/20260325074949_metrics_snapshots.sql b/db/migrations/20260325074949_metrics_snapshots.sql
deleted file mode 100644
index 7d373e86..00000000
--- a/db/migrations/20260325074949_metrics_snapshots.sql
+++ /dev/null
@@ -1,18 +0,0 @@
--- +goose Up
-
-CREATE TABLE sandbox_metrics_snapshots (
- id BIGSERIAL PRIMARY KEY,
- team_id TEXT NOT NULL,
- sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- running_count INTEGER NOT NULL,
- vcpus_reserved INTEGER NOT NULL,
- memory_mb_reserved INTEGER NOT NULL
-);
-
--- All queries filter on team_id first then range-scan sampled_at.
-CREATE INDEX idx_metrics_snapshots_team_time
- ON sandbox_metrics_snapshots (team_id, sampled_at DESC);
-
--- +goose Down
-
-DROP TABLE sandbox_metrics_snapshots;
diff --git a/db/migrations/20260325135035_add_sandbox_metric_points.sql b/db/migrations/20260325135035_add_sandbox_metric_points.sql
deleted file mode 100644
index 08e86835..00000000
--- a/db/migrations/20260325135035_add_sandbox_metric_points.sql
+++ /dev/null
@@ -1,16 +0,0 @@
--- +goose Up
-CREATE TABLE sandbox_metric_points (
- sandbox_id TEXT NOT NULL,
- tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')),
- ts BIGINT NOT NULL,
- cpu_pct FLOAT8 NOT NULL DEFAULT 0,
- mem_bytes BIGINT NOT NULL DEFAULT 0,
- disk_bytes BIGINT NOT NULL DEFAULT 0,
- PRIMARY KEY (sandbox_id, tier, ts)
-);
-
-CREATE INDEX idx_sandbox_metric_points_sandbox_tier
- ON sandbox_metric_points (sandbox_id, tier);
-
--- +goose Down
-DROP TABLE IF EXISTS sandbox_metric_points;
diff --git a/db/migrations/20260328162803_template_uuid_pk.sql b/db/migrations/20260328162803_template_uuid_pk.sql
new file mode 100644
index 00000000..0bb65668
--- /dev/null
+++ b/db/migrations/20260328162803_template_uuid_pk.sql
@@ -0,0 +1,82 @@
+-- +goose Up
+
+-- 1. Add UUID id column to templates and make it the primary key.
+ALTER TABLE templates ADD COLUMN id UUID DEFAULT gen_random_uuid();
+UPDATE templates SET id = gen_random_uuid() WHERE id IS NULL;
+ALTER TABLE templates ALTER COLUMN id SET NOT NULL;
+ALTER TABLE templates DROP CONSTRAINT templates_pkey;
+ALTER TABLE templates ADD PRIMARY KEY (id);
+
+-- 2. Name becomes a display field with team-scoped uniqueness.
+ALTER TABLE templates ADD CONSTRAINT uq_templates_team_name UNIQUE (team_id, name);
+
+-- 3. Prevent team templates from using names that belong to global (platform) templates.
+-- A team template insert/update with a name matching any platform template is rejected.
+-- +goose StatementBegin
+CREATE OR REPLACE FUNCTION check_global_template_name_collision()
+RETURNS TRIGGER AS $$
+BEGIN
+ IF NEW.team_id != '00000000-0000-0000-0000-000000000000' THEN
+ IF EXISTS (
+ SELECT 1 FROM templates
+ WHERE name = NEW.name
+ AND team_id = '00000000-0000-0000-0000-000000000000'
+ ) THEN
+ RAISE EXCEPTION 'template name "%" is reserved by a global template', NEW.name
+ USING ERRCODE = 'unique_violation';
+ END IF;
+ END IF;
+ RETURN NEW;
+END;
+$$ LANGUAGE plpgsql;
+-- +goose StatementEnd
+
+CREATE TRIGGER trg_check_global_template_name
+ BEFORE INSERT OR UPDATE ON templates
+ FOR EACH ROW
+ EXECUTE FUNCTION check_global_template_name_collision();
+
+-- 4. Seed the built-in "minimal" template so it appears in all listings.
+-- Both id and team_id are the all-zeros UUID (platform sentinel).
+INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
+VALUES (
+ '00000000-0000-0000-0000-000000000000',
+ 'minimal',
+ 'base',
+ 1,
+ 512,
+ 0,
+ '00000000-0000-0000-0000-000000000000'
+) ON CONFLICT DO NOTHING;
+
+-- 5. Add template UUID references to template_builds.
+ALTER TABLE template_builds
+ ADD COLUMN template_id UUID,
+ ADD COLUMN team_id UUID;
+
+-- 5. Add template UUID references to sandboxes.
+ALTER TABLE sandboxes
+ ADD COLUMN template_id UUID,
+ ADD COLUMN template_team_id UUID;
+
+-- +goose Down
+
+ALTER TABLE sandboxes
+ DROP COLUMN IF EXISTS template_team_id,
+ DROP COLUMN IF EXISTS template_id;
+
+ALTER TABLE template_builds
+ DROP COLUMN IF EXISTS team_id,
+ DROP COLUMN IF EXISTS template_id;
+
+-- Remove the seeded minimal template.
+DELETE FROM templates WHERE id = '00000000-0000-0000-0000-000000000000';
+
+DROP TRIGGER IF EXISTS trg_check_global_template_name ON templates;
+DROP FUNCTION IF EXISTS check_global_template_name_collision();
+
+ALTER TABLE templates DROP CONSTRAINT IF EXISTS uq_templates_team_name;
+
+ALTER TABLE templates DROP CONSTRAINT IF EXISTS templates_pkey;
+ALTER TABLE templates ADD PRIMARY KEY (name);
+ALTER TABLE templates DROP COLUMN IF EXISTS id;
diff --git a/db/migrations/20260330112050_mtls_cert_expiry.sql b/db/migrations/20260330112050_mtls_cert_expiry.sql
new file mode 100644
index 00000000..e7245d27
--- /dev/null
+++ b/db/migrations/20260330112050_mtls_cert_expiry.sql
@@ -0,0 +1,7 @@
+-- +goose Up
+ALTER TABLE hosts DROP COLUMN mtls_enabled;
+ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ;
+
+-- +goose Down
+ALTER TABLE hosts DROP COLUMN cert_expires_at;
+ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/db/migrations/20260330150223_build_options.sql b/db/migrations/20260330150223_build_options.sql
new file mode 100644
index 00000000..981ad065
--- /dev/null
+++ b/db/migrations/20260330150223_build_options.sql
@@ -0,0 +1,11 @@
+-- +goose Up
+
+-- Allow completed_at to be set when a build is cancelled.
+-- (The UpdateBuildStatus query is updated in sqlc; no schema change needed for that.)
+
+-- Add skip_pre_post flag: when true, the pre-build and post-build command phases
+-- are skipped for this build.
+ALTER TABLE template_builds ADD COLUMN skip_pre_post BOOLEAN NOT NULL DEFAULT FALSE;
+
+-- +goose Down
+ALTER TABLE template_builds DROP COLUMN skip_pre_post;
diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql
index 27ece000..0a5a150b 100644
--- a/db/queries/hosts.sql
+++ b/db/queries/hosts.sql
@@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
-- name: RegisterHost :execrows
UPDATE hosts
-SET arch = $2,
- cpu_cores = $3,
- memory_mb = $4,
- disk_gb = $5,
- address = $6,
- status = 'online',
+SET arch = $2,
+ cpu_cores = $3,
+ memory_mb = $4,
+ disk_gb = $5,
+ address = $6,
+ cert_fingerprint = $7,
+ cert_expires_at = $8,
+ status = 'online',
last_heartbeat_at = NOW(),
- updated_at = NOW()
+ updated_at = NOW()
WHERE id = $1 AND status = 'pending';
+-- name: UpdateHostCert :exec
+UPDATE hosts
+SET cert_fingerprint = $2,
+ cert_expires_at = $3,
+ updated_at = NOW()
+WHERE id = $1;
+
-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql
index 131fe1ed..2b195744 100644
--- a/db/queries/sandboxes.sql
+++ b/db/queries/sandboxes.sql
@@ -1,6 +1,6 @@
-- name: InsertSandbox :one
-INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
-VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
+INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
+VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING *;
-- name: GetSandbox :one
@@ -9,6 +9,14 @@ SELECT * FROM sandboxes WHERE id = $1;
-- name: GetSandboxByTeam :one
SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2;
+-- name: GetSandboxProxyTarget :one
+-- Returns the sandbox status and its host's address in one query.
+-- Used by SandboxProxyWrapper to avoid two round-trips.
+SELECT s.status, h.address AS host_address
+FROM sandboxes s
+JOIN hosts h ON h.id = s.host_id
+WHERE s.id = $1 AND s.team_id = $2;
+
-- name: ListSandboxes :many
SELECT * FROM sandboxes ORDER BY created_at DESC;
@@ -50,7 +58,7 @@ WHERE id = $1;
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
-WHERE id = ANY($1::text[]);
+WHERE id = ANY($1::uuid[]);
-- name: ListActiveSandboxesByTeam :many
SELECT * FROM sandboxes
@@ -72,4 +80,4 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
-WHERE id = ANY($1::text[]) AND status = 'missing';
+WHERE id = ANY($1::uuid[]) AND status = 'missing';
diff --git a/db/queries/template_builds.sql b/db/queries/template_builds.sql
new file mode 100644
index 00000000..1fb07be3
--- /dev/null
+++ b/db/queries/template_builds.sql
@@ -0,0 +1,33 @@
+-- name: InsertTemplateBuild :one
+INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
+VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
+RETURNING *;
+
+-- name: GetTemplateBuild :one
+SELECT * FROM template_builds WHERE id = $1;
+
+-- name: ListTemplateBuilds :many
+SELECT * FROM template_builds ORDER BY created_at DESC;
+
+-- name: UpdateBuildStatus :one
+UPDATE template_builds
+SET status = $2,
+ started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
+ completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
+WHERE id = $1
+RETURNING *;
+
+-- name: UpdateBuildProgress :exec
+UPDATE template_builds
+SET current_step = $2, logs = $3
+WHERE id = $1;
+
+-- name: UpdateBuildSandbox :exec
+UPDATE template_builds
+SET sandbox_id = $2, host_id = $3
+WHERE id = $1;
+
+-- name: UpdateBuildError :exec
+UPDATE template_builds
+SET error = $2, status = 'failed', completed_at = NOW()
+WHERE id = $1;
diff --git a/db/queries/templates.sql b/db/queries/templates.sql
index b17abc3a..de4d6f2a 100644
--- a/db/queries/templates.sql
+++ b/db/queries/templates.sql
@@ -1,13 +1,22 @@
-- name: InsertTemplate :one
-INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
-VALUES ($1, $2, $3, $4, $5, $6)
+INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
+VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING *;
-- name: GetTemplate :one
-SELECT * FROM templates WHERE name = $1;
+SELECT * FROM templates WHERE id = $1;
-- name: GetTemplateByTeam :one
-SELECT * FROM templates WHERE name = $1 AND team_id = $2;
+-- Platform templates (team_id = 00000000-...) are visible to all teams.
+SELECT * FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000');
+
+-- name: GetTemplateByName :one
+-- Look up a template by team_id and name (exact team match, no global fallback).
+SELECT * FROM templates WHERE team_id = $1 AND name = $2;
+
+-- name: GetPlatformTemplateByName :one
+-- Check if a global (platform) template exists with the given name.
+SELECT * FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1;
-- name: ListTemplates :many
SELECT * FROM templates ORDER BY created_at DESC;
@@ -16,13 +25,23 @@ SELECT * FROM templates ORDER BY created_at DESC;
SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC;
-- name: ListTemplatesByTeam :many
-SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;
+-- Platform templates are visible to all teams.
+SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC;
-- name: ListTemplatesByTeamAndType :many
-SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC;
+-- Platform templates are visible to all teams.
+SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC;
-- name: DeleteTemplate :exec
-DELETE FROM templates WHERE name = $1;
+DELETE FROM templates WHERE id = $1;
-- name: DeleteTemplateByTeam :exec
DELETE FROM templates WHERE name = $1 AND team_id = $2;
+
+-- name: DeleteTemplatesByTeam :exec
+-- Bulk delete all templates owned by a team (for team soft-delete cleanup).
+DELETE FROM templates WHERE team_id = $1;
+
+-- name: ListTemplatesByTeamOnly :many
+-- List templates owned by a specific team (NOT including platform templates).
+SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;
diff --git a/deploy/Caddyfile.dev b/deploy/Caddyfile.dev
new file mode 100644
index 00000000..789f8dfd
--- /dev/null
+++ b/deploy/Caddyfile.dev
@@ -0,0 +1,41 @@
+# Sandbox port forwarding: {port}-{sandbox_id}.localhost
+# Matches subdomains like 49999-sb-abcd1234.localhost and proxies them
+# to the control plane, which inspects the Host header and routes to
+# the correct host agent.
+#
+# NOTE: Wildcard *.localhost DNS resolution requires local setup.
+# Option 1: Add entries to /etc/hosts for each sandbox
+# Option 2: Use dnsmasq: address=/.localhost/127.0.0.1
+# Option 3: Use systemd-resolved (Ubuntu default — *.localhost resolves to 127.0.0.1)
+http://*.localhost {
+ reverse_proxy host.docker.internal:8080
+}
+
+# Main entry point: API + frontend
+http://localhost {
+ # API routes — strip /api prefix and proxy to the control plane.
+ # The frontend calls /api/v1/... which becomes /v1/... at the CP.
+ handle_path /api/* {
+ reverse_proxy host.docker.internal:8080
+ }
+
+ # Backend routes served directly (SDK clients, OAuth initiation)
+ handle /v1/* {
+ reverse_proxy host.docker.internal:8080
+ }
+ handle /openapi.yaml {
+ reverse_proxy host.docker.internal:8080
+ }
+ handle /docs {
+ reverse_proxy host.docker.internal:8080
+ }
+ handle /auth/oauth/* {
+ reverse_proxy host.docker.internal:8080
+ }
+
+ # Everything else — proxy to the frontend dev server
+ # This includes: /login, /dashboard/*, /admin/*, /auth/github/callback
+ handle {
+ reverse_proxy host.docker.internal:5173
+ }
+}
diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml
index ebcd3087..28f401f6 100644
--- a/deploy/docker-compose.dev.yml
+++ b/deploy/docker-compose.dev.yml
@@ -15,19 +15,14 @@ services:
ports:
- "6379:6379"
- prometheus:
- image: prom/prometheus:latest
+ caddy:
+ image: caddy:2-alpine
ports:
- - "9090:9090"
+ - "8000:80"
volumes:
- - ./deploy/prometheus.yml:/etc/prometheus/prometheus.yml
-
- grafana:
- image: grafana/grafana:latest
- ports:
- - "3001:3000"
- environment:
- GF_SECURITY_ADMIN_PASSWORD: admin
+ - ./Caddyfile.dev:/etc/caddy/Caddyfile:ro
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
volumes:
pgdata:
diff --git a/envd/internal/api/init.go b/envd/internal/api/init.go
index a4894599..301400cb 100644
--- a/envd/internal/api/init.go
+++ b/envd/internal/api/init.go
@@ -14,14 +14,12 @@ import (
"os/exec"
"time"
- "github.com/awnumar/memguard"
- "github.com/rs/zerolog"
- "github.com/txn2/txeh"
- "golang.org/x/sys/unix"
-
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
+ "github.com/awnumar/memguard"
+ "github.com/rs/zerolog"
+ "github.com/txn2/txeh"
)
var (
@@ -29,11 +27,6 @@ var (
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
-const (
- maxTimeInPast = 50 * time.Millisecond
- maxTimeInFuture = 5 * time.Second
-)
-
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
@@ -172,20 +165,6 @@ func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJ
return err
}
- if data.Timestamp != nil {
- // Check if current time differs significantly from the received timestamp
- if shouldSetSystemTime(time.Now(), *data.Timestamp) {
- logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
- ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
- err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
- if err != nil {
- logger.Error().Msgf("Failed to set system time: %v", err)
- }
- } else {
- logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
- }
- }
-
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
@@ -308,10 +287,3 @@ func getIPFamily(address string) (txeh.IPFamily, error) {
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}
-
-// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
-// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
-// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
-func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
- return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
-}
diff --git a/envd/internal/api/init_test.go b/envd/internal/api/init_test.go
index c4b6f4b5..f3db3615 100644
--- a/envd/internal/api/init_test.go
+++ b/envd/internal/api/init_test.go
@@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"
"testing"
- "time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@@ -59,71 +58,6 @@ func TestSimpleCases(t *testing.T) {
}
}
-func TestShouldSetSystemTime(t *testing.T) {
- t.Parallel()
- sandboxTime := time.Now()
-
- tests := []struct {
- name string
- hostTime time.Time
- want bool
- }{
- {
- name: "sandbox time far ahead of host time (should set)",
- hostTime: sandboxTime.Add(-10 * time.Second),
- want: true,
- },
- {
- name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
- hostTime: sandboxTime.Add(-50 * time.Millisecond),
- want: false,
- },
- {
- name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
- hostTime: sandboxTime.Add(-40 * time.Millisecond),
- want: false,
- },
- {
- name: "sandbox time slightly ahead of host time (should not set)",
- hostTime: sandboxTime.Add(-10 * time.Millisecond),
- want: false,
- },
- {
- name: "sandbox time equals host time (should not set)",
- hostTime: sandboxTime,
- want: false,
- },
- {
- name: "sandbox time slightly behind host time (should not set)",
- hostTime: sandboxTime.Add(1 * time.Second),
- want: false,
- },
- {
- name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
- hostTime: sandboxTime.Add(4 * time.Second),
- want: false,
- },
- {
- name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
- hostTime: sandboxTime.Add(5 * time.Second),
- want: false,
- },
- {
- name: "sandbox time far behind host time (should set)",
- hostTime: sandboxTime.Add(1 * time.Minute),
- want: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- got := shouldSetSystemTime(tt.hostTime, sandboxTime)
- assert.Equal(t, tt.want, got)
- })
- }
-}
-
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))
diff --git a/envd/internal/port/conn.go b/envd/internal/port/conn.go
new file mode 100644
index 00000000..8a8c032c
--- /dev/null
+++ b/envd/internal/port/conn.go
@@ -0,0 +1,165 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package port
+
+import (
+ "bufio"
+ "encoding/hex"
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "syscall"
+)
+
+// ConnStat represents a single TCP connection read from /proc/net/tcp(6).
+// It contains only the fields needed by the port scanner and forwarder.
+type ConnStat struct {
+ LocalIP string
+ LocalPort uint32
+ Status string
+ Family uint32 // syscall.AF_INET or syscall.AF_INET6
+ Inode uint64 // socket inode, unique per connection
+}
+
+// tcpStates maps the hex state values from /proc/net/tcp to string names
+// matching the gopsutil convention used by ScannerFilter.
+var tcpStates = map[string]string{
+ "01": "ESTABLISHED",
+ "02": "SYN_SENT",
+ "03": "SYN_RECV",
+ "04": "FIN_WAIT1",
+ "05": "FIN_WAIT2",
+ "06": "TIME_WAIT",
+ "07": "CLOSE",
+ "08": "CLOSE_WAIT",
+ "09": "LAST_ACK",
+ "0A": "LISTEN",
+ "0B": "CLOSING",
+}
+
+// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns
+// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil
+// performs, which is unsafe across Firecracker snapshot/restore boundaries.
+func ReadTCPConnections() ([]ConnStat, error) {
+ var conns []ConnStat
+
+ tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET)
+ if err != nil {
+ return nil, fmt.Errorf("parse /proc/net/tcp: %w", err)
+ }
+ conns = append(conns, tcp4...)
+
+ tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6)
+ if err != nil {
+ return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err)
+ }
+ conns = append(conns, tcp6...)
+
+ return conns, nil
+}
+
+// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file.
+//
+// Format (fields are whitespace-separated):
+//
+// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode
+// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345
+func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ var conns []ConnStat
+ scanner := bufio.NewScanner(f)
+
+ // Skip header line.
+ scanner.Scan()
+
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+
+ fields := strings.Fields(line)
+ if len(fields) < 10 {
+ continue
+ }
+
+ // fields[1] = local_address (hex_ip:hex_port)
+ ip, port, err := parseHexAddr(fields[1], family)
+ if err != nil {
+ continue
+ }
+
+ // fields[3] = state (hex)
+ state, ok := tcpStates[fields[3]]
+ if !ok {
+ state = "UNKNOWN"
+ }
+
+ // fields[9] = inode
+ inode, err := strconv.ParseUint(fields[9], 10, 64)
+ if err != nil {
+ continue
+ }
+
+ conns = append(conns, ConnStat{
+ LocalIP: ip,
+ LocalPort: port,
+ Status: state,
+ Family: family,
+ Inode: inode,
+ })
+ }
+
+ return conns, scanner.Err()
+}
+
+// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp.
+// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word).
+// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word).
+func parseHexAddr(s string, family uint32) (string, uint32, error) {
+ parts := strings.SplitN(s, ":", 2)
+ if len(parts) != 2 {
+ return "", 0, fmt.Errorf("invalid address: %s", s)
+ }
+
+ port64, err := strconv.ParseUint(parts[1], 16, 32)
+ if err != nil {
+ return "", 0, err
+ }
+
+ ipHex := parts[0]
+ ipBytes, err := hex.DecodeString(ipHex)
+ if err != nil {
+ return "", 0, err
+ }
+
+ var ip net.IP
+ if family == syscall.AF_INET {
+ if len(ipBytes) != 4 {
+ return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes))
+ }
+ // /proc/net/tcp stores IPv4 as a single little-endian 32-bit word.
+ ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0])
+ } else {
+ if len(ipBytes) != 16 {
+ return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes))
+ }
+ // /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words.
+ ip = make(net.IP, 16)
+ for i := 0; i < 4; i++ {
+ ip[i*4+0] = ipBytes[i*4+3]
+ ip[i*4+1] = ipBytes[i*4+2]
+ ip[i*4+2] = ipBytes[i*4+1]
+ ip[i*4+3] = ipBytes[i*4+0]
+ }
+ }
+
+ return ip.String(), uint32(port64), nil
+}
diff --git a/envd/internal/port/forward.go b/envd/internal/port/forward.go
index e8365196..bf516ff9 100644
--- a/envd/internal/port/forward.go
+++ b/envd/internal/port/forward.go
@@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
type PortToForward struct {
socat *exec.Cmd
- // Process ID of the process that's listening on port.
- pid int32
+ // Socket inode of the listening socket (unique per connection).
+ inode uint64
// family version of the ip.
family uint32
state PortState
@@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
// This will make sure we won't delete them later.
for _, p := range procs {
- key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
+ key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort)
// We check if the opened port is in our map of forwarded ports.
val, portOk := f.ports[key]
@@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
val.state = PortStateForward
} else {
f.logger.Debug().
- Str("ip", p.Laddr.IP).
- Uint32("port", p.Laddr.Port).
+ Str("ip", p.LocalIP).
+ Uint32("port", p.LocalPort).
Uint32("family", familyToIPVersion(p.Family)).
Str("state", p.Status).
Msg("Detected new opened port on localhost that is not forwarded")
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
ptf := &PortToForward{
- pid: p.Pid,
- port: p.Laddr.Port,
+ inode: p.Inode,
+ port: p.LocalPort,
state: PortStateForward,
family: familyToIPVersion(p.Family),
}
@@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
f.logger.Debug().
Str("socatCmd", cmd.String()).
- Int32("pid", p.pid).
+ Uint64("inode", p.inode).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
@@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) {
logger := f.logger.With().
Str("socatCmd", p.socat.String()).
- Int32("pid", p.pid).
+ Uint64("inode", p.inode).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
diff --git a/envd/internal/port/scan.go b/envd/internal/port/scan.go
index 766202a6..2b155233 100644
--- a/envd/internal/port/scan.go
+++ b/envd/internal/port/scan.go
@@ -3,19 +3,21 @@
package port
import (
+ "sync"
"time"
"github.com/rs/zerolog"
- "github.com/shirou/gopsutil/v4/net"
-
- "git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
)
type Scanner struct {
- Processes chan net.ConnectionStat
- scanExit chan struct{}
- subs *smap.Map[*ScannerSubscriber]
- period time.Duration
+ scanExit chan struct{}
+ period time.Duration
+
+ // Plain mutex-protected map instead of concurrent-map. The concurrent-map
+ // library's Items() spawns goroutines and uses a WaitGroup internally,
+ // which corrupts Go runtime semaphore state across Firecracker snapshot/restore.
+ mu sync.RWMutex
+ subs map[string]*ScannerSubscriber
}
func (s *Scanner) Destroy() {
@@ -24,33 +26,44 @@ func (s *Scanner) Destroy() {
func NewScanner(period time.Duration) *Scanner {
return &Scanner{
- period: period,
- subs: smap.New[*ScannerSubscriber](),
- scanExit: make(chan struct{}),
- Processes: make(chan net.ConnectionStat),
+ period: period,
+ subs: make(map[string]*ScannerSubscriber),
+ scanExit: make(chan struct{}),
}
}
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
subscriber := NewScannerSubscriber(logger, id, filter)
- s.subs.Insert(id, subscriber)
+
+ s.mu.Lock()
+ s.subs[id] = subscriber
+ s.mu.Unlock()
return subscriber
}
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
- s.subs.Remove(sub.ID())
+ s.mu.Lock()
+ delete(s.subs, sub.ID())
+ s.mu.Unlock()
+
sub.Destroy()
}
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
func (s *Scanner) ScanAndBroadcast() {
for {
- // tcp monitors both ipv4 and ipv6 connections.
- processes, _ := net.Connections("tcp")
- for _, sub := range s.subs.Items() {
- sub.Signal(processes)
+ // Read directly from /proc/net/tcp and /proc/net/tcp6 instead of
+ // using gopsutil's net.Connections(), which walks /proc/{pid}/fd
+ // and causes Go runtime corruption after Firecracker snapshot/restore.
+ conns, _ := ReadTCPConnections()
+
+ s.mu.RLock()
+ for _, sub := range s.subs {
+ sub.Signal(conns)
}
+ s.mu.RUnlock()
+
select {
case <-s.scanExit:
return
diff --git a/envd/internal/port/scanSubscriber.go b/envd/internal/port/scanSubscriber.go
index 6a4f5b01..bad99083 100644
--- a/envd/internal/port/scanSubscriber.go
+++ b/envd/internal/port/scanSubscriber.go
@@ -4,7 +4,6 @@ package port
import (
"github.com/rs/zerolog"
- "github.com/shirou/gopsutil/v4/net"
)
// If we want to create a listener/subscriber pattern somewhere else we should move
@@ -13,7 +12,7 @@ import (
type ScannerSubscriber struct {
logger *zerolog.Logger
filter *ScannerFilter
- Messages chan ([]net.ConnectionStat)
+ Messages chan ([]ConnStat)
id string
}
@@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt
logger: logger,
id: id,
filter: filter,
- Messages: make(chan []net.ConnectionStat),
+ Messages: make(chan []ConnStat),
}
}
@@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() {
close(ss.Messages)
}
-func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
+func (ss *ScannerSubscriber) Signal(conns []ConnStat) {
// Filter isn't specified. Accept everything.
if ss.filter == nil {
- ss.Messages <- proc
+ ss.Messages <- conns
} else {
- filtered := []net.ConnectionStat{}
- for i := range proc {
+ filtered := []ConnStat{}
+ for i := range conns {
// We need to access the list directly otherwise there will be implicit memory aliasing
- // If the filter matched a process, we will send it to a channel.
- if ss.filter.Match(&proc[i]) {
- filtered = append(filtered, proc[i])
+ // If the filter matched a connection, we will send it to a channel.
+ if ss.filter.Match(&conns[i]) {
+ filtered = append(filtered, conns[i])
}
}
ss.Messages <- filtered
diff --git a/envd/internal/port/scanfilter.go b/envd/internal/port/scanfilter.go
index 941023d9..f87667f2 100644
--- a/envd/internal/port/scanfilter.go
+++ b/envd/internal/port/scanfilter.go
@@ -4,8 +4,6 @@ package port
import (
"slices"
-
- "github.com/shirou/gopsutil/v4/net"
)
type ScannerFilter struct {
@@ -13,15 +11,15 @@ type ScannerFilter struct {
IPs []string
}
-func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
+func (sf *ScannerFilter) Match(conn *ConnStat) bool {
// Filter is an empty struct.
if sf.State == "" && len(sf.IPs) == 0 {
return false
}
- ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
+ ipMatch := slices.Contains(sf.IPs, conn.LocalIP)
- if ipMatch && sf.State == proc.Status {
+ if ipMatch && sf.State == conn.Status {
return true
}
diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts
new file mode 100644
index 00000000..1de23b8d
--- /dev/null
+++ b/frontend/src/lib/api/builds.ts
@@ -0,0 +1,76 @@
+import { apiFetch, type ApiResult } from '$lib/api/client';
+
+export type BuildLogEntry = {
+ step: number;
+ phase: string; // "pre-build", "recipe", or "post-build"
+ cmd: string;
+ stdout: string;
+ stderr: string;
+ exit: number;
+ ok: boolean;
+ elapsed_ms: number;
+};
+
+export type Build = {
+ id: string;
+ name: string;
+ base_template: string;
+ recipe: string[];
+ healthcheck?: string;
+ vcpus: number;
+ memory_mb: number;
+ status: string;
+ current_step: number;
+ total_steps: number;
+ logs: BuildLogEntry[];
+ error?: string;
+ sandbox_id?: string;
+ host_id?: string;
+ created_at: string;
+ started_at?: string;
+ completed_at?: string;
+};
+
+export type CreateBuildParams = {
+ name: string;
+ base_template?: string;
+ recipe: string[];
+ healthcheck?: string;
+ vcpus?: number;
+ memory_mb?: number;
+ skip_pre_post?: boolean;
+};
+
+export async function createBuild(params: CreateBuildParams): Promise> {
+ return apiFetch('POST', '/api/v1/admin/builds', params);
+}
+
+export async function listBuilds(): Promise> {
+ return apiFetch('GET', '/api/v1/admin/builds');
+}
+
+export async function getBuild(id: string): Promise> {
+ return apiFetch('GET', `/api/v1/admin/builds/${id}`);
+}
+
+export type AdminTemplate = {
+ name: string;
+ type: string;
+ vcpus: number;
+ memory_mb: number;
+ size_bytes: number;
+ team_id: string;
+ created_at: string;
+};
+
+export async function listAdminTemplates(): Promise> {
+ return apiFetch('GET', '/api/v1/admin/templates');
+}
+
+export async function deleteAdminTemplate(name: string): Promise> {
+ return apiFetch('DELETE', `/api/v1/admin/templates/${name}`);
+}
+
+export async function cancelBuild(id: string): Promise> {
+ return apiFetch('POST', `/api/v1/admin/builds/${id}/cancel`);
+}
diff --git a/frontend/src/lib/api/capsules.ts b/frontend/src/lib/api/capsules.ts
index cc4ad79b..565f14f2 100644
--- a/frontend/src/lib/api/capsules.ts
+++ b/frontend/src/lib/api/capsules.ts
@@ -54,6 +54,7 @@ export type Snapshot = {
memory_mb?: number;
size_bytes: number;
created_at: string;
+ platform: boolean;
};
export async function createSnapshot(sandboxId: string, name?: string): Promise> {
diff --git a/frontend/src/lib/components/AdminSidebar.svelte b/frontend/src/lib/components/AdminSidebar.svelte
index 4bed5cc9..ebf4b646 100644
--- a/frontend/src/lib/components/AdminSidebar.svelte
+++ b/frontend/src/lib/components/AdminSidebar.svelte
@@ -3,6 +3,7 @@
import { auth } from '$lib/auth.svelte';
import {
IconServer,
+ IconTemplate,
IconSettings,
IconLogout,
IconSidebar,
@@ -21,7 +22,8 @@
};
const managementItems: NavItem[] = [
- { label: 'Hosts', icon: IconServer, href: '/admin/hosts' }
+ { label: 'Hosts', icon: IconServer, href: '/admin/hosts' },
+ { label: 'Templates', icon: IconTemplate, href: '/admin/templates' }
];
function isActive(href: string): boolean {
diff --git a/frontend/src/lib/components/CreateCapsuleDialog.svelte b/frontend/src/lib/components/CreateCapsuleDialog.svelte
index b570f2b4..2bd027d6 100644
--- a/frontend/src/lib/components/CreateCapsuleDialog.svelte
+++ b/frontend/src/lib/components/CreateCapsuleDialog.svelte
@@ -56,6 +56,7 @@
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]"
placeholder="minimal"
/>
+ Name of a snapshot or base image to boot from.
@@ -85,14 +86,16 @@
-
Idle timeout (seconds — 0 = never pause)
+
Idle timeout
+
Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.
diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte
new file mode 100644
index 00000000..4619e7b9
--- /dev/null
+++ b/frontend/src/routes/admin/templates/+page.svelte
@@ -0,0 +1,902 @@
+
+
+
+
+
+
+
+
+
+
+
+ {#each [['templates', 'Templates', templateCount], ['builds', 'Builds', builds.length]] as [id, label, count] (id)}
+ { activeTab = id as 'templates' | 'builds'; }}
+ class="relative py-3 pr-5 text-ui transition-colors duration-150 {activeTab === id
+ ? 'font-medium text-[var(--color-text-bright)]'
+ : 'text-[var(--color-text-tertiary)] hover:text-[var(--color-text-secondary)]'}"
+ >
+ {label}
+ {#if activeTab === id}
+
+ {/if}
+ {#if !templatesLoading}
+
+ {count}
+
+ {/if}
+
+ {/each}
+
+
+
+
+ {#if activeTab === 'templates'}
+ {#if templatesLoading}
+ {@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])}
+ {:else if templatesError}
+
+ {templatesError}
+
+ {:else if templates.length === 0}
+ {@render emptyState('templates')}
+ {:else}
+ {@render templatesTable()}
+ {/if}
+ {:else}
+ {#if buildsLoading}
+ {@render skeletonRows(4, ['Build', 'Name', 'Status', 'Progress', 'Started', 'Duration'])}
+ {:else if buildsError}
+
+ {buildsError}
+
+ {:else if builds.length === 0}
+ {@render emptyState('builds')}
+ {:else}
+ {@render buildsTable()}
+ {/if}
+ {/if}
+
+
+
+
+
+
+{#snippet skeletonRows(count: number, headers: string[])}
+
+
+
+
+ {#each headers as h}
+ {h}
+ {/each}
+
+
+
+ {#each Array(count) as _, i}
+
+ {#each headers as _h, j}
+
+
+
+ {/each}
+
+ {/each}
+
+
+
+{/snippet}
+
+{#snippet emptyState(type: 'templates' | 'builds')}
+
+
+ {#if type === 'templates'}
+
+ {:else}
+
+ {/if}
+
+
+ {type === 'templates' ? 'No templates yet.' : 'No builds yet.'}
+
+
+ {type === 'templates'
+ ? 'Create a template to provide pre-configured environments for all teams.'
+ : 'Start a template build to see progress and logs here.'}
+
+
+{/snippet}
+
+{#snippet templatesTable()}
+
+
+
+
+ Name
+ Type
+ Specs
+ Size
+ Created
+
+
+
+
+ {#each templates as tmpl (tmpl.name)}
+
+
+ {tmpl.name}
+
+
+ {#if tmpl.type === 'snapshot'}
+
+ snapshot
+
+ {:else}
+
+ base
+
+ {/if}
+
+
+ {#if tmpl.vcpus && tmpl.memory_mb}
+
+ {tmpl.vcpus} vCPU · {tmpl.memory_mb} MB
+
+ {:else}
+ —
+ {/if}
+
+
+
+ {tmpl.size_bytes ? formatBytes(tmpl.size_bytes) : '—'}
+
+
+
+
+ {timeAgo(tmpl.created_at)}
+
+
+
+ { deleteTarget = tmpl; deleteError = null; }}
+ class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta text-[var(--color-text-tertiary)] transition-colors duration-150 hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]"
+ >
+ Delete
+
+
+
+ {/each}
+
+
+
+{/snippet}
+
+{#snippet buildsTable()}
+
+
+
+
+ Build
+ Name
+ Base
+ Status
+ Progress
+ Started
+ Duration
+
+
+
+ {#each builds as build (build.id)}
+ toggleBuildExpand(build.id)}
+ >
+
+
+
+
+ {build.name}
+
+
+ {build.base_template}
+
+
+
+ {#if build.status === 'running'}
+
+
+
+
+ {:else if build.status === 'success'}
+
+ {:else if build.status === 'failed'}
+
+ {:else}
+
+ {/if}
+ {build.status}
+
+
+
+
+ {build.current_step} / {build.total_steps}
+
+ {#if build.status === 'running' && build.total_steps > 0}
+
+ {/if}
+
+
+
+ {build.started_at ? timeAgo(build.started_at) : '—'}
+
+
+
+
+ {formatDuration(build.started_at, build.completed_at)}
+
+
+
+
+ {#if expandedBuildId === build.id}
+
+
+
+ {#if build.status === 'pending' || build.status === 'running'}
+
+
{ e.stopPropagation(); handleCancelBuild(build.id); }}
+ disabled={cancelingBuildId === build.id}
+ class="flex items-center gap-1.5 rounded-[var(--radius-button)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/8 px-3 py-1.5 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/15 disabled:opacity-50"
+ >
+ {#if cancelingBuildId === build.id}
+
+ {:else}
+
+ {/if}
+ Cancel build
+
+
+ {/if}
+ {#if build.error}
+
+ {build.error}
+
+ {/if}
+
+ {#if build.logs && build.logs.length > 0}
+
+ {#each build.logs as log, i (i)}
+ {@const isInternal = log.phase === 'pre-build' || log.phase === 'post-build'}
+ {@const recipeIdx = log.phase === 'recipe' ? build.logs.filter(l => l.phase === 'recipe' && l.step <= log.step).length : 0}
+ {@const phaseLabel = isInternal ? (log.phase === 'pre-build' ? 'Pre-build' : 'Post-build') : `Step ${recipeIdx}`}
+ {@const [kw, kwRest] = splitInstruction(log.cmd)}
+
+ {/each}
+
+ {:else}
+
+ {#if build.status === 'pending' || build.status === 'running'}
+
+ {build.status === 'pending' ? 'Waiting for worker…' : 'Running…'}
+ {:else}
+ No build logs recorded.
+ {/if}
+
+ {/if}
+
+
+ {#if build.recipe && build.recipe.length > 0}
+
+ {/if}
+
+ {#if build.healthcheck}
+
+ Healthcheck
+ {build.healthcheck}
+
+ {/if}
+
+
+
+ {/if}
+ {/each}
+
+
+
+{/snippet}
+
+
+{#if showCreate}
+
+
{ if (!creating) showCreate = false; }}
+ onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }}
+ >
+
+
+ Create Template
+
+
+ Build a new global template by running commands on a base image.
+
+
+ {#if createError}
+
+ {createError}
+
+ {/if}
+
+
+
+
+ Template Name
+
+
+
+
+
+
+
+
+ Recipe (one instruction per line)
+
+
+
+ Supports RUN, START, WORKDIR, ENV key=value. RUN steps have a 30s timeout; override with RUN --timeout=5m.
+
+
+
+
+
+ Healthcheck (optional)
+
+
+
+ If set, the build will poll this command every 1s (up to 60s) after the recipe completes. On success, a full snapshot (with memory state) is created. Without a healthcheck, only the rootfs is saved.
+
+
+
+
+
+ Skip pre-build and post-build steps
+
+
+
+
+
(showCreate = false)}
+ disabled={creating}
+ class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
+ >
+ Cancel
+
+
+ {#if creating}
+
+ Creating…
+ {:else}
+ Start Build
+ {/if}
+
+
+
+
+{/if}
+
+
+{#if deleteTarget}
+
+
{ if (!deleting) deleteTarget = null; }}
+ onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }}
+ >
+
+
+ Delete Template
+
+
+ Permanently remove {deleteTarget.name} from all hosts.
+
+
+ {#if deleteError}
+
+ {deleteError}
+
+ {/if}
+
+
+
(deleteTarget = null)}
+ disabled={deleting}
+ class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
+ >
+ Cancel
+
+
+ {#if deleting}
+
+ Deleting…
+ {:else}
+ Delete
+ {/if}
+
+
+
+
+{/if}
+
+
diff --git a/frontend/src/routes/dashboard/capsules/+layout.svelte b/frontend/src/routes/dashboard/capsules/+layout.svelte
index d9f93f82..5c400b9a 100644
--- a/frontend/src/routes/dashboard/capsules/+layout.svelte
+++ b/frontend/src/routes/dashboard/capsules/+layout.svelte
@@ -47,7 +47,7 @@
Capsules
- Isolated VMs. Start cold in under a second — pause, snapshot, or destroy at will.
+ All active and recent capsules across your team.
diff --git a/frontend/src/routes/dashboard/capsules/+page.svelte b/frontend/src/routes/dashboard/capsules/+page.svelte
index 4f270035..afb9de0c 100644
--- a/frontend/src/routes/dashboard/capsules/+page.svelte
+++ b/frontend/src/routes/dashboard/capsules/+page.svelte
@@ -247,6 +247,13 @@
return `${Math.floor(seconds / 86400)}d ago`;
}
+ function fmtTimeout(sec: number): string {
+ if (!sec) return 'None';
+ if (sec < 60) return `${sec}s`;
+ if (sec < 3600) return `${Math.round(sec / 60)}m`;
+ return `${Math.round(sec / 3600)}h`;
+ }
+
function handleClickOutside(event: MouseEvent) {
if (openMenuId && !(event.target as Element)?.closest('.status-menu-container')) {
openMenuId = null;
@@ -300,7 +307,7 @@
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-2)] py-2 pl-9 pr-3 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]"
/>
- {filteredCapsules.length} total
+ {filteredCapsules.length} capsule{filteredCapsules.length !== 1 ? 's' : ''}
@@ -363,8 +370,11 @@
{#if error}
-
- {error}
+
+
+
+
+ {error}. Try refreshing the page.
{/if}
@@ -466,14 +476,14 @@
- {capsule.timeout_sec ? `${capsule.timeout_sec}s` : '—'}
+ {fmtTimeout(capsule.timeout_sec)}
{formatTime(capsule.started_at)}
{#if capsule.last_active_at}
- {timeAgo(capsule.last_active_at)}
+ active {timeAgo(capsule.last_active_at)}
{/if}
@@ -612,7 +622,7 @@
-
This capsule will be paused first — memory state is captured at rest.
+
This capsule will be paused first , then its full state (memory + disk) will be captured.
{:else}
The capsule's current memory state will be captured and stored as a reusable snapshot.
diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.js b/frontend/src/routes/dashboard/capsules/[id]/+page.js
new file mode 100644
index 00000000..d43d0cd2
--- /dev/null
+++ b/frontend/src/routes/dashboard/capsules/[id]/+page.js
@@ -0,0 +1 @@
+export const prerender = false;
diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte
index e932209a..ed26426d 100644
--- a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte
+++ b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte
@@ -512,7 +512,7 @@
- Failed to load metrics: {metricsError}
+ Could not load metrics: {metricsError}. Will retry automatically.
{/if}
diff --git a/frontend/src/routes/dashboard/snapshots/+page.svelte b/frontend/src/routes/dashboard/snapshots/+page.svelte
index e39bf3b6..2ae201ab 100644
--- a/frontend/src/routes/dashboard/snapshots/+page.svelte
+++ b/frontend/src/routes/dashboard/snapshots/+page.svelte
@@ -114,15 +114,15 @@
}
function emptyHeading(f: TypeFilter): string {
- if (f === 'snapshot') return 'No snapshots';
- if (f === 'base') return 'No images';
- return 'No templates yet';
+ if (f === 'snapshot') return 'No snapshots yet';
+ if (f === 'base') return 'No base images';
+ return 'No snapshots yet';
}
function emptyDescription(f: TypeFilter): string {
- if (f === 'snapshot') return 'Pause a capsule from the Capsules page, then snapshot it to capture its state.';
- if (f === 'base') return 'Base images are added by the Wrenn team. Contact support to request a custom image.';
- return 'To create a snapshot, go to Capsules, pause a running capsule, then choose Snapshot.';
+ if (f === 'snapshot') return 'Pause a running capsule, then choose Snapshot to save its state.';
+ if (f === 'base') return 'Base images are provided by the Wrenn team. Contact support to request a custom one.';
+ return 'Pause a running capsule, then choose Snapshot to save its state. You can launch new capsules from any snapshot.';
}
onMount(fetchSnapshots);
@@ -162,7 +162,7 @@
Templates
- Snapshots capture a live capsule state. Base images are the rootfs every capsule starts from. Launch a full VM from any template.
+ Snapshots capture a running capsule's state. Base images are the starting point for every new capsule. Launch from either.
@@ -206,8 +206,11 @@
{#if pageTab === 'snapshots'}
{#if error}
-
- {error}
+
+
+
+
+ {error}. Try refreshing the page.
{/if}
@@ -274,11 +277,11 @@
{filteredSnapshots.length}
- {typeFilter === 'all'
- ? filteredSnapshots.length === 1 ? 'template' : 'templates'
- : typeFilter === 'snapshot'
- ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
- : filteredSnapshots.length === 1 ? 'image' : 'images'}
+ {typeFilter === 'snapshot'
+ ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
+ : typeFilter === 'base'
+ ? filteredSnapshots.length === 1 ? 'image' : 'images'
+ : filteredSnapshots.length === 1 ? 'item' : 'total'}
@@ -420,6 +423,7 @@
{
e.stopPropagation();
if (openDropdownName === snapshot.name) {
@@ -430,7 +434,7 @@
openDropdownName = snapshot.name;
}
}}
- class="flex items-center px-2 py-1.5 text-[var(--color-text-secondary)] transition-colors duration-150 hover:bg-[var(--color-bg-4)] hover:text-[var(--color-text-bright)]"
+ class="flex items-center px-2 py-1.5 text-[var(--color-text-secondary)] transition-colors duration-150 hover:bg-[var(--color-bg-4)] hover:text-[var(--color-text-bright)] disabled:cursor-not-allowed disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-[var(--color-text-secondary)]"
>
{filteredSnapshots.length}
- {typeFilter === 'all'
- ? filteredSnapshots.length === 1 ? 'template' : 'templates'
- : typeFilter === 'snapshot'
- ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
- : filteredSnapshots.length === 1 ? 'image' : 'images'}
+ {typeFilter === 'snapshot'
+ ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
+ : typeFilter === 'base'
+ ? filteredSnapshots.length === 1 ? 'image' : 'images'
+ : filteredSnapshots.length === 1 ? 'item' : 'total'}
{typeFilter !== 'all' ? '· filtered' : '· total'}
{/if}
@@ -481,21 +485,23 @@
class="fixed z-50 w-32 overflow-hidden rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] py-1"
style="top: {dropdownPos.top}px; left: {dropdownPos.left}px; animation: fadeUp 0.15s ease both"
>
- {
- e.stopPropagation();
- const target = snapshots.find((s) => s.name === openDropdownName);
- openDropdownName = null;
- if (target) { deleteTarget = target; deleteError = null; }
- }}
- class="flex w-full items-center gap-2 px-3 py-2 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/5"
- >
-
-
-
-
- Delete
-
+ {#if !dropdownSnapshot.platform}
+ {
+ e.stopPropagation();
+ const target = snapshots.find((s) => s.name === openDropdownName);
+ openDropdownName = null;
+ if (target) { deleteTarget = target; deleteError = null; }
+ }}
+ class="flex w-full items-center gap-2 px-3 py-2 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/5"
+ >
+
+
+
+
+ Delete
+
+ {/if}
{/if}
{/if}
@@ -513,10 +519,10 @@
class="relative w-full max-w-[380px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] p-6"
style="animation: fadeUp 0.2s ease both"
>
- Delete Snapshot
+ Delete snapshot
Permanently delete {deleteTarget.name} .
- Any capsule using this template will not be affected, but you won't be able to launch from it again.
+ Running capsules won't be affected, but you won't be able to launch new ones from it.
{#if deleteTarget.type === 'snapshot'}
@@ -526,7 +532,7 @@
- This live capture includes saved memory state. Any capsule relying on it will be unable to resume.
+ This snapshot includes memory state. Paused capsules that depend on it won't be able to resume.
{/if}
@@ -580,7 +586,7 @@
>
Launch Capsule
- Configure resources and launch. The VM will clone from this template and be ready in seconds.
+ Configure resources and launch a new capsule from this snapshot.
{#if launchError}
@@ -655,14 +661,16 @@
-
Auto-pause timeout (seconds, 0 = never)
+
Idle timeout
+
Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.
diff --git a/frontend/src/routes/dashboard/team/+page.svelte b/frontend/src/routes/dashboard/team/+page.svelte
index 773cdf4f..f17311ba 100644
--- a/frontend/src/routes/dashboard/team/+page.svelte
+++ b/frontend/src/routes/dashboard/team/+page.svelte
@@ -50,7 +50,6 @@
let nameInputEl = $state(null);
// Copy state
- let copiedSlug = $state(false);
let copiedId = $state(false);
// Add member dialog
@@ -139,16 +138,11 @@
savingName = false;
}
- async function copyToClipboard(text: string, which: 'slug' | 'id') {
+ async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text);
- if (which === 'slug') {
- copiedSlug = true;
- setTimeout(() => (copiedSlug = false), 2000);
- } else {
- copiedId = true;
- setTimeout(() => (copiedId = false), 2000);
- }
+ copiedId = true;
+ setTimeout(() => (copiedId = false), 2000);
} catch {
toast.error('Copy failed — select the text and copy manually.');
}
@@ -514,115 +508,58 @@
-
-
-
-
-
-
- Slug
-
-
{team.slug}
-
-
copyToClipboard(team!.slug, 'slug')}
- title="Copy slug"
- class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
- {copiedSlug
- ? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
- : 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
+
+
+
+
- {#if copiedSlug}
-
-
-
- Copied
- {:else}
-
-
-
-
- Copy
- {/if}
-
-
-
-
-
-
-
- Team ID
-
-
{team.id}
+ Team ID
-
copyToClipboard(team!.id, 'id')}
- title="Copy team ID"
- class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
- {copiedId
- ? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
- : 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
+ {team.id}
- {#if copiedId}
-
-
-
- Copied
- {:else}
-
-
-
-
- Copy
- {/if}
-
+
copyToClipboard(team!.id)}
+ title="Copy team ID"
+ class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
+ {copiedId
+ ? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
+ : 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
+ >
+ {#if copiedId}
+
+
+
+ Copied
+ {:else}
+
+
+
+
+ Copy
+ {/if}
+
diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts
index 89afaba3..350070b5 100644
--- a/frontend/vite.config.ts
+++ b/frontend/vite.config.ts
@@ -7,7 +7,7 @@ export default defineConfig({
server: {
proxy: {
'/api': {
- target: 'http://localhost:8000',
+ target: 'http://localhost:8080',
rewrite: (path) => path.replace(/^\/api/, '')
}
}
diff --git a/images/wrenn-init.sh b/images/wrenn-init.sh
index 32285ea7..266e516a 100644
--- a/images/wrenn-init.sh
+++ b/images/wrenn-init.sh
@@ -1,6 +1,6 @@
#!/bin/sh
# wrenn-init: minimal PID 1 init for Firecracker microVMs.
-# Mounts virtual filesystems then execs envd.
+# Mounts virtual filesystems, starts chronyd for time sync, then execs tini + envd.
set -e
@@ -25,7 +25,19 @@ echo "nameserver 8.8.8.8" > /etc/resolv.conf
echo "nameserver 8.8.4.4" >> /etc/resolv.conf
# Set a standard PATH so envd and all child processes can find common binaries.
-export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
+export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games
+
+# Write chrony config to sync time from the KVM PTP hardware clock.
+# /dev/ptp0 is a paravirtual clock exposed by KVM — no network required.
+mkdir -p /etc/chrony /run/chrony
+cat > /etc/chrony/chrony.conf </dev/null || true
# Exec tini as PID 1 — it reaps zombie processes and forwards signals to envd.
exec /sbin/tini -- /usr/local/bin/envd
diff --git a/internal/api/agent_helper.go b/internal/api/agent_helper.go
index ac5b38e0..98a881d6 100644
--- a/internal/api/agent_helper.go
+++ b/internal/api/agent_helper.go
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@@ -11,7 +13,7 @@ import (
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
-func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
+func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)
diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go
new file mode 100644
index 00000000..963dff69
--- /dev/null
+++ b/internal/api/handler_sandbox_proxy.go
@@ -0,0 +1,229 @@
+package api
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jackc/pgx/v5/pgtype"
+
+ "git.omukk.dev/wrenn/sandbox/internal/auth"
+ "git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/lifecycle"
+)
+
+// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
+// without relying on error message text.
+var (
+ errProxySandboxNotFound = errors.New("sandbox not found")
+ errProxyNoHostAddress = errors.New("host agent has no address")
+)
+
+const proxyCacheTTL = 120 * time.Second
+
+// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
+// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
+var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
+
+// errProxySandboxNotRunning carries the sandbox status so callers can include
+// it in the HTTP response without parsing error strings.
+type errProxySandboxNotRunning struct{ status string }
+
+func (e errProxySandboxNotRunning) Error() string {
+ return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
+}
+
+// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
+// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
+// can capture the correct port without the cache key needing to include it.
+type proxyCacheEntry struct {
+ agentURL *url.URL
+ expiresAt time.Time
+}
+
+// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
+type proxyCacheKey [32]byte
+
+func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
+ var k proxyCacheKey
+ copy(k[:16], sandboxID.Bytes[:])
+ copy(k[16:], teamID.Bytes[:])
+ return k
+}
+
+// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
+// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
+// requests are reverse-proxied through the host agent that owns the sandbox.
+// All other requests are passed through to the inner handler.
+//
+// Authentication is via X-API-Key header only (no JWT). The API key's team
+// must own the sandbox.
+type SandboxProxyWrapper struct {
+ inner http.Handler
+ db *db.Queries
+ pool *lifecycle.HostClientPool
+ transport http.RoundTripper
+
+ cacheMu sync.Mutex
+ cache map[proxyCacheKey]proxyCacheEntry
+}
+
+// NewSandboxProxyWrapper creates a new proxy wrapper.
+func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
+ return &SandboxProxyWrapper{
+ inner: inner,
+ db: queries,
+ pool: pool,
+ transport: pool.Transport(),
+ cache: make(map[proxyCacheKey]proxyCacheEntry),
+ }
+}
+
+// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
+// On a miss it queries the DB, resolves the address, and populates the cache.
+// The *httputil.ReverseProxy is built by the caller so the Director closure
+// captures the correct port without the cache key needing to include it.
+func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
+ cacheKey := makeProxyCacheKey(sandboxID, teamID)
+
+ h.cacheMu.Lock()
+ entry, ok := h.cache[cacheKey]
+ h.cacheMu.Unlock()
+
+ if ok && time.Now().Before(entry.expiresAt) {
+ return entry.agentURL, nil
+ }
+
+ // Cache miss or expired — query DB.
+ target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
+ ID: sandboxID,
+ TeamID: teamID,
+ })
+ if err != nil {
+ return nil, errProxySandboxNotFound
+ }
+ if target.Status != "running" {
+ return nil, errProxySandboxNotRunning{status: target.Status}
+ }
+ if target.HostAddress == "" {
+ return nil, errProxyNoHostAddress
+ }
+
+ agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
+ if err != nil {
+ return nil, fmt.Errorf("invalid host agent address: %w", err)
+ }
+
+ h.cacheMu.Lock()
+ h.cache[cacheKey] = proxyCacheEntry{
+ agentURL: agentURL,
+ expiresAt: time.Now().Add(proxyCacheTTL),
+ }
+ h.cacheMu.Unlock()
+
+ return agentURL, nil
+}
+
+// evictProxyCache removes the cached entry for a (sandbox, team) pair.
+// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
+func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
+ h.cacheMu.Lock()
+ delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
+ h.cacheMu.Unlock()
+}
+
+func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ host := r.Host
+ // Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
+ if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
+ host = host[:colonIdx]
+ }
+
+ matches := sandboxHostPattern.FindStringSubmatch(host)
+ if matches == nil {
+ h.inner.ServeHTTP(w, r)
+ return
+ }
+
+ port := matches[1]
+ sandboxIDStr := matches[2]
+
+ // Validate port.
+ portNum, err := strconv.Atoi(port)
+ if err != nil || portNum < 1 || portNum > 65535 {
+ http.Error(w, "invalid port", http.StatusBadRequest)
+ return
+ }
+
+ // Authenticate: require API key or JWT, extract team ID.
+ teamID, err := h.authenticateRequest(r)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
+ return
+ }
+
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
+ return
+ }
+
+ agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
+ if err != nil {
+ switch {
+ case errors.Is(err, errProxySandboxNotFound):
+ http.Error(w, err.Error(), http.StatusNotFound)
+ case errors.As(err, new(errProxySandboxNotRunning)):
+ http.Error(w, err.Error(), http.StatusConflict)
+ default:
+ http.Error(w, err.Error(), http.StatusServiceUnavailable)
+ }
+ return
+ }
+
+ proxy := &httputil.ReverseProxy{
+ Transport: h.transport,
+ Director: func(req *http.Request) {
+ req.URL.Scheme = agentURL.Scheme
+ req.URL.Host = agentURL.Host
+ req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path
+ req.Host = agentURL.Host
+ },
+ ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
+ slog.Debug("sandbox proxy error",
+ "sandbox_id", sandboxIDStr,
+ "port", port,
+ "error", err,
+ )
+ h.evictProxyCache(sandboxID, teamID)
+ http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
+ },
+ }
+ proxy.ServeHTTP(w, r)
+}
+
+// authenticateRequest validates the request's API key and returns the team ID.
+// Only API key authentication is supported for sandbox proxy requests (not JWT).
+func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
+ key := r.Header.Get("X-API-Key")
+ if key == "" {
+ return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
+ }
+
+ hash := auth.HashAPIKey(key)
+ row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
+ if err != nil {
+ return pgtype.UUID{}, fmt.Errorf("invalid API key")
+ }
+ return row.TeamID, nil
+}
diff --git a/internal/api/handlers_apikeys.go b/internal/api/handlers_apikeys.go
index 2637181d..700ddc52 100644
--- a/internal/api/handlers_apikeys.go
+++ b/internal/api/handlers_apikeys.go
@@ -9,6 +9,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@@ -39,11 +40,11 @@ type apiKeyResponse struct {
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
resp := apiKeyResponse{
- ID: k.ID,
- TeamID: k.TeamID,
+ ID: id.FormatAPIKeyID(k.ID),
+ TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
- CreatedBy: k.CreatedBy,
+ CreatedBy: id.FormatUserID(k.CreatedBy),
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
@@ -57,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
resp := apiKeyResponse{
- ID: k.ID,
- TeamID: k.TeamID,
+ ID: id.FormatAPIKeyID(k.ID),
+ TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
- CreatedBy: k.CreatedBy,
+ CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
}
if k.CreatedAt.Valid {
@@ -118,7 +119,13 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /v1/api-keys/{id}.
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
- keyID := chi.URLParam(r, "id")
+ keyIDStr := chi.URLParam(r, "id")
+
+ keyID, err := id.ParseAPIKeyID(keyIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
+ return
+ }
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
diff --git a/internal/api/handlers_audit.go b/internal/api/handlers_audit.go
index 7812309d..a19ab1dc 100644
--- a/internal/api/handlers_audit.go
+++ b/internal/api/handlers_audit.go
@@ -6,7 +6,10 @@ import (
"strings"
"time"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/auth"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@@ -65,13 +68,24 @@ func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
limit = n
}
+ // Parse ?before_id cursor (UUID).
+ var beforeID pgtype.UUID
+ if s := r.URL.Query().Get("before_id"); s != "" {
+ parsed, err := id.ParseAuditLogID(s)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
+ return
+ }
+ beforeID = parsed
+ }
+
entries, err := h.svc.List(r.Context(), service.AuditListParams{
TeamID: ac.TeamID,
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
Actions: parseMultiParam(r.URL.Query()["action"]),
Before: before,
- BeforeID: r.URL.Query().Get("before_id"),
+ BeforeID: beforeID,
Limit: limit,
})
if err != nil {
diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go
index ba60d8e1..b1d4915f 100644
--- a/internal/api/handlers_auth.go
+++ b/internal/api/handlers_auth.go
@@ -20,7 +20,7 @@ import (
// It prefers the user's default team; if none is flagged as default it falls
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
// no team memberships at all.
-func loginTeam(ctx context.Context, q *db.Queries, userID string) (db.Team, string, error) {
+func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
team, err := q.GetDefaultTeamForUser(ctx, userID)
if err == nil {
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
@@ -176,8 +176,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, authResponse{
Token: token,
- UserID: userID,
- TeamID: teamID,
+ UserID: id.FormatUserID(userID),
+ TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
})
@@ -236,8 +236,8 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
- UserID: user.ID,
- TeamID: team.ID,
+ UserID: id.FormatUserID(user.ID),
+ TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
@@ -260,10 +260,16 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
+ teamID, err := id.ParseTeamID(req.TeamID)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
+ return
+ }
+
ctx := r.Context()
// Verify team exists and is not deleted.
- team, err := h.db.GetTeam(ctx, req.TeamID)
+ team, err := h.db.GetTeam(ctx, teamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "team not found")
@@ -280,7 +286,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
// Verify membership from DB — JWT role is not trusted here.
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: ac.UserID,
- TeamID: req.TeamID,
+ TeamID: teamID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@@ -298,7 +304,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
- token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
+ token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@@ -306,8 +312,8 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
- UserID: ac.UserID,
- TeamID: req.TeamID,
+ UserID: id.FormatUserID(ac.UserID),
+ TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
})
diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go
new file mode 100644
index 00000000..282c3f48
--- /dev/null
+++ b/internal/api/handlers_builds.go
@@ -0,0 +1,276 @@
+package api
+
+import (
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "time"
+
+ "connectrpc.com/connect"
+ "github.com/go-chi/chi/v5"
+
+ "git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/layout"
+ "git.omukk.dev/wrenn/sandbox/internal/lifecycle"
+ "git.omukk.dev/wrenn/sandbox/internal/service"
+ "git.omukk.dev/wrenn/sandbox/internal/validate"
+ pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
+)
+
+type buildHandler struct {
+ svc *service.BuildService
+ db *db.Queries
+ pool *lifecycle.HostClientPool
+}
+
+func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
+ return &buildHandler{svc: svc, db: db, pool: pool}
+}
+
+type createBuildRequest struct {
+ Name string `json:"name"`
+ BaseTemplate string `json:"base_template"`
+ Recipe []string `json:"recipe"`
+ Healthcheck string `json:"healthcheck"`
+ VCPUs int32 `json:"vcpus"`
+ MemoryMB int32 `json:"memory_mb"`
+ SkipPrePost bool `json:"skip_pre_post"`
+}
+
+type buildResponse struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ BaseTemplate string `json:"base_template"`
+ Recipe json.RawMessage `json:"recipe"`
+ Healthcheck *string `json:"healthcheck,omitempty"`
+ VCPUs int32 `json:"vcpus"`
+ MemoryMB int32 `json:"memory_mb"`
+ Status string `json:"status"`
+ CurrentStep int32 `json:"current_step"`
+ TotalSteps int32 `json:"total_steps"`
+ Logs json.RawMessage `json:"logs"`
+ Error *string `json:"error,omitempty"`
+ SandboxID *string `json:"sandbox_id,omitempty"`
+ HostID *string `json:"host_id,omitempty"`
+ CreatedAt string `json:"created_at"`
+ StartedAt *string `json:"started_at,omitempty"`
+ CompletedAt *string `json:"completed_at,omitempty"`
+}
+
+func buildToResponse(b db.TemplateBuild) buildResponse {
+ resp := buildResponse{
+ ID: id.FormatBuildID(b.ID),
+ Name: b.Name,
+ BaseTemplate: b.BaseTemplate,
+ Recipe: b.Recipe,
+ VCPUs: b.Vcpus,
+ MemoryMB: b.MemoryMb,
+ Status: b.Status,
+ CurrentStep: b.CurrentStep,
+ TotalSteps: b.TotalSteps,
+ Logs: b.Logs,
+ }
+ if b.Healthcheck != "" {
+ resp.Healthcheck = &b.Healthcheck
+ }
+ if b.Error != "" {
+ resp.Error = &b.Error
+ }
+ if b.SandboxID.Valid {
+ s := id.FormatSandboxID(b.SandboxID)
+ resp.SandboxID = &s
+ }
+ if b.HostID.Valid {
+ s := id.FormatHostID(b.HostID)
+ resp.HostID = &s
+ }
+ if b.CreatedAt.Valid {
+ resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
+ }
+ if b.StartedAt.Valid {
+ s := b.StartedAt.Time.Format(time.RFC3339)
+ resp.StartedAt = &s
+ }
+ if b.CompletedAt.Valid {
+ s := b.CompletedAt.Time.Format(time.RFC3339)
+ resp.CompletedAt = &s
+ }
+ return resp
+}
+
+// Create handles POST /v1/admin/builds.
+func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
+ var req createBuildRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
+ return
+ }
+
+ if req.Name == "" {
+ writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
+ return
+ }
+ if err := validate.SafeName(req.Name); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
+ return
+ }
+ if len(req.Recipe) == 0 {
+ writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
+ return
+ }
+
+ build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
+ Name: req.Name,
+ BaseTemplate: req.BaseTemplate,
+ Recipe: req.Recipe,
+ Healthcheck: req.Healthcheck,
+ VCPUs: req.VCPUs,
+ MemoryMB: req.MemoryMB,
+ SkipPrePost: req.SkipPrePost,
+ })
+ if err != nil {
+ slog.Error("failed to create build", "error", err)
+ writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
+ return
+ }
+
+ writeJSON(w, http.StatusCreated, buildToResponse(build))
+}
+
+// List handles GET /v1/admin/builds.
+func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
+ builds, err := h.svc.List(r.Context())
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
+ return
+ }
+
+ resp := make([]buildResponse, len(builds))
+ for i, b := range builds {
+ resp[i] = buildToResponse(b)
+ }
+
+ writeJSON(w, http.StatusOK, resp)
+}
+
+// Get handles GET /v1/admin/builds/{id}.
+func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
+ buildIDStr := chi.URLParam(r, "id")
+
+ buildID, err := id.ParseBuildID(buildIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
+ return
+ }
+
+ build, err := h.svc.Get(r.Context(), buildID)
+ if err != nil {
+ writeError(w, http.StatusNotFound, "not_found", "build not found")
+ return
+ }
+
+ writeJSON(w, http.StatusOK, buildToResponse(build))
+}
+
+// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
+func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
+ templates, err := h.db.ListTemplates(r.Context())
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
+ return
+ }
+
+ type templateResponse struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ VCPUs int32 `json:"vcpus"`
+ MemoryMB int32 `json:"memory_mb"`
+ SizeBytes int64 `json:"size_bytes"`
+ TeamID string `json:"team_id"`
+ CreatedAt string `json:"created_at"`
+ }
+
+ resp := make([]templateResponse, len(templates))
+ for i, t := range templates {
+ resp[i] = templateResponse{
+ Name: t.Name,
+ Type: t.Type,
+ VCPUs: t.Vcpus,
+ MemoryMB: t.MemoryMb,
+ SizeBytes: t.SizeBytes,
+ TeamID: id.FormatTeamID(t.TeamID),
+ }
+ if t.CreatedAt.Valid {
+ resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
+ }
+ }
+
+ writeJSON(w, http.StatusOK, resp)
+}
+
+// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
+func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
+ name := chi.URLParam(r, "name")
+ if err := validate.SafeName(name); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
+ return
+ }
+ ctx := r.Context()
+
+ tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
+ if err != nil {
+ writeError(w, http.StatusNotFound, "not_found", "template not found")
+ return
+ }
+ if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
+ writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
+ return
+ }
+
+ // Broadcast delete to all online hosts.
+ hosts, _ := h.db.ListActiveHosts(ctx)
+ for _, host := range hosts {
+ if host.Status != "online" {
+ continue
+ }
+ agent, err := h.pool.GetForHost(host)
+ if err != nil {
+ continue
+ }
+ if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
+ TeamId: formatUUIDForRPC(tmpl.TeamID),
+ TemplateId: formatUUIDForRPC(tmpl.ID),
+ })); err != nil {
+ if connect.CodeOf(err) != connect.CodeNotFound {
+ slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
+ }
+ }
+ }
+
+ if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
+ writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
+ return
+ }
+
+ w.WriteHeader(http.StatusNoContent)
+}
+
+// Cancel handles POST /v1/admin/builds/{id}/cancel.
+func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
+ buildIDStr := chi.URLParam(r, "id")
+
+ buildID, err := id.ParseBuildID(buildIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
+ return
+ }
+
+ if err := h.svc.Cancel(r.Context(), buildID); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
+ return
+ }
+
+ w.WriteHeader(http.StatusNoContent)
+}
diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go
index 84b38330..596457b2 100644
--- a/internal/api/handlers_exec.go
+++ b/internal/api/handlers_exec.go
@@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -46,10 +47,16 @@ type execResponse struct {
// Exec handles POST /v1/sandboxes/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -80,7 +87,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
@@ -101,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
- slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
+ slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
// Use base64 encoding if output contains non-UTF-8 bytes.
@@ -112,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
- SandboxID: sandboxID,
+ SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
@@ -124,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, execResponse{
- SandboxID: sandboxID,
+ SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),
diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go
index 3ecfdfe7..52dfd17e 100644
--- a/internal/api/handlers_exec_stream.go
+++ b/internal/api/handlers_exec_stream.go
@@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -48,10 +49,16 @@ type wsOutMsg struct {
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -91,7 +98,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
defer cancel()
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Cmd: startMsg.Cmd,
Args: startMsg.Args,
}))
@@ -157,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
- slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
+ slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
}
diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go
index c5fff70d..a2e9936c 100644
--- a/internal/api/handlers_files.go
+++ b/internal/api/handlers_files.go
@@ -11,6 +11,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -82,7 +89,7 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Path: filePath,
Content: content,
})); err != nil {
@@ -101,10 +108,16 @@ type readFileRequest struct {
// Download handles POST /v1/sandboxes/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -133,7 +146,7 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {
diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go
index 66e89c73..e6c040f2 100644
--- a/internal/api/handlers_files_stream.go
+++ b/internal/api/handlers_files_stream.go
@@ -12,6 +12,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -29,10 +30,16 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -101,7 +108,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
if err := stream.Send(&pb.WriteFileStreamRequest{
Content: &pb.WriteFileStreamRequest_Meta{
Meta: &pb.WriteFileStreamMeta{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Path: filePath,
},
},
@@ -146,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -178,7 +191,7 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
// Open server-streaming RPC to host agent.
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {
diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go
index f4f79173..50652a00 100644
--- a/internal/api/handlers_hosts.go
+++ b/internal/api/handlers_hosts.go
@@ -8,9 +8,12 @@ import (
"github.com/go-chi/chi/v5"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@@ -46,6 +49,9 @@ type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
+ CertPEM string `json:"cert_pem,omitempty"`
+ KeyPEM string `json:"key_pem,omitempty"`
+ CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type deletePreviewResponse struct {
@@ -66,6 +72,9 @@ type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
+ CertPEM string `json:"cert_pem,omitempty"`
+ KeyPEM string `json:"key_pem,omitempty"`
+ CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type addTagRequest struct {
@@ -93,34 +102,35 @@ type hostResponse struct {
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
- ID: h.ID,
+ ID: id.FormatHostID(h.ID),
Type: h.Type,
Status: h.Status,
- CreatedBy: h.CreatedBy,
+ CreatedBy: id.FormatUserID(h.CreatedBy),
}
if h.TeamID.Valid {
- resp.TeamID = &h.TeamID.String
+ s := id.FormatTeamID(h.TeamID)
+ resp.TeamID = &s
}
- if h.Provider.Valid {
- resp.Provider = &h.Provider.String
+ if h.Provider != "" {
+ resp.Provider = &h.Provider
}
- if h.AvailabilityZone.Valid {
- resp.AvailabilityZone = &h.AvailabilityZone.String
+ if h.AvailabilityZone != "" {
+ resp.AvailabilityZone = &h.AvailabilityZone
}
- if h.Arch.Valid {
- resp.Arch = &h.Arch.String
+ if h.Arch != "" {
+ resp.Arch = &h.Arch
}
- if h.CpuCores.Valid {
- resp.CPUCores = &h.CpuCores.Int32
+ if h.CpuCores != 0 {
+ resp.CPUCores = &h.CpuCores
}
- if h.MemoryMb.Valid {
- resp.MemoryMB = &h.MemoryMb.Int32
+ if h.MemoryMb != 0 {
+ resp.MemoryMB = &h.MemoryMb
}
- if h.DiskGb.Valid {
- resp.DiskGB = &h.DiskGb.Int32
+ if h.DiskGb != 0 {
+ resp.DiskGB = &h.DiskGb
}
- if h.Address.Valid {
- resp.Address = &h.Address.String
+ if h.Address != "" {
+ resp.Address = &h.Address
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
@@ -133,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
}
// isAdmin fetches the user record and returns whether they are an admin.
-func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
+func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
@@ -151,14 +161,23 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
- result, err := h.svc.Create(r.Context(), service.HostCreateParams{
- Type: req.Type,
- TeamID: req.TeamID,
- Provider: req.Provider,
- AvailabilityZone: req.AvailabilityZone,
- RequestingUserID: ac.UserID,
- IsRequestorAdmin: h.isAdmin(r, ac.UserID),
- })
+ // Parse optional team ID from request body.
+ var params service.HostCreateParams
+ params.Type = req.Type
+ params.Provider = req.Provider
+ params.AvailabilityZone = req.AvailabilityZone
+ params.RequestingUserID = ac.UserID
+ params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
+ if req.TeamID != "" {
+ teamID, err := id.ParseTeamID(req.TeamID)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
+ return
+ }
+ params.TeamID = teamID
+ }
+
+ result, err := h.svc.Create(r.Context(), params)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@@ -166,8 +185,7 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
- hostTeamID := result.Host.TeamID.String
- h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, hostTeamID)
+ h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
@@ -192,14 +210,22 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
- seen[host.TeamID.String] = struct{}{}
+ key := id.FormatTeamID(host.TeamID)
+ seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
- for id := range seen {
- if team, err := h.queries.GetTeam(r.Context(), id); err == nil {
- teamNames[id] = team.Name
+ for _, host := range hosts {
+ if !host.TeamID.Valid {
+ continue
+ }
+ key := id.FormatTeamID(host.TeamID)
+ if _, ok := teamNames[key]; ok {
+ continue
+ }
+ if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
+ teamNames[key] = team.Name
}
}
}
@@ -209,7 +235,8 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
for i, host := range hosts {
resp[i] = hostToResponse(host)
if host.TeamID.Valid {
- if name, ok := teamNames[host.TeamID.String]; ok {
+ key := id.FormatTeamID(host.TeamID)
+ if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
@@ -220,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -236,9 +269,15 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -256,19 +295,25 @@ func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
// Fetch host before deletion to capture team_id for audit.
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
if hostErr != nil {
- slog.Warn("audit: could not fetch host before delete", "host_id", hostID, "error", hostErr)
+ slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
}
- err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
+ err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
- h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID.String)
+ h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
w.WriteHeader(http.StatusNoContent)
return
}
@@ -292,9 +337,15 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -343,14 +394,23 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
+ CertPEM: result.CertPEM,
+ KeyPEM: result.KeyPEM,
+ CACertPEM: result.CACertPEM,
})
}
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
@@ -368,7 +428,7 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// Log marked_up if the host just recovered from unreachable.
if prevHost.Status == "unreachable" {
- h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID.String, hc.HostID)
+ h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
}
w.WriteHeader(http.StatusNoContent)
@@ -376,10 +436,16 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
@@ -401,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@@ -438,14 +510,23 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
+ CertPEM: result.CertPEM,
+ KeyPEM: result.KeyPEM,
+ CACertPEM: result.CACertPEM,
})
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
- hostID := chi.URLParam(r, "id")
+ hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ hostID, err := id.ParseHostID(hostIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
+ return
+ }
+
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
diff --git a/internal/api/handlers_metrics.go b/internal/api/handlers_metrics.go
index 793349e5..25f485c6 100644
--- a/internal/api/handlers_metrics.go
+++ b/internal/api/handlers_metrics.go
@@ -7,9 +7,11 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
+ "github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -38,10 +40,16 @@ type metricsResponse struct {
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
rangeTier := r.URL.Query().Get("range")
if rangeTier == "" {
rangeTier = "10m"
@@ -60,15 +68,15 @@ func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Reques
switch sb.Status {
case "running":
- h.getFromAgent(w, r, sandboxID, rangeTier, sb.HostID)
+ h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
case "paused":
- h.getFromDB(ctx, w, sandboxID, rangeTier)
+ h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
default:
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
}
}
-func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxID, rangeTier, hostID string) {
+func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
ctx := r.Context()
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
@@ -78,7 +86,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Range: rangeTier,
}))
if err != nil {
@@ -98,7 +106,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
writeJSON(w, http.StatusOK, metricsResponse{
- SandboxID: sandboxID,
+ SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
@@ -118,7 +126,7 @@ var rangeToDB = map[string]struct {
"24h": {"24h", 24 * time.Hour},
}
-func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxID, rangeTier string) {
+func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
mapping := rangeToDB[rangeTier]
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
SandboxID: sandboxID,
@@ -141,7 +149,7 @@ func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWr
}
writeJSON(w, http.StatusOK, metricsResponse{
- SandboxID: sandboxID,
+ SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
diff --git a/internal/api/handlers_oauth.go b/internal/api/handlers_oauth.go
index 348dd859..a9c448ac 100644
--- a/internal/api/handlers_oauth.go
+++ b/internal/api/handlers_oauth.go
@@ -162,7 +162,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "internal_error")
return
}
- redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
+ redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@@ -262,7 +262,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
- redirectWithToken(w, r, redirectBase, token, userID, teamID, email, profile.Name)
+ redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@@ -296,7 +296,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "internal_error")
return
}
- redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
+ redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
diff --git a/internal/api/handlers_sandbox.go b/internal/api/handlers_sandbox.go
index b2709a5d..a19a7cc3 100644
--- a/internal/api/handlers_sandbox.go
+++ b/internal/api/handlers_sandbox.go
@@ -10,6 +10,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@@ -46,7 +47,7 @@ type sandboxResponse struct {
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
resp := sandboxResponse{
- ID: sb.ID,
+ ID: id.FormatSandboxID(sb.ID),
Status: sb.Status,
Template: sb.Template,
VCPUs: sb.Vcpus,
@@ -81,7 +82,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
}
ac := auth.MustFromContext(r.Context())
- if ac.TeamID == "" {
+ if !ac.TeamID.Valid {
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
return
}
@@ -122,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/sandboxes/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@@ -136,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
// Pause handles POST /v1/sandboxes/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -152,9 +165,15 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
// Resume handles POST /v1/sandboxes/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -168,9 +187,15 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
// Ping handles POST /v1/sandboxes/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@@ -182,9 +207,15 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
// Destroy handles DELETE /v1/sandboxes/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
- sandboxID := chi.URLParam(r, "id")
+ sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
+ sandboxID, err := id.ParseSandboxID(sandboxIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
+ return
+ }
+
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go
index fbfcdc18..f7d05f2f 100644
--- a/internal/api/handlers_snapshots.go
+++ b/internal/api/handlers_snapshots.go
@@ -10,12 +10,14 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
+
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
@@ -35,8 +37,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
-// and ignore NotFound errors. TODO: add host_id to templates table.
-func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error {
+// and ignore NotFound errors.
+func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
@@ -49,9 +51,12 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri
if err != nil {
continue
}
- if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
+ if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
+ TeamId: formatUUIDForRPC(teamID),
+ TemplateId: formatUUIDForRPC(templateID),
+ })); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
- slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
+ slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
@@ -70,6 +75,7 @@ type snapshotResponse struct {
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
+ Platform bool `json:"platform"`
}
func templateToResponse(t db.Template) snapshotResponse {
@@ -77,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
Name: t.Name,
Type: t.Type,
SizeBytes: t.SizeBytes,
+ Platform: t.TeamID == id.PlatformTeamID,
}
- if t.Vcpus.Valid {
- resp.VCPUs = &t.Vcpus.Int32
+ if t.Vcpus != 0 {
+ resp.VCPUs = &t.Vcpus
}
- if t.MemoryMb.Valid {
- resp.MemoryMB = &t.MemoryMb.Int32
+ if t.MemoryMb != 0 {
+ resp.MemoryMB = &t.MemoryMb
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@@ -103,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
+ sandboxID, err := id.ParseSandboxID(req.SandboxID)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
+ return
+ }
+
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
@@ -115,14 +128,20 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
+ // Check for global name collision.
+ if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
+ writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
+ return
+ }
+
// Check if name already exists for this team.
- if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
+ if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old snapshot files from all hosts before removing the DB record.
- if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil {
+ if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
@@ -133,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Verify sandbox exists, belongs to team, and is running or paused.
- sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
+ sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
@@ -149,30 +168,53 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
- resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
- SandboxId: req.SandboxID,
- Name: req.Name,
+ // Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
+ // The host agent's CreateSnapshot removes the sandbox from its in-memory
+ // map immediately; if the reconciler fires during the flatten window and
+ // the DB still says "running", it will mark the sandbox "stopped".
+ if sb.Status == "running" {
+ if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
+ ID: sandboxID, Status: "paused",
+ }); err != nil {
+ writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
+ return
+ }
+ }
+
+ // Use a detached context with a generous timeout so the snapshot completes
+ // even if the client disconnects (the flatten step can take 10-20s).
+ snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer snapCancel()
+
+ // Generate the new template ID upfront so the host agent knows where to store files.
+ newTemplateID := id.NewTemplateID()
+
+ resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
+ SandboxId: req.SandboxID,
+ Name: req.Name,
+ TeamId: formatUUIDForRPC(ac.TeamID),
+ TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
+ // Snapshot failed — revert status back to what it was.
+ if sb.Status == "running" {
+ if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
+ ID: sandboxID, Status: "running",
+ }); dbErr != nil {
+ slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
+ }
+ }
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
- // Mark sandbox as paused (if it was running, it got paused by the snapshot).
- if sb.Status != "paused" {
- if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
- ID: req.SandboxID, Status: "paused",
- }); err != nil {
- slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
- }
- }
-
- tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
+ tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
+ ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
- Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
- MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
+ Vcpus: sb.Vcpus,
+ MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
})
@@ -182,7 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
- h.audit.LogSnapshotCreate(r.Context(), ac, req.Name)
+ h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
+
+ if ctx.Err() != nil {
+ slog.Info("snapshot created but client disconnected before response", "name", req.Name)
+ return
+ }
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
}
@@ -215,12 +262,22 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
- if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
+ tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
+ if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
+ // Platform templates can only be deleted by admins via /v1/admin/templates.
+ if tmpl.TeamID == id.PlatformTeamID {
+ writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
+ return
+ }
+ if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
+ writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
+ return
+ }
- if err := h.deleteSnapshotBroadcast(ctx, name); err != nil {
+ if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}
diff --git a/internal/api/handlers_team.go b/internal/api/handlers_team.go
index 3950ab73..2bf99f9a 100644
--- a/internal/api/handlers_team.go
+++ b/internal/api/handlers_team.go
@@ -7,10 +7,12 @@ import (
"time"
"github.com/go-chi/chi/v5"
+ "github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@@ -48,7 +50,7 @@ type memberResponse struct {
func teamToResponse(t db.Team) teamResponse {
resp := teamResponse{
- ID: t.ID,
+ ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
@@ -72,11 +74,16 @@ func memberInfoToResponse(m service.MemberInfo) memberResponse {
// requireTeamAccess is an inline check used by every team-scoped handler:
// the JWT team_id must match the URL {id} before any DB call is made.
// Returns false and writes 403 if they don't match.
-func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (string, bool) {
- teamID := chi.URLParam(r, "id")
+func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
+ teamIDStr := chi.URLParam(r, "id")
+ teamID, err := id.ParseTeamID(teamIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
+ return pgtype.UUID{}, false
+ }
if ac.TeamID != teamID {
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
- return "", false
+ return pgtype.UUID{}, false
}
return teamID, true
}
@@ -185,7 +192,7 @@ func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
// Fetch old name for audit log before renaming.
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
- slog.Warn("audit: could not fetch old team name for rename log", "team_id", teamID, "error", err)
+ slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
@@ -267,7 +274,11 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
return
}
- h.audit.LogMemberAdd(r.Context(), ac, member.UserID, member.Email, member.Role)
+ // member.UserID is already formatted with prefix; parse it back for the audit logger.
+ targetUserID, parseErr := id.ParseUserID(member.UserID)
+ if parseErr == nil {
+ h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
+ }
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
@@ -279,7 +290,13 @@ func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
- targetUserID := chi.URLParam(r, "uid")
+ targetUserIDStr := chi.URLParam(r, "uid")
+
+ targetUserID, err := id.ParseUserID(targetUserIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
+ return
+ }
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
@@ -299,7 +316,13 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
- targetUserID := chi.URLParam(r, "uid")
+ targetUserIDStr := chi.URLParam(r, "uid")
+
+ targetUserID, err := id.ParseUserID(targetUserIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
+ return
+ }
var req struct {
Role string `json:"role"`
@@ -341,7 +364,13 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
// Enables or disables the BYOC feature flag for a team.
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
- teamID := chi.URLParam(r, "id")
+ teamIDStr := chi.URLParam(r, "id")
+
+ teamID, err := id.ParseTeamID(teamIDStr)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
+ return
+ }
var req struct {
Enabled bool `json:"enabled"`
diff --git a/internal/api/handlers_users.go b/internal/api/handlers_users.go
index 8269d3c0..17050641 100644
--- a/internal/api/handlers_users.go
+++ b/internal/api/handlers_users.go
@@ -8,6 +8,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
type usersHandler struct {
@@ -45,7 +46,7 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
}
resp := make([]userResult, len(results))
for i, u := range results {
- resp[i] = userResult{UserID: u.ID, Email: u.Email}
+ resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
}
writeJSON(w, http.StatusOK, resp)
}
diff --git a/internal/api/host_monitor.go b/internal/api/host_monitor.go
index 4bf19d87..95fde106 100644
--- a/internal/api/host_monitor.go
+++ b/internal/api/host_monitor.go
@@ -6,9 +6,11 @@ import (
"time"
"connectrpc.com/connect"
+ "github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@@ -82,15 +84,15 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
- slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
+ slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
- slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
- slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
}
- m.audit.LogHostMarkedDown(ctx, host.TeamID.String, host.ID)
+ m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
return
}
@@ -110,19 +112,20 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
- slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
+ slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
// Build set of sandbox IDs alive on the host.
+ // The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
- for _, id := range resp.Msg.AutoPausedSandboxIds {
- autoPaused[id] = struct{}{}
+ for _, apID := range resp.Msg.AutoPausedSandboxIds {
+ autoPaused[apID] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
@@ -134,30 +137,31 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"missing"},
})
if err != nil {
- slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
- var toRestore []string
- var toStop []string
+ var toRestore []pgtype.UUID
+ var toStop []pgtype.UUID
for _, sb := range missingSandboxes {
- if _, ok := alive[sb.ID]; ok {
+ sbIDStr := id.FormatSandboxID(sb.ID)
+ if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
- slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
+ slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
- slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
if len(toStop) > 0 {
- slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
+ slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
- slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
@@ -169,18 +173,19 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"running"},
})
if err != nil {
- slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
- var toPause, toStop []string
- sbTeamID := make(map[string]string, len(runningSandboxes))
+ var toPause, toStop []pgtype.UUID
+ sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
for _, sb := range runningSandboxes {
+ sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
- if _, ok := alive[sb.ID]; ok {
+ if _, ok := alive[sbIDStr]; ok {
continue
}
- if _, ok := autoPaused[sb.ID]; ok {
+ if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
@@ -188,24 +193,24 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
}
if len(toPause) > 0 {
- slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
+ slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
- slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
- slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
+ slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
- slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
+ slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
diff --git a/internal/api/middleware.go b/internal/api/middleware.go
index 6a56293c..5c9d8cdf 100644
--- a/internal/api/middleware.go
+++ b/internal/api/middleware.go
@@ -12,6 +12,9 @@ import (
"time"
"connectrpc.com/connect"
+ "github.com/jackc/pgx/v5/pgtype"
+
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
type errorResponse struct {
@@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
})
}
+// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
+func formatUUIDForRPC(u pgtype.UUID) string {
+ return id.UUIDString(u)
+}
+
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
func agentErrToHTTP(err error) (int, string, string) {
switch connect.CodeOf(err) {
diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go
index dee4240f..985b2899 100644
--- a/internal/api/middleware_auth.go
+++ b/internal/api/middleware_auth.go
@@ -7,6 +7,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@@ -24,7 +25,7 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
- slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
+ slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
@@ -45,9 +46,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
+ teamID, err := id.ParseTeamID(claims.TeamID)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
+ return
+ }
+ userID, err := id.ParseUserID(claims.Subject)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
+ return
+ }
+
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
- TeamID: claims.TeamID,
- UserID: claims.Subject,
+ TeamID: teamID,
+ UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
diff --git a/internal/api/middleware_hosttoken.go b/internal/api/middleware_hosttoken.go
index a5c5e6f1..b926e41f 100644
--- a/internal/api/middleware_hosttoken.go
+++ b/internal/api/middleware_hosttoken.go
@@ -4,6 +4,7 @@ import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
@@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return
}
- ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
+ hostID, err := id.ParseHostID(claims.HostID)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
+ return
+ }
+
+ ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go
index 96b1c68a..37215388 100644
--- a/internal/api/middleware_jwt.go
+++ b/internal/api/middleware_jwt.go
@@ -5,6 +5,7 @@ import (
"strings"
"git.omukk.dev/wrenn/sandbox/internal/auth"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireJWT validates the Authorization: Bearer header, verifies the JWT
@@ -25,9 +26,20 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
+ teamID, err := id.ParseTeamID(claims.TeamID)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
+ return
+ }
+ userID, err := id.ParseUserID(claims.Subject)
+ if err != nil {
+ writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
+ return
+ }
+
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
- TeamID: claims.TeamID,
- UserID: claims.Subject,
+ TeamID: teamID,
+ UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
diff --git a/internal/api/server.go b/internal/api/server.go
index 918476bf..5d854b90 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/audit"
+ "git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
@@ -22,7 +23,8 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
- router chi.Router
+ router chi.Router
+ BuildSvc *service.BuildService
}
// New constructs the chi router and registers all routes.
@@ -35,6 +37,7 @@ func New(
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
+ ca *auth.CA,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
@@ -43,10 +46,11 @@ func New(
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
- hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
+ hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
+ buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
al := audit.New(queries)
@@ -65,6 +69,7 @@ func New(
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
+ buildH := newBuildHandler(buildSvc, queries, pool)
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@@ -174,9 +179,15 @@ func New(
r.Use(requireJWT(jwtSecret))
r.Use(requireAdmin(queries))
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
+ r.Get("/templates", buildH.ListTemplates)
+ r.Delete("/templates/{name}", buildH.DeleteTemplate)
+ r.Post("/builds", buildH.Create)
+ r.Get("/builds", buildH.List)
+ r.Get("/builds/{id}", buildH.Get)
+ r.Post("/builds/{id}/cancel", buildH.Cancel)
})
- return &Server{router: r}
+ return &Server{router: r, BuildSvc: buildSvc}
}
// Handler returns the HTTP handler.
diff --git a/internal/audit/logger.go b/internal/audit/logger.go
index 8f44059f..e60d8a6c 100644
--- a/internal/audit/logger.go
+++ b/internal/audit/logger.go
@@ -25,18 +25,15 @@ func New(queries *db.Queries) *AuditLogger {
}
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
-func actorFields(ac auth.AuthContext) (actorType string, actorID pgtype.Text, actorName pgtype.Text) {
- if ac.UserID != "" {
- return "user",
- pgtype.Text{String: ac.UserID, Valid: true},
- pgtype.Text{String: ac.Name, Valid: ac.Name != ""}
+// actor_id is stored as a prefixed string in the TEXT column.
+func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
+ if ac.UserID.Valid {
+ return "user", id.FormatUserID(ac.UserID), ac.Name
}
- if ac.APIKeyID != "" {
- return "api_key",
- pgtype.Text{String: ac.APIKeyID, Valid: true},
- pgtype.Text{String: ac.APIKeyName, Valid: true}
+ if ac.APIKeyID.Valid {
+ return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
}
- return "system", pgtype.Text{}, pgtype.Text{}
+ return "system", "", ""
}
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
@@ -44,7 +41,6 @@ func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
slog.Warn("audit: failed to write log entry",
"action", p.Action,
"resource_type", p.ResourceType,
- "team_id", p.TeamID,
"error", err,
)
}
@@ -61,18 +57,26 @@ func marshalMeta(meta map[string]any) []byte {
return b
}
+// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
+func optText(s string) pgtype.Text {
+ if s == "" {
+ return pgtype.Text{}
+ }
+ return pgtype.Text{String: s, Valid: true}
+}
+
// --- Sandbox events (scope: team) ---
-func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID, template string) {
+func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
- ResourceID: pgtype.Text{String: sandboxID, Valid: true},
+ ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "create",
Scope: "team",
Status: "success",
@@ -80,16 +84,16 @@ func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext,
})
}
-func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID string) {
+func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
- ResourceID: pgtype.Text{String: sandboxID, Valid: true},
+ ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "success",
@@ -98,15 +102,15 @@ func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext,
}
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
-func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID string) {
+func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
- ActorName: pgtype.Text{},
+ ActorName: "",
ResourceType: "sandbox",
- ResourceID: pgtype.Text{String: sandboxID, Valid: true},
+ ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "info",
@@ -114,16 +118,16 @@ func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID
})
}
-func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID string) {
+func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
- ResourceID: pgtype.Text{String: sandboxID, Valid: true},
+ ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "resume",
Scope: "team",
Status: "success",
@@ -131,16 +135,16 @@ func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext,
})
}
-func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID string) {
+func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
- ResourceID: pgtype.Text{String: sandboxID, Valid: true},
+ ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "destroy",
Scope: "team",
Status: "warning",
@@ -156,10 +160,10 @@ func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
- ResourceID: pgtype.Text{String: name, Valid: true},
+ ResourceID: optText(name),
Action: "create",
Scope: "team",
Status: "success",
@@ -173,10 +177,10 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
- ResourceID: pgtype.Text{String: name, Valid: true},
+ ResourceID: optText(name),
Action: "delete",
Scope: "team",
Status: "warning",
@@ -186,16 +190,16 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext
// --- Team events (scope: team) ---
-func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID, oldName, newName string) {
+func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "team",
- ResourceID: pgtype.Text{String: teamID, Valid: true},
+ ResourceID: optText(id.FormatTeamID(teamID)),
Action: "rename",
Scope: "team",
Status: "info",
@@ -205,16 +209,16 @@ func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, te
// --- API key events (scope: team) ---
-func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID, keyName string) {
+func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
- ResourceID: pgtype.Text{String: keyID, Valid: true},
+ ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "create",
Scope: "team",
Status: "success",
@@ -222,16 +226,16 @@ func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext,
})
}
-func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID string) {
+func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
- ResourceID: pgtype.Text{String: keyID, Valid: true},
+ ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "revoke",
Scope: "team",
Status: "warning",
@@ -241,16 +245,16 @@ func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext,
// --- Member events (scope: admin) ---
-func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID, targetEmail, role string) {
+func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
- ResourceID: pgtype.Text{String: targetUserID, Valid: true},
+ ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "add",
Scope: "admin",
Status: "success",
@@ -258,16 +262,16 @@ func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, tar
})
}
-func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID string) {
+func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
- ResourceID: pgtype.Text{String: targetUserID, Valid: true},
+ ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "remove",
Scope: "admin",
Status: "warning",
@@ -277,14 +281,18 @@ func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext,
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
actorType, actorID, actorName := actorFields(ac)
+ resourceID := ""
+ if ac.UserID.Valid {
+ resourceID = id.FormatUserID(ac.UserID)
+ }
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
- ResourceID: pgtype.Text{String: ac.UserID, Valid: ac.UserID != ""},
+ ResourceID: optText(resourceID),
Action: "leave",
Scope: "admin",
Status: "info",
@@ -292,16 +300,16 @@ func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
})
}
-func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID, newRole string) {
+func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
- ResourceID: pgtype.Text{String: targetUserID, Valid: true},
+ ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "role_update",
Scope: "admin",
Status: "info",
@@ -311,24 +319,24 @@ func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthConte
// --- Host events (scope: admin) ---
-func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID string) {
+func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
// For shared hosts with no owning team, use the caller's team.
logTeamID := teamID
- if logTeamID == "" {
+ if !logTeamID.Valid {
logTeamID = ac.TeamID
}
- if logTeamID == "" {
+ if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
- ResourceID: pgtype.Text{String: hostID, Valid: true},
+ ResourceID: optText(id.FormatHostID(hostID)),
Action: "create",
Scope: "admin",
Status: "success",
@@ -336,23 +344,23 @@ func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, ho
})
}
-func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID string) {
+func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
logTeamID := teamID
- if logTeamID == "" {
+ if !logTeamID.Valid {
logTeamID = ac.TeamID
}
- if logTeamID == "" {
+ if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
- ActorID: actorID,
+ ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
- ResourceID: pgtype.Text{String: hostID, Valid: true},
+ ResourceID: optText(id.FormatHostID(hostID)),
Action: "delete",
Scope: "admin",
Status: "warning",
@@ -361,9 +369,8 @@ func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, ho
}
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
-// teamID must be non-empty (BYOC hosts only); shared hosts are not logged.
-func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID string) {
- if teamID == "" {
+func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
+ if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
@@ -371,9 +378,9 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
- ActorName: pgtype.Text{},
+ ActorName: "",
ResourceType: "host",
- ResourceID: pgtype.Text{String: hostID, Valid: true},
+ ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_down",
Scope: "admin",
Status: "error",
@@ -382,9 +389,8 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri
}
// LogHostMarkedUp records a system-initiated host status transition back to online.
-// teamID must be non-empty (BYOC hosts only); shared hosts are not logged.
-func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string) {
- if teamID == "" {
+func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
+ if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
@@ -392,9 +398,9 @@ func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
- ActorName: pgtype.Text{},
+ ActorName: "",
ResourceType: "host",
- ResourceID: pgtype.Text{String: hostID, Valid: true},
+ ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_up",
Scope: "admin",
Status: "success",
diff --git a/internal/auth/cert.go b/internal/auth/cert.go
new file mode 100644
index 00000000..d76f1de3
--- /dev/null
+++ b/internal/auth/cert.go
@@ -0,0 +1,251 @@
+package auth
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "math/big"
+ "net"
+ "sync/atomic"
+ "time"
+)
+
+// CPCertRenewInterval is how often the control plane should renew its client
+// certificate. It is set to half the cert TTL so there is always a wide safety
+// margin before expiry.
+const CPCertRenewInterval = cpCertTTL / 2
+
+const (
+ hostCertTTL = 7 * 24 * time.Hour
+ cpCertTTL = 24 * time.Hour
+)
+
+// CA holds a parsed certificate authority ready to issue leaf certificates.
+type CA struct {
+ Cert *x509.Certificate
+ Key *ecdsa.PrivateKey
+ PEM string // PEM-encoded certificate for embedding in register/refresh responses
+}
+
+// ParseCA parses PEM-encoded CA certificate and private key strings.
+// The cert and key are expected to be ECDSA P-256.
+func ParseCA(certPEM, keyPEM string) (*CA, error) {
+ certBlock, _ := pem.Decode([]byte(certPEM))
+ if certBlock == nil {
+ return nil, fmt.Errorf("failed to decode CA certificate PEM")
+ }
+ cert, err := x509.ParseCertificate(certBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("parse CA certificate: %w", err)
+ }
+
+ keyBlock, _ := pem.Decode([]byte(keyPEM))
+ if keyBlock == nil {
+ return nil, fmt.Errorf("failed to decode CA key PEM")
+ }
+ keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("parse CA private key: %w", err)
+ }
+
+ return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
+}
+
+// HostCert holds all material returned when issuing a leaf cert for a host agent.
+type HostCert struct {
+ CertPEM string
+ KeyPEM string
+ Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
+ ExpiresAt time.Time // stored in hosts.cert_expires_at
+ TLSCert tls.Certificate
+}
+
+// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
+// certificate for the host agent. hostID becomes the common name; the host's
+// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
+// stack can verify the connection without disabling hostname checking.
+func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return HostCert{}, fmt.Errorf("generate host key: %w", err)
+ }
+
+ serial, err := randomSerial()
+ if err != nil {
+ return HostCert{}, err
+ }
+
+ now := time.Now()
+ expires := now.Add(hostCertTTL)
+
+ tmpl := &x509.Certificate{
+ SerialNumber: serial,
+ Subject: pkix.Name{CommonName: hostID},
+ NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
+ NotAfter: expires,
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ }
+
+ // Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
+ host, _, err := net.SplitHostPort(hostAddr)
+ if err != nil {
+ host = hostAddr
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ tmpl.IPAddresses = []net.IP{ip}
+ } else {
+ tmpl.DNSNames = []string{host}
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
+ if err != nil {
+ return HostCert{}, fmt.Errorf("create host certificate: %w", err)
+ }
+
+ certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
+ keyDER, err := x509.MarshalECPrivateKey(key)
+ if err != nil {
+ return HostCert{}, fmt.Errorf("marshal host key: %w", err)
+ }
+ keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
+
+ tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
+ if err != nil {
+ return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
+ }
+
+ fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
+
+ return HostCert{
+ CertPEM: certPEM,
+ KeyPEM: keyPEM,
+ Fingerprint: fp,
+ ExpiresAt: expires,
+ TLSCert: tlsCert,
+ }, nil
+}
+
+// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
+// the control plane to present during mTLS handshakes with host agents.
+// Called once at CP startup; the result is embedded into the shared HTTP client.
+func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
+ }
+
+ serial, err := randomSerial()
+ if err != nil {
+ return tls.Certificate{}, err
+ }
+
+ now := time.Now()
+ tmpl := &x509.Certificate{
+ SerialNumber: serial,
+ Subject: pkix.Name{CommonName: "wrenn-cp"},
+ NotBefore: now.Add(-time.Minute),
+ NotAfter: now.Add(cpCertTTL),
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
+ }
+
+ certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ keyDER, err := x509.MarshalECPrivateKey(key)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
+ }
+ keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
+
+ return tls.X509KeyPair(certPEM, keyPEM)
+}
+
+// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
+// PEM-encoded CA certificate. This is used on the agent side where only the
+// CA certificate (not the private key) is available.
+func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
+ return nil
+ }
+ return &tls.Config{
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ ClientCAs: pool,
+ GetCertificate: getCert,
+ MinVersion: tls.VersionTLS13,
+ }
+}
+
+// CPCertStore provides lock-free read/write access to the control plane's
+// current client TLS certificate. It is used with tls.Config.GetClientCertificate
+// to enable hot-swap without restarting the HTTP client.
+//
+// The zero value is not usable; use NewCPCertStore to create one.
+type CPCertStore struct {
+ ptr atomic.Pointer[tls.Certificate]
+ ca *CA
+}
+
+// NewCPCertStore issues an initial CP client certificate from ca and returns a
+// store that can renew it in place. Returns an error if the initial issuance fails.
+func NewCPCertStore(ca *CA) (*CPCertStore, error) {
+ s := &CPCertStore{ca: ca}
+ if err := s.Refresh(); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+// Refresh issues a fresh CP client certificate and atomically stores it.
+// If issuance fails the existing cert is unchanged.
+func (s *CPCertStore) Refresh() error {
+ cert, err := IssueCPClientCert(s.ca)
+ if err != nil {
+ return fmt.Errorf("renew CP client certificate: %w", err)
+ }
+ s.ptr.Store(&cert)
+ return nil
+}
+
+// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
+// per-handshake and always returns the most recently stored certificate.
+func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ cert := s.ptr.Load()
+ if cert == nil {
+ return nil, fmt.Errorf("no CP client certificate available")
+ }
+ return cert, nil
+}
+
+// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
+// It uses certStore.GetClientCertificate so the certificate can be renewed
+// without replacing the config or transport.
+func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
+ pool := x509.NewCertPool()
+ pool.AddCert(ca.Cert)
+ return &tls.Config{
+ RootCAs: pool,
+ GetClientCertificate: certStore.GetClientCertificate,
+ MinVersion: tls.VersionTLS13,
+ }
+}
+
+// randomSerial returns a random 128-bit certificate serial number.
+func randomSerial() (*big.Int, error) {
+ serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ return nil, fmt.Errorf("generate serial number: %w", err)
+ }
+ return serial, nil
+}
diff --git a/internal/auth/context.go b/internal/auth/context.go
index 22bf7957..762227cf 100644
--- a/internal/auth/context.go
+++ b/internal/auth/context.go
@@ -1,6 +1,10 @@
package auth
-import "context"
+import (
+ "context"
+
+ "github.com/jackc/pgx/v5/pgtype"
+)
type contextKey int
@@ -8,14 +12,14 @@ const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
- TeamID string
- UserID string // empty when authenticated via API key
- Email string // empty when authenticated via API key
- Name string // empty when authenticated via API key
- Role string // owner, admin, or member; empty when authenticated via API key
- IsAdmin bool // platform-level admin; always false when authenticated via API key
- APIKeyID string // populated when authenticated via API key; empty for JWT auth
- APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
+ TeamID pgtype.UUID
+ UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
+ Email string // empty when authenticated via API key
+ Name string // empty when authenticated via API key
+ Role string // owner, admin, or member; empty when authenticated via API key
+ IsAdmin bool // platform-level admin; always false when authenticated via API key
+ APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
+ APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
@@ -43,7 +47,7 @@ const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
- HostID string
+ HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.
diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go
index fd1bc02d..840cd3bb 100644
--- a/internal/auth/jwt.go
+++ b/internal/auth/jwt.go
@@ -5,6 +5,9 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
+ "github.com/jackc/pgx/v5/pgtype"
+
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
const jwtExpiry = 6 * time.Hour
@@ -23,16 +26,16 @@ type Claims struct {
}
// SignJWT signs a new 6-hour JWT for the given user.
-func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) {
+func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
- TeamID: teamID,
+ TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
- Subject: userID,
+ Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
@@ -70,14 +73,15 @@ type HostClaims struct {
jwt.RegisteredClaims
}
-// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
-func SignHostJWT(secret []byte, hostID string) (string, error) {
+// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
+func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
+ formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
- HostID: hostID,
+ HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
- Subject: hostID,
+ Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},
diff --git a/internal/config/config.go b/internal/config/config.go
index 7ef0aa69..e4e67403 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -13,6 +13,11 @@ type Config struct {
ListenAddr string
JWTSecret string
+ // mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
+ // disables cert issuance and leaves agent connections on plain HTTP (dev mode).
+ CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
+ CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
+
OAuthGitHubClientID string
OAuthGitHubClientSecret string
OAuthRedirectURL string
@@ -28,9 +33,12 @@ func Load() Config {
return Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
- ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
+ ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
+ CACert: os.Getenv("WRENN_CA_CERT"),
+ CAKey: os.Getenv("WRENN_CA_KEY"),
+
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
diff --git a/internal/db/api_keys.sql.go b/internal/db/api_keys.sql.go
index b4f0ffcc..4b8d3699 100644
--- a/internal/db/api_keys.sql.go
+++ b/internal/db/api_keys.sql.go
@@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
`
type DeleteAPIKeyParams struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
@@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_
`
type InsertAPIKeyParams struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
- Name string `json:"name"`
- KeyHash string `json:"key_hash"`
- KeyPrefix string `json:"key_prefix"`
- CreatedBy string `json:"created_by"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ Name string `json:"name"`
+ KeyHash string `json:"key_hash"`
+ KeyPrefix string `json:"key_prefix"`
+ CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
@@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
`
-func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) {
+func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
if err != nil {
return nil, err
@@ -126,18 +126,18 @@ ORDER BY k.created_at DESC
`
type ListAPIKeysByTeamWithCreatorRow struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
- CreatedBy string `json:"created_by"`
+ CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
CreatorEmail string `json:"creator_email"`
}
-func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) {
+func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
if err != nil {
return nil, err
@@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
`
-func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error {
+func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
return err
}
diff --git a/internal/db/audit.sql.go b/internal/db/audit.sql.go
index 9370eca9..69b2b8ce 100644
--- a/internal/db/audit.sql.go
+++ b/internal/db/audit.sql.go
@@ -17,11 +17,11 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
type InsertAuditLogParams struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
- ActorName pgtype.Text `json:"actor_name"`
+ ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
@@ -60,12 +60,12 @@ LIMIT $7
`
type ListAuditLogsParams struct {
- TeamID string `json:"team_id"`
+ TeamID pgtype.UUID `json:"team_id"`
Column2 []string `json:"column_2"`
Column3 []string `json:"column_3"`
Column4 []string `json:"column_4"`
Column5 pgtype.Timestamptz `json:"column_5"`
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Limit int32 `json:"limit"`
}
diff --git a/internal/db/host_refresh_tokens.sql.go b/internal/db/host_refresh_tokens.sql.go
index d02a0e7a..0ec162d9 100644
--- a/internal/db/host_refresh_tokens.sql.go
+++ b/internal/db/host_refresh_tokens.sql.go
@@ -47,8 +47,8 @@ RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
- ID string `json:"id"`
- HostID string `json:"host_id"`
+ ID pgtype.UUID `json:"id"`
+ HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
@@ -76,7 +76,7 @@ const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
-func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error {
+func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
@@ -86,7 +86,7 @@ UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
-func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error {
+func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}
diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go
index 2d7b8e0c..2e3962b5 100644
--- a/internal/db/hosts.sql.go
+++ b/internal/db/hosts.sql.go
@@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
`
type AddHostTagParams struct {
- HostID string `json:"host_id"`
- Tag string `json:"tag"`
+ HostID pgtype.UUID `json:"host_id"`
+ Tag string `json:"tag"`
}
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
@@ -29,16 +29,16 @@ const deleteHost = `-- name: DeleteHost :exec
DELETE FROM hosts WHERE id = $1
`
-func (q *Queries) DeleteHost(ctx context.Context, id string) error {
+func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteHost, id)
return err
}
const getHost = `-- name: GetHost :one
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
`
-func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
+func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
row := q.db.QueryRow(ctx, getHost, id)
var i Host
err := row.Scan(
@@ -59,18 +59,18 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
- ID string `json:"id"`
- TeamID pgtype.Text `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
@@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
)
return i, err
}
@@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
`
-func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) {
+func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
rows, err := q.db.Query(ctx, getHostTags, hostID)
if err != nil {
return nil, err
@@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
`
-func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) {
+func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
if err != nil {
return nil, err
@@ -157,16 +157,16 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
-RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
+RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
`
type InsertHostParams struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Type string `json:"type"`
- TeamID pgtype.Text `json:"team_id"`
- Provider pgtype.Text `json:"provider"`
- AvailabilityZone pgtype.Text `json:"availability_zone"`
- CreatedBy string `json:"created_by"`
+ TeamID pgtype.UUID `json:"team_id"`
+ Provider string `json:"provider"`
+ AvailabilityZone string `json:"availability_zone"`
+ CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
@@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
)
return i, err
}
@@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at
`
type InsertHostTokenParams struct {
- ID string `json:"id"`
- HostID string `json:"host_id"`
- CreatedBy string `json:"created_by"`
+ ID pgtype.UUID `json:"id"`
+ HostID pgtype.UUID `json:"host_id"`
+ CreatedBy pgtype.UUID `json:"created_by"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
@@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
}
const listActiveHosts = `-- name: ListActiveHosts :many
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
@@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
}
const listHosts = `-- name: ListHosts :many
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
@@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
@@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
}
const listHostsByTag = `-- name: ListHostsByTag :many
-SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
+SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
@@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -411,10 +411,10 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
-func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
+func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
if err != nil {
return nil, err
@@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
}
const listHostsByType = `-- name: ListHostsByType :many
-SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
+SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
@@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
- &i.MtlsEnabled,
+ &i.CertExpiresAt,
); err != nil {
return nil, err
}
@@ -500,7 +500,7 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
`
-func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
+func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
return err
}
@@ -509,31 +509,35 @@ const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
-func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error {
+func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
-SET arch = $2,
- cpu_cores = $3,
- memory_mb = $4,
- disk_gb = $5,
- address = $6,
- status = 'online',
+SET arch = $2,
+ cpu_cores = $3,
+ memory_mb = $4,
+ disk_gb = $5,
+ address = $6,
+ cert_fingerprint = $7,
+ cert_expires_at = $8,
+ status = 'online',
last_heartbeat_at = NOW(),
- updated_at = NOW()
+ updated_at = NOW()
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
- ID string `json:"id"`
- Arch pgtype.Text `json:"arch"`
- CpuCores pgtype.Int4 `json:"cpu_cores"`
- MemoryMb pgtype.Int4 `json:"memory_mb"`
- DiskGb pgtype.Int4 `json:"disk_gb"`
- Address pgtype.Text `json:"address"`
+ ID pgtype.UUID `json:"id"`
+ Arch string `json:"arch"`
+ CpuCores int32 `json:"cpu_cores"`
+ MemoryMb int32 `json:"memory_mb"`
+ DiskGb int32 `json:"disk_gb"`
+ Address string `json:"address"`
+ CertFingerprint string `json:"cert_fingerprint"`
+ CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
@@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
arg.MemoryMb,
arg.DiskGb,
arg.Address,
+ arg.CertFingerprint,
+ arg.CertExpiresAt,
)
if err != nil {
return 0, err
@@ -556,8 +562,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
`
type RemoveHostTagParams struct {
- HostID string `json:"host_id"`
- Tag string `json:"tag"`
+ HostID pgtype.UUID `json:"host_id"`
+ Tag string `json:"tag"`
}
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
@@ -565,11 +571,30 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
return err
}
+const updateHostCert = `-- name: UpdateHostCert :exec
+UPDATE hosts
+SET cert_fingerprint = $2,
+ cert_expires_at = $3,
+ updated_at = NOW()
+WHERE id = $1
+`
+
+type UpdateHostCertParams struct {
+ ID pgtype.UUID `json:"id"`
+ CertFingerprint string `json:"cert_fingerprint"`
+ CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
+}
+
+func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
+ _, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
+ return err
+}
+
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
`
-func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
+func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
return err
}
@@ -584,7 +609,7 @@ WHERE id = $1
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
// Returns 0 if no host was found (deleted), which the caller treats as 404.
-func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) {
+func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
if err != nil {
return 0, err
@@ -597,8 +622,8 @@ UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`
type UpdateHostStatusParams struct {
- ID string `json:"id"`
- Status string `json:"status"`
+ ID pgtype.UUID `json:"id"`
+ Status string `json:"status"`
}
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
diff --git a/internal/db/metrics.sql.go b/internal/db/metrics.sql.go
index 80501559..f522dc2d 100644
--- a/internal/db/metrics.sql.go
+++ b/internal/db/metrics.sql.go
@@ -7,6 +7,8 @@ package db
import (
"context"
+
+ "github.com/jackc/pgx/v5/pgtype"
)
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
@@ -14,7 +16,7 @@ DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1
`
-func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID string) error {
+func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
return err
}
@@ -25,8 +27,8 @@ WHERE sandbox_id = $1 AND tier = $2
`
type DeleteSandboxMetricPointsByTierParams struct {
- SandboxID string `json:"sandbox_id"`
- Tier string `json:"tier"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ Tier string `json:"tier"`
}
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
@@ -53,7 +55,7 @@ type GetLiveMetricsRow struct {
// Reads directly from sandboxes for accurate real-time current values.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
-func (q *Queries) GetLiveMetrics(ctx context.Context, teamID string) (GetLiveMetricsRow, error) {
+func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
var i GetLiveMetricsRow
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
@@ -76,7 +78,7 @@ type GetPeakMetricsRow struct {
PeakMemoryMb int32 `json:"peak_memory_mb"`
}
-func (q *Queries) GetPeakMetrics(ctx context.Context, teamID string) (GetPeakMetricsRow, error) {
+func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
var i GetPeakMetricsRow
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
@@ -91,9 +93,9 @@ ORDER BY ts ASC
`
type GetSandboxMetricPointsParams struct {
- SandboxID string `json:"sandbox_id"`
- Tier string `json:"tier"`
- Ts int64 `json:"ts"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ Tier string `json:"tier"`
+ Ts int64 `json:"ts"`
}
type GetSandboxMetricPointsRow struct {
@@ -134,10 +136,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertMetricsSnapshotParams struct {
- TeamID string `json:"team_id"`
- RunningCount int32 `json:"running_count"`
- VcpusReserved int32 `json:"vcpus_reserved"`
- MemoryMbReserved int32 `json:"memory_mb_reserved"`
+ TeamID pgtype.UUID `json:"team_id"`
+ RunningCount int32 `json:"running_count"`
+ VcpusReserved int32 `json:"vcpus_reserved"`
+ MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
@@ -157,12 +159,12 @@ ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
`
type InsertSandboxMetricPointParams struct {
- SandboxID string `json:"sandbox_id"`
- Tier string `json:"tier"`
- Ts int64 `json:"ts"`
- CpuPct float64 `json:"cpu_pct"`
- MemBytes int64 `json:"mem_bytes"`
- DiskBytes int64 `json:"disk_bytes"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ Tier string `json:"tier"`
+ Ts int64 `json:"ts"`
+ CpuPct float64 `json:"cpu_pct"`
+ MemBytes int64 `json:"mem_bytes"`
+ DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
@@ -210,10 +212,10 @@ GROUP BY team_id
`
type SampleSandboxMetricsRow struct {
- TeamID string `json:"team_id"`
- RunningCount int32 `json:"running_count"`
- VcpusReserved int32 `json:"vcpus_reserved"`
- MemoryMbReserved int32 `json:"memory_mb_reserved"`
+ TeamID pgtype.UUID `json:"team_id"`
+ RunningCount int32 `json:"running_count"`
+ VcpusReserved int32 `json:"vcpus_reserved"`
+ MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Aggregates per-team resource usage from the live sandboxes table.
diff --git a/internal/db/models.go b/internal/db/models.go
index 0128f4a8..1e9a5d00 100644
--- a/internal/db/models.go
+++ b/internal/db/models.go
@@ -9,18 +9,18 @@ import (
)
type AdminPermission struct {
- ID string `json:"id"`
- UserID string `json:"user_id"`
+ ID pgtype.UUID `json:"id"`
+ UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type AuditLog struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
- ActorName pgtype.Text `json:"actor_name"`
+ ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
@@ -31,29 +31,29 @@ type AuditLog struct {
}
type Host struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Type string `json:"type"`
- TeamID pgtype.Text `json:"team_id"`
- Provider pgtype.Text `json:"provider"`
- AvailabilityZone pgtype.Text `json:"availability_zone"`
- Arch pgtype.Text `json:"arch"`
- CpuCores pgtype.Int4 `json:"cpu_cores"`
- MemoryMb pgtype.Int4 `json:"memory_mb"`
- DiskGb pgtype.Int4 `json:"disk_gb"`
- Address pgtype.Text `json:"address"`
+ TeamID pgtype.UUID `json:"team_id"`
+ Provider string `json:"provider"`
+ AvailabilityZone string `json:"availability_zone"`
+ Arch string `json:"arch"`
+ CpuCores int32 `json:"cpu_cores"`
+ MemoryMb int32 `json:"memory_mb"`
+ DiskGb int32 `json:"disk_gb"`
+ Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
- CreatedBy string `json:"created_by"`
+ CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
- CertFingerprint pgtype.Text `json:"cert_fingerprint"`
- MtlsEnabled bool `json:"mtls_enabled"`
+ CertFingerprint string `json:"cert_fingerprint"`
+ CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
type HostRefreshToken struct {
- ID string `json:"id"`
- HostID string `json:"host_id"`
+ ID pgtype.UUID `json:"id"`
+ HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
@@ -61,14 +61,14 @@ type HostRefreshToken struct {
}
type HostTag struct {
- HostID string `json:"host_id"`
- Tag string `json:"tag"`
+ HostID pgtype.UUID `json:"host_id"`
+ Tag string `json:"tag"`
}
type HostToken struct {
- ID string `json:"id"`
- HostID string `json:"host_id"`
- CreatedBy string `json:"created_by"`
+ ID pgtype.UUID `json:"id"`
+ HostID pgtype.UUID `json:"host_id"`
+ CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
UsedAt pgtype.Timestamptz `json:"used_at"`
@@ -77,40 +77,43 @@ type HostToken struct {
type OauthProvider struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
- UserID string `json:"user_id"`
+ UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Sandbox struct {
- ID string `json:"id"`
- HostID string `json:"host_id"`
- Template string `json:"template"`
- Status string `json:"status"`
- Vcpus int32 `json:"vcpus"`
- MemoryMb int32 `json:"memory_mb"`
- TimeoutSec int32 `json:"timeout_sec"`
- GuestIp string `json:"guest_ip"`
- HostIp string `json:"host_ip"`
- CreatedAt pgtype.Timestamptz `json:"created_at"`
- StartedAt pgtype.Timestamptz `json:"started_at"`
- LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
- LastUpdated pgtype.Timestamptz `json:"last_updated"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ HostID pgtype.UUID `json:"host_id"`
+ Template string `json:"template"`
+ Status string `json:"status"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
+ TimeoutSec int32 `json:"timeout_sec"`
+ DiskSizeMb int32 `json:"disk_size_mb"`
+ GuestIp string `json:"guest_ip"`
+ HostIp string `json:"host_ip"`
+ CreatedAt pgtype.Timestamptz `json:"created_at"`
+ StartedAt pgtype.Timestamptz `json:"started_at"`
+ LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
+ LastUpdated pgtype.Timestamptz `json:"last_updated"`
+ TemplateID pgtype.UUID `json:"template_id"`
+ TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
type SandboxMetricPoint struct {
- SandboxID string `json:"sandbox_id"`
- Tier string `json:"tier"`
- Ts int64 `json:"ts"`
- CpuPct float64 `json:"cpu_pct"`
- MemBytes int64 `json:"mem_bytes"`
- DiskBytes int64 `json:"disk_bytes"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ Tier string `json:"tier"`
+ Ts int64 `json:"ts"`
+ CpuPct float64 `json:"cpu_pct"`
+ MemBytes int64 `json:"mem_bytes"`
+ DiskBytes int64 `json:"disk_bytes"`
}
type SandboxMetricsSnapshot struct {
ID int64 `json:"id"`
- TeamID string `json:"team_id"`
+ TeamID pgtype.UUID `json:"team_id"`
SampledAt pgtype.Timestamptz `json:"sampled_at"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
@@ -118,21 +121,21 @@ type SandboxMetricsSnapshot struct {
}
type Team struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Name string `json:"name"`
- CreatedAt pgtype.Timestamptz `json:"created_at"`
- IsByoc bool `json:"is_byoc"`
Slug string `json:"slug"`
+ IsByoc bool `json:"is_byoc"`
+ CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type TeamApiKey struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
- CreatedBy string `json:"created_by"`
+ CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
}
@@ -140,26 +143,50 @@ type TeamApiKey struct {
type Template struct {
Name string `json:"name"`
Type string `json:"type"`
- Vcpus pgtype.Int4 `json:"vcpus"`
- MemoryMb pgtype.Int4 `json:"memory_mb"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
- TeamID string `json:"team_id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+}
+
+type TemplateBuild struct {
+ ID pgtype.UUID `json:"id"`
+ Name string `json:"name"`
+ BaseTemplate string `json:"base_template"`
+ Recipe []byte `json:"recipe"`
+ Healthcheck string `json:"healthcheck"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
+ Status string `json:"status"`
+ CurrentStep int32 `json:"current_step"`
+ TotalSteps int32 `json:"total_steps"`
+ Logs []byte `json:"logs"`
+ Error string `json:"error"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ HostID pgtype.UUID `json:"host_id"`
+ CreatedAt pgtype.Timestamptz `json:"created_at"`
+ StartedAt pgtype.Timestamptz `json:"started_at"`
+ CompletedAt pgtype.Timestamptz `json:"completed_at"`
+ TemplateID pgtype.UUID `json:"template_id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ SkipPrePost bool `json:"skip_pre_post"`
}
type User struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
+ Name string `json:"name"`
+ IsAdmin bool `json:"is_admin"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
- IsAdmin bool `json:"is_admin"`
- Name string `json:"name"`
}
type UsersTeam struct {
- UserID string `json:"user_id"`
- TeamID string `json:"team_id"`
+ UserID pgtype.UUID `json:"user_id"`
+ TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
diff --git a/internal/db/oauth.sql.go b/internal/db/oauth.sql.go
index ab79eec5..0270defa 100644
--- a/internal/db/oauth.sql.go
+++ b/internal/db/oauth.sql.go
@@ -7,6 +7,8 @@ package db
import (
"context"
+
+ "github.com/jackc/pgx/v5/pgtype"
)
const getOAuthProvider = `-- name: GetOAuthProvider :one
@@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertOAuthProviderParams struct {
- Provider string `json:"provider"`
- ProviderID string `json:"provider_id"`
- UserID string `json:"user_id"`
- Email string `json:"email"`
+ Provider string `json:"provider"`
+ ProviderID string `json:"provider_id"`
+ UserID pgtype.UUID `json:"user_id"`
+ Email string `json:"email"`
}
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go
index 620f77e2..3ce16443 100644
--- a/internal/db/sandboxes.sql.go
+++ b/internal/db/sandboxes.sql.go
@@ -15,12 +15,12 @@ const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
-WHERE id = ANY($1::text[]) AND status = 'missing'
+WHERE id = ANY($1::uuid[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
-func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error {
+func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
@@ -29,12 +29,12 @@ const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
-WHERE id = ANY($1::text[])
+WHERE id = ANY($1::uuid[])
`
type BulkUpdateStatusByIDsParams struct {
- Column1 []string `json:"column_1"`
- Status string `json:"status"`
+ Column1 []pgtype.UUID `json:"column_1"`
+ Status string `json:"status"`
}
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
@@ -43,38 +43,41 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu
}
const getSandbox = `-- name: GetSandbox :one
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
`
-func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) {
+func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandbox, id)
var i Sandbox
err := row.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
)
return i, err
}
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
`
type GetSandboxByTeamParams struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
@@ -82,38 +85,70 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara
var i Sandbox
err := row.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
)
return i, err
}
+const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
+SELECT s.status, h.address AS host_address
+FROM sandboxes s
+JOIN hosts h ON h.id = s.host_id
+WHERE s.id = $1 AND s.team_id = $2
+`
+
+type GetSandboxProxyTargetParams struct {
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
+}
+
+type GetSandboxProxyTargetRow struct {
+ Status string `json:"status"`
+ HostAddress string `json:"host_address"`
+}
+
+// Returns the sandbox status and its host's address in one query.
+// Used by SandboxProxyWrapper to avoid two round-trips.
+func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
+ row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
+ var i GetSandboxProxyTargetRow
+ err := row.Scan(&i.Status, &i.HostAddress)
+ return i, err
+}
+
const insertSandbox = `-- name: InsertSandbox :one
-INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
-VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
-RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
+INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
+VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type InsertSandboxParams struct {
- ID string `json:"id"`
- TeamID string `json:"team_id"`
- HostID string `json:"host_id"`
- Template string `json:"template"`
- Status string `json:"status"`
- Vcpus int32 `json:"vcpus"`
- MemoryMb int32 `json:"memory_mb"`
- TimeoutSec int32 `json:"timeout_sec"`
+ ID pgtype.UUID `json:"id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ HostID pgtype.UUID `json:"host_id"`
+ Template string `json:"template"`
+ Status string `json:"status"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
+ TimeoutSec int32 `json:"timeout_sec"`
+ DiskSizeMb int32 `json:"disk_size_mb"`
+ TemplateID pgtype.UUID `json:"template_id"`
+ TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
@@ -126,34 +161,40 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S
arg.Vcpus,
arg.MemoryMb,
arg.TimeoutSec,
+ arg.DiskSizeMb,
+ arg.TemplateID,
+ arg.TemplateTeamID,
)
var i Sandbox
err := row.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
)
return i, err
}
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC
`
-func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
+func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
if err != nil {
return nil, err
@@ -164,19 +205,22 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string)
var i Sandbox
if err := rows.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
); err != nil {
return nil, err
}
@@ -189,7 +233,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string)
}
const listSandboxes = `-- name: ListSandboxes :many
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
`
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
@@ -203,19 +247,22 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
var i Sandbox
if err := rows.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
); err != nil {
return nil, err
}
@@ -228,14 +275,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
}
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE host_id = $1 AND status = ANY($2::text[])
ORDER BY created_at DESC
`
type ListSandboxesByHostAndStatusParams struct {
- HostID string `json:"host_id"`
- Column2 []string `json:"column_2"`
+ HostID pgtype.UUID `json:"host_id"`
+ Column2 []string `json:"column_2"`
}
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
@@ -249,19 +296,22 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
var i Sandbox
if err := rows.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
); err != nil {
return nil, err
}
@@ -274,12 +324,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
}
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
-SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
+SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
ORDER BY created_at DESC
`
-func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
+func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
if err != nil {
return nil, err
@@ -290,19 +340,22 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
var i Sandbox
if err := rows.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
); err != nil {
return nil, err
}
@@ -324,7 +377,7 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
-func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error {
+func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
@@ -337,7 +390,7 @@ WHERE id = $1
`
type UpdateLastActiveParams struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
}
@@ -355,11 +408,11 @@ SET status = 'running',
last_active_at = $4,
last_updated = NOW()
WHERE id = $1
-RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
+RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxRunningParams struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
HostIp string `json:"host_ip"`
GuestIp string `json:"guest_ip"`
StartedAt pgtype.Timestamptz `json:"started_at"`
@@ -375,19 +428,22 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun
var i Sandbox
err := row.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
)
return i, err
}
@@ -397,12 +453,12 @@ UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = $1
-RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
+RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxStatusParams struct {
- ID string `json:"id"`
- Status string `json:"status"`
+ ID pgtype.UUID `json:"id"`
+ Status string `json:"status"`
}
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
@@ -410,19 +466,22 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat
var i Sandbox
err := row.Scan(
&i.ID,
+ &i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
+ &i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
- &i.TeamID,
+ &i.TemplateID,
+ &i.TemplateTeamID,
)
return i, err
}
diff --git a/internal/db/teams.sql.go b/internal/db/teams.sql.go
index a00f5efc..334141ff 100644
--- a/internal/db/teams.sql.go
+++ b/internal/db/teams.sql.go
@@ -16,8 +16,8 @@ DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
`
type DeleteTeamMemberParams struct {
- TeamID string `json:"team_id"`
- UserID string `json:"user_id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
@@ -26,7 +26,7 @@ func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberPara
}
const getBYOCTeams = `-- name: GetBYOCTeams :many
-SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
+SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
`
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
@@ -41,9 +41,9 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
if err := rows.Scan(
&i.ID,
&i.Name,
- &i.CreatedAt,
- &i.IsByoc,
&i.Slug,
+ &i.IsByoc,
+ &i.CreatedAt,
&i.DeletedAt,
); err != nil {
return nil, err
@@ -57,46 +57,46 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
}
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
-SELECT t.id, t.name, t.created_at, t.is_byoc, t.slug, t.deleted_at FROM teams t
+SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
LIMIT 1
`
-func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) {
+func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
- &i.CreatedAt,
- &i.IsByoc,
&i.Slug,
+ &i.IsByoc,
+ &i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeam = `-- name: GetTeam :one
-SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE id = $1
+SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
`
-func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
+func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getTeam, id)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
- &i.CreatedAt,
- &i.IsByoc,
&i.Slug,
+ &i.IsByoc,
+ &i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamBySlug = `-- name: GetTeamBySlug :one
-SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
+SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
`
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
@@ -105,9 +105,9 @@ func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error)
err := row.Scan(
&i.ID,
&i.Name,
- &i.CreatedAt,
- &i.IsByoc,
&i.Slug,
+ &i.IsByoc,
+ &i.CreatedAt,
&i.DeletedAt,
)
return i, err
@@ -122,14 +122,14 @@ ORDER BY ut.created_at
`
type GetTeamMembersRow struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt pgtype.Timestamptz `json:"joined_at"`
}
-func (q *Queries) GetTeamMembers(ctx context.Context, teamID string) ([]GetTeamMembersRow, error) {
+func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
if err != nil {
return nil, err
@@ -160,8 +160,8 @@ SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE use
`
type GetTeamMembershipParams struct {
- UserID string `json:"user_id"`
- TeamID string `json:"team_id"`
+ UserID pgtype.UUID `json:"user_id"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
@@ -186,7 +186,7 @@ ORDER BY ut.created_at
`
type GetTeamsForUserRow struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
@@ -195,7 +195,7 @@ type GetTeamsForUserRow struct {
Role string `json:"role"`
}
-func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeamsForUserRow, error) {
+func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
if err != nil {
return nil, err
@@ -226,13 +226,13 @@ func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeam
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name, slug)
VALUES ($1, $2, $3)
-RETURNING id, name, created_at, is_byoc, slug, deleted_at
+RETURNING id, name, slug, is_byoc, created_at, deleted_at
`
type InsertTeamParams struct {
- ID string `json:"id"`
- Name string `json:"name"`
- Slug string `json:"slug"`
+ ID pgtype.UUID `json:"id"`
+ Name string `json:"name"`
+ Slug string `json:"slug"`
}
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
@@ -241,9 +241,9 @@ func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, e
err := row.Scan(
&i.ID,
&i.Name,
- &i.CreatedAt,
- &i.IsByoc,
&i.Slug,
+ &i.IsByoc,
+ &i.CreatedAt,
&i.DeletedAt,
)
return i, err
@@ -255,10 +255,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertTeamMemberParams struct {
- UserID string `json:"user_id"`
- TeamID string `json:"team_id"`
- IsDefault bool `json:"is_default"`
- Role string `json:"role"`
+ UserID pgtype.UUID `json:"user_id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ IsDefault bool `json:"is_default"`
+ Role string `json:"role"`
}
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
@@ -276,8 +276,8 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1
`
type SetTeamBYOCParams struct {
- ID string `json:"id"`
- IsByoc bool `json:"is_byoc"`
+ ID pgtype.UUID `json:"id"`
+ IsByoc bool `json:"is_byoc"`
}
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
@@ -289,7 +289,7 @@ const softDeleteTeam = `-- name: SoftDeleteTeam :exec
UPDATE teams SET deleted_at = NOW() WHERE id = $1
`
-func (q *Queries) SoftDeleteTeam(ctx context.Context, id string) error {
+func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, softDeleteTeam, id)
return err
}
@@ -299,9 +299,9 @@ UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
`
type UpdateMemberRoleParams struct {
- TeamID string `json:"team_id"`
- UserID string `json:"user_id"`
- Role string `json:"role"`
+ TeamID pgtype.UUID `json:"team_id"`
+ UserID pgtype.UUID `json:"user_id"`
+ Role string `json:"role"`
}
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
@@ -314,8 +314,8 @@ UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
`
type UpdateTeamNameParams struct {
- ID string `json:"id"`
- Name string `json:"name"`
+ ID pgtype.UUID `json:"id"`
+ Name string `json:"name"`
}
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go
new file mode 100644
index 00000000..facfb199
--- /dev/null
+++ b/internal/db/template_builds.sql.go
@@ -0,0 +1,241 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: template_builds.sql
+
+package db
+
+import (
+ "context"
+
+ "github.com/jackc/pgx/v5/pgtype"
+)
+
+const getTemplateBuild = `-- name: GetTemplateBuild :one
+SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
+`
+
+func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
+ row := q.db.QueryRow(ctx, getTemplateBuild, id)
+ var i TemplateBuild
+ err := row.Scan(
+ &i.ID,
+ &i.Name,
+ &i.BaseTemplate,
+ &i.Recipe,
+ &i.Healthcheck,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.Status,
+ &i.CurrentStep,
+ &i.TotalSteps,
+ &i.Logs,
+ &i.Error,
+ &i.SandboxID,
+ &i.HostID,
+ &i.CreatedAt,
+ &i.StartedAt,
+ &i.CompletedAt,
+ &i.TemplateID,
+ &i.TeamID,
+ &i.SkipPrePost,
+ )
+ return i, err
+}
+
+const insertTemplateBuild = `-- name: InsertTemplateBuild :one
+INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
+VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
+RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
+`
+
+type InsertTemplateBuildParams struct {
+ ID pgtype.UUID `json:"id"`
+ Name string `json:"name"`
+ BaseTemplate string `json:"base_template"`
+ Recipe []byte `json:"recipe"`
+ Healthcheck string `json:"healthcheck"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
+ TotalSteps int32 `json:"total_steps"`
+ TemplateID pgtype.UUID `json:"template_id"`
+ TeamID pgtype.UUID `json:"team_id"`
+ SkipPrePost bool `json:"skip_pre_post"`
+}
+
+func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
+ row := q.db.QueryRow(ctx, insertTemplateBuild,
+ arg.ID,
+ arg.Name,
+ arg.BaseTemplate,
+ arg.Recipe,
+ arg.Healthcheck,
+ arg.Vcpus,
+ arg.MemoryMb,
+ arg.TotalSteps,
+ arg.TemplateID,
+ arg.TeamID,
+ arg.SkipPrePost,
+ )
+ var i TemplateBuild
+ err := row.Scan(
+ &i.ID,
+ &i.Name,
+ &i.BaseTemplate,
+ &i.Recipe,
+ &i.Healthcheck,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.Status,
+ &i.CurrentStep,
+ &i.TotalSteps,
+ &i.Logs,
+ &i.Error,
+ &i.SandboxID,
+ &i.HostID,
+ &i.CreatedAt,
+ &i.StartedAt,
+ &i.CompletedAt,
+ &i.TemplateID,
+ &i.TeamID,
+ &i.SkipPrePost,
+ )
+ return i, err
+}
+
+const listTemplateBuilds = `-- name: ListTemplateBuilds :many
+SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
+`
+
+func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
+ rows, err := q.db.Query(ctx, listTemplateBuilds)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []TemplateBuild
+ for rows.Next() {
+ var i TemplateBuild
+ if err := rows.Scan(
+ &i.ID,
+ &i.Name,
+ &i.BaseTemplate,
+ &i.Recipe,
+ &i.Healthcheck,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.Status,
+ &i.CurrentStep,
+ &i.TotalSteps,
+ &i.Logs,
+ &i.Error,
+ &i.SandboxID,
+ &i.HostID,
+ &i.CreatedAt,
+ &i.StartedAt,
+ &i.CompletedAt,
+ &i.TemplateID,
+ &i.TeamID,
+ &i.SkipPrePost,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const updateBuildError = `-- name: UpdateBuildError :exec
+UPDATE template_builds
+SET error = $2, status = 'failed', completed_at = NOW()
+WHERE id = $1
+`
+
+type UpdateBuildErrorParams struct {
+ ID pgtype.UUID `json:"id"`
+ Error string `json:"error"`
+}
+
+func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
+ _, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
+ return err
+}
+
+const updateBuildProgress = `-- name: UpdateBuildProgress :exec
+UPDATE template_builds
+SET current_step = $2, logs = $3
+WHERE id = $1
+`
+
+type UpdateBuildProgressParams struct {
+ ID pgtype.UUID `json:"id"`
+ CurrentStep int32 `json:"current_step"`
+ Logs []byte `json:"logs"`
+}
+
+func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
+ _, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
+ return err
+}
+
+const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
+UPDATE template_builds
+SET sandbox_id = $2, host_id = $3
+WHERE id = $1
+`
+
+type UpdateBuildSandboxParams struct {
+ ID pgtype.UUID `json:"id"`
+ SandboxID pgtype.UUID `json:"sandbox_id"`
+ HostID pgtype.UUID `json:"host_id"`
+}
+
+func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
+ _, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
+ return err
+}
+
+const updateBuildStatus = `-- name: UpdateBuildStatus :one
+UPDATE template_builds
+SET status = $2,
+ started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
+ completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
+WHERE id = $1
+RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
+`
+
+type UpdateBuildStatusParams struct {
+ ID pgtype.UUID `json:"id"`
+ Status string `json:"status"`
+}
+
+func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
+ row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
+ var i TemplateBuild
+ err := row.Scan(
+ &i.ID,
+ &i.Name,
+ &i.BaseTemplate,
+ &i.Recipe,
+ &i.Healthcheck,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.Status,
+ &i.CurrentStep,
+ &i.TotalSteps,
+ &i.Logs,
+ &i.Error,
+ &i.SandboxID,
+ &i.HostID,
+ &i.CreatedAt,
+ &i.StartedAt,
+ &i.CompletedAt,
+ &i.TemplateID,
+ &i.TeamID,
+ &i.SkipPrePost,
+ )
+ return i, err
+}
diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go
index cafae692..7d37808d 100644
--- a/internal/db/templates.sql.go
+++ b/internal/db/templates.sql.go
@@ -12,11 +12,11 @@ import (
)
const deleteTemplate = `-- name: DeleteTemplate :exec
-DELETE FROM templates WHERE name = $1
+DELETE FROM templates WHERE id = $1
`
-func (q *Queries) DeleteTemplate(ctx context.Context, name string) error {
- _, err := q.db.Exec(ctx, deleteTemplate, name)
+func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
+ _, err := q.db.Exec(ctx, deleteTemplate, id)
return err
}
@@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2
`
type DeleteTemplateByTeamParams struct {
- Name string `json:"name"`
- TeamID string `json:"team_id"`
+ Name string `json:"name"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
@@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy
return err
}
-const getTemplate = `-- name: GetTemplate :one
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1
+const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
+DELETE FROM templates WHERE team_id = $1
`
-func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) {
- row := q.db.QueryRow(ctx, getTemplate, name)
+// Bulk delete all templates owned by a team (for team soft-delete cleanup).
+func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
+ _, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
+ return err
+}
+
+const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
+`
+
+// Check if a global (platform) template exists with the given name.
+func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
+ row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
var i Template
err := row.Scan(
&i.Name,
@@ -49,19 +60,67 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
+ )
+ return i, err
+}
+
+const getTemplate = `-- name: GetTemplate :one
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
+`
+
+func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
+ row := q.db.QueryRow(ctx, getTemplate, id)
+ var i Template
+ err := row.Scan(
+ &i.Name,
+ &i.Type,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.SizeBytes,
+ &i.CreatedAt,
+ &i.TeamID,
+ &i.ID,
+ )
+ return i, err
+}
+
+const getTemplateByName = `-- name: GetTemplateByName :one
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
+`
+
+type GetTemplateByNameParams struct {
+ TeamID pgtype.UUID `json:"team_id"`
+ Name string `json:"name"`
+}
+
+// Look up a template by team_id and name (exact team match, no global fallback).
+func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
+ row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
+ var i Template
+ err := row.Scan(
+ &i.Name,
+ &i.Type,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.SizeBytes,
+ &i.CreatedAt,
+ &i.TeamID,
+ &i.ID,
)
return i, err
}
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
`
type GetTemplateByTeamParams struct {
- Name string `json:"name"`
- TeamID string `json:"team_id"`
+ Name string `json:"name"`
+ TeamID pgtype.UUID `json:"team_id"`
}
+// Platform templates (team_id = 00000000-...) are visible to all teams.
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
var i Template
@@ -73,27 +132,30 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
)
return i, err
}
const insertTemplate = `-- name: InsertTemplate :one
-INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
-VALUES ($1, $2, $3, $4, $5, $6)
-RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id
+INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
+VALUES ($1, $2, $3, $4, $5, $6, $7)
+RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
`
type InsertTemplateParams struct {
+ ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
- Vcpus pgtype.Int4 `json:"vcpus"`
- MemoryMb pgtype.Int4 `json:"memory_mb"`
+ Vcpus int32 `json:"vcpus"`
+ MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
- TeamID string `json:"team_id"`
+ TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
row := q.db.QueryRow(ctx, insertTemplate,
+ arg.ID,
arg.Name,
arg.Type,
arg.Vcpus,
@@ -110,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams)
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
)
return i, err
}
const listTemplates = `-- name: ListTemplates :many
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
`
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
@@ -135,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
); err != nil {
return nil, err
}
@@ -147,10 +211,11 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
}
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
`
-func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) {
+// Platform templates are visible to all teams.
+func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
if err != nil {
return nil, err
@@ -167,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
); err != nil {
return nil, err
}
@@ -179,14 +245,15 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
}
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
`
type ListTemplatesByTeamAndTypeParams struct {
- TeamID string `json:"team_id"`
- Type string `json:"type"`
+ TeamID pgtype.UUID `json:"team_id"`
+ Type string `json:"type"`
}
+// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
if err != nil {
@@ -204,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
+`
+
+// List templates owned by a specific team (NOT including platform templates).
+func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
+ rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []Template
+ for rows.Next() {
+ var i Template
+ if err := rows.Scan(
+ &i.Name,
+ &i.Type,
+ &i.Vcpus,
+ &i.MemoryMb,
+ &i.SizeBytes,
+ &i.CreatedAt,
+ &i.TeamID,
+ &i.ID,
); err != nil {
return nil, err
}
@@ -216,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
}
const listTemplatesByType = `-- name: ListTemplatesByType :many
-SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC
+SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
@@ -236,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
+ &i.ID,
); err != nil {
return nil, err
}
diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go
index 50ba287e..9de866bb 100644
--- a/internal/db/users.sql.go
+++ b/internal/db/users.sql.go
@@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
`
type DeleteAdminPermissionParams struct {
- UserID string `json:"user_id"`
- Permission string `json:"permission"`
+ UserID pgtype.UUID `json:"user_id"`
+ Permission string `json:"permission"`
}
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
@@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
`
-func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) {
+func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
if err != nil {
return nil, err
@@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm
}
const getAdminUsers = `-- name: GetAdminUsers :many
-SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE is_admin = TRUE ORDER BY created_at
+SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
`
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
@@ -71,10 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
&i.ID,
&i.Email,
&i.PasswordHash,
+ &i.Name,
+ &i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
- &i.IsAdmin,
- &i.Name,
); err != nil {
return nil, err
}
@@ -87,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
}
const getUserByEmail = `-- name: GetUserByEmail :one
-SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE email = $1
+SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
@@ -97,29 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
&i.ID,
&i.Email,
&i.PasswordHash,
+ &i.Name,
+ &i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
- &i.IsAdmin,
- &i.Name,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
-SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE id = $1
+SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
`
-func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) {
+func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
+ &i.Name,
+ &i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
- &i.IsAdmin,
- &i.Name,
)
return i, err
}
@@ -131,8 +131,8 @@ SELECT EXISTS(
`
type HasAdminPermissionParams struct {
- UserID string `json:"user_id"`
- Permission string `json:"permission"`
+ UserID pgtype.UUID `json:"user_id"`
+ Permission string `json:"permission"`
}
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
@@ -148,9 +148,9 @@ VALUES ($1, $2, $3)
`
type InsertAdminPermissionParams struct {
- ID string `json:"id"`
- UserID string `json:"user_id"`
- Permission string `json:"permission"`
+ ID pgtype.UUID `json:"id"`
+ UserID pgtype.UUID `json:"user_id"`
+ Permission string `json:"permission"`
}
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
@@ -161,11 +161,11 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm
const insertUser = `-- name: InsertUser :one
INSERT INTO users (id, email, password_hash, name)
VALUES ($1, $2, $3, $4)
-RETURNING id, email, password_hash, created_at, updated_at, is_admin, name
+RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserParams struct {
- ID string `json:"id"`
+ ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
@@ -183,10 +183,10 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e
&i.ID,
&i.Email,
&i.PasswordHash,
+ &i.Name,
+ &i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
- &i.IsAdmin,
- &i.Name,
)
return i, err
}
@@ -194,13 +194,13 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e
const insertUserOAuth = `-- name: InsertUserOAuth :one
INSERT INTO users (id, email, name)
VALUES ($1, $2, $3)
-RETURNING id, email, password_hash, created_at, updated_at, is_admin, name
+RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserOAuthParams struct {
- ID string `json:"id"`
- Email string `json:"email"`
- Name string `json:"name"`
+ ID pgtype.UUID `json:"id"`
+ Email string `json:"email"`
+ Name string `json:"name"`
}
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
@@ -210,10 +210,10 @@ func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams
&i.ID,
&i.Email,
&i.PasswordHash,
+ &i.Name,
+ &i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
- &i.IsAdmin,
- &i.Name,
)
return i, err
}
@@ -223,8 +223,8 @@ SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`
type SearchUsersByEmailPrefixRow struct {
- ID string `json:"id"`
- Email string `json:"email"`
+ ID pgtype.UUID `json:"id"`
+ Email string `json:"email"`
}
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
@@ -252,8 +252,8 @@ UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
`
type SetUserAdminParams struct {
- ID string `json:"id"`
- IsAdmin bool `json:"is_admin"`
+ ID pgtype.UUID `json:"id"`
+ IsAdmin bool `json:"is_admin"`
}
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
@@ -266,8 +266,8 @@ UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
`
type UpdateUserNameParams struct {
- ID string `json:"id"`
- Name string `json:"name"`
+ ID pgtype.UUID `json:"id"`
+ Name string `json:"name"`
}
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
diff --git a/internal/devicemapper/devicemapper.go b/internal/devicemapper/devicemapper.go
index ea14fcd8..9fa08332 100644
--- a/internal/devicemapper/devicemapper.go
+++ b/internal/devicemapper/devicemapper.go
@@ -116,9 +116,10 @@ type SnapshotDevice struct {
// writable CoW layer.
//
// The origin loop device must already exist (from LoopRegistry.Acquire).
-func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
- // Create sparse CoW file sized to match the origin.
- if err := createSparseFile(cowPath, originSizeBytes); err != nil {
+func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
+ // Create sparse CoW file. The logical size limits how many blocks can be
+ // modified; because the file is sparse, only written blocks use real disk.
+ if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
return nil, fmt.Errorf("create cow file: %w", err)
}
@@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64)
return nil, fmt.Errorf("losetup cow: %w", err)
}
+ // The dm-snapshot virtual device size must match the origin — the snapshot
+ // target maps 1:1 onto origin sectors. The CoW file just needs enough
+ // space to store all modified blocks (it's sparse, so 20GB costs nothing).
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
@@ -220,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error {
"if="+dmDevPath,
"of="+outputPath,
"bs=4M",
+ "conv=sparse",
"status=none",
)
if out, err := cmd.CombinedOutput(); err != nil {
diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go
index 04a1dc2e..49765690 100644
--- a/internal/envdclient/client.go
+++ b/internal/envdclient/client.go
@@ -3,14 +3,12 @@ package envdclient
import (
"bytes"
"context"
- "encoding/json"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"net/url"
- "time"
"connectrpc.com/connect"
@@ -49,35 +47,6 @@ func (c *Client) BaseURL() string {
return c.base
}
-// Init calls POST /init on envd to sync the guest clock with the host.
-// This is important after snapshot resume where the guest clock is frozen.
-func (c *Client) Init(ctx context.Context) error {
- now := time.Now().UTC()
- body, err := json.Marshal(map[string]any{"timestamp": now})
- if err != nil {
- return fmt.Errorf("marshal init body: %w", err)
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body))
- if err != nil {
- return fmt.Errorf("create init request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- resp, err := c.httpClient.Do(req)
- if err != nil {
- return fmt.Errorf("init request: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusNoContent {
- respBody, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody))
- }
-
- return nil
-}
-
// ExecResult holds the output of a command execution.
type ExecResult struct {
Stdout []byte
diff --git a/internal/hostagent/certstore.go b/internal/hostagent/certstore.go
new file mode 100644
index 00000000..4260ba2f
--- /dev/null
+++ b/internal/hostagent/certstore.go
@@ -0,0 +1,42 @@
+package hostagent
+
+import (
+ "crypto/tls"
+ "fmt"
+ "sync/atomic"
+)
+
+// CertStore provides lock-free read/write access to the agent's current TLS
+// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
+// of the agent's cert on JWT refresh without restarting the server.
+//
+// The zero value is usable; GetCert returns an error until a cert is stored.
+type CertStore struct {
+ ptr atomic.Pointer[tls.Certificate]
+}
+
+// Store atomically replaces the current certificate.
+func (s *CertStore) Store(cert *tls.Certificate) {
+ s.ptr.Store(cert)
+}
+
+// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
+// If parsing fails the existing cert is unchanged.
+func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
+ cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
+ if err != nil {
+ return fmt.Errorf("parse TLS key pair: %w", err)
+ }
+ s.ptr.Store(&cert)
+ return nil
+}
+
+// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
+// been stored yet.
+func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ cert := s.ptr.Load()
+ if cert == nil {
+ return nil, fmt.Errorf("no TLS certificate available")
+ }
+ return cert, nil
+}
diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go
new file mode 100644
index 00000000..bbee4741
--- /dev/null
+++ b/internal/hostagent/proxy.go
@@ -0,0 +1,89 @@
+package hostagent
+
+import (
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/http/httputil"
+ "strconv"
+ "strings"
+
+ "git.omukk.dev/wrenn/sandbox/internal/sandbox"
+)
+
+// ProxyHandler reverse-proxies HTTP requests to services running inside
+// sandboxes. It handles requests of the form:
+//
+// /proxy/{sandbox_id}/{port}/{path...}
+//
+// The sandbox's HostIP (routable on this machine) is used as the upstream.
+// This supports any protocol that rides on HTTP, including WebSocket upgrades.
+type ProxyHandler struct {
+ mgr *sandbox.Manager
+ transport http.RoundTripper
+}
+
+// NewProxyHandler creates a new sandbox proxy handler.
+func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
+ return &ProxyHandler{
+ mgr: mgr,
+ transport: http.DefaultTransport,
+ }
+}
+
+// ServeHTTP implements http.Handler.
+func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Expected path: /proxy/{sandbox_id}/{port}/...
+ // After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
+ trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
+ if trimmed == r.URL.Path {
+ http.Error(w, "invalid proxy path", http.StatusBadRequest)
+ return
+ }
+
+ parts := strings.SplitN(trimmed, "/", 3)
+ if len(parts) < 2 {
+ http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
+ return
+ }
+
+ sandboxID := parts[0]
+ port := parts[1]
+ remainder := ""
+ if len(parts) == 3 {
+ remainder = parts[2]
+ }
+
+ // Validate port is a number in the valid range.
+ portNum, err := strconv.Atoi(port)
+ if err != nil || portNum < 1 || portNum > 65535 {
+ http.Error(w, "invalid port", http.StatusBadRequest)
+ return
+ }
+
+ hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
+ if !ok {
+ http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
+ return
+ }
+ defer tracker.Release()
+
+ targetHost := fmt.Sprintf("%s:%d", hostIP, portNum)
+
+ proxy := &httputil.ReverseProxy{
+ Transport: h.transport,
+ Director: func(req *http.Request) {
+ req.URL.Scheme = "http"
+ req.URL.Host = targetHost
+ req.URL.Path = "/" + remainder
+ req.URL.RawQuery = r.URL.RawQuery
+ req.Host = targetHost
+ },
+ ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
+ slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
+ http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
+ },
+ }
+
+ proxy.ServeHTTP(w, r)
+}
diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go
index 9f39c3b0..07909ee1 100644
--- a/internal/hostagent/registration.go
+++ b/internal/hostagent/registration.go
@@ -17,18 +17,24 @@ import (
"golang.org/x/sys/unix"
)
-// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt.
-type tokenFile struct {
+// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
+// It holds all credentials the agent needs: the host JWT, refresh token, and
+// (when mTLS is enabled) the TLS certificate material for the agent's server.
+type TokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
+ // mTLS fields — empty when the CP has no CA configured.
+ CertPEM string `json:"cert_pem,omitempty"`
+ KeyPEM string `json:"key_pem,omitempty"`
+ CACertPEM string `json:"ca_cert_pem,omitempty"`
}
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane
- TokenFile string // Path to persist the host JWT after registration
+ TokenFile string // Path to persist the credentials after registration
Address string // Externally-reachable address (ip:port) for this host
}
@@ -41,22 +47,20 @@ type registerRequest struct {
Address string `json:"address"`
}
-type registerResponse struct {
+// authResponse is the shared JSON shape for both register and refresh responses.
+type authResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
+ CertPEM string `json:"cert_pem,omitempty"`
+ KeyPEM string `json:"key_pem,omitempty"`
+ CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
-type refreshResponse struct {
- Host json.RawMessage `json:"host"`
- Token string `json:"token"`
- RefreshToken string `json:"refresh_token"`
-}
-
type errorResponse struct {
Error struct {
Code string `json:"code"`
@@ -64,8 +68,8 @@ type errorResponse struct {
} `json:"error"`
}
-// loadTokenFile reads and parses the persisted token file.
-func loadTokenFile(path string) (*tokenFile, error) {
+// LoadTokenFile reads and parses the persisted credentials file.
+func LoadTokenFile(path string) (*TokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) {
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
- return &tokenFile{HostID: hostID, JWT: trimmed}, nil
+ return &TokenFile{HostID: hostID, JWT: trimmed}, nil
}
- var tf tokenFile
+ var tf TokenFile
if err := json.Unmarshal(data, &tf); err != nil {
- return nil, fmt.Errorf("parse token file: %w", err)
+ return nil, fmt.Errorf("parse credentials file: %w", err)
}
return &tf, nil
}
-// saveTokenFile writes the token file as JSON with 0600 permissions.
-func saveTokenFile(path string, tf tokenFile) error {
+// saveTokenFile writes the credentials file as JSON with 0600 permissions.
+func saveTokenFile(path string, tf TokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
- return fmt.Errorf("marshal token file: %w", err)
+ return fmt.Errorf("marshal credentials file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists
-// the returned JWT and refresh token to disk. Returns the host JWT token string.
-func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
- // If no explicit registration token was given, reuse the saved JWT.
+// the returned credentials to disk. Returns the full TokenFile on success.
+func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
+ // If no explicit registration token was given, reuse the saved credentials.
// A --register flag always overrides the local file so operators can
- // force re-registration without manually deleting host.jwt.
+ // force re-registration without manually deleting the credentials file.
if cfg.RegistrationToken == "" {
- if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
- slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
- return tf.JWT, nil
+ if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
+ slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
+ return tf, nil
}
- return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
+ return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
}
arch := runtime.GOARCH
@@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
body, err := json.Marshal(reqBody)
if err != nil {
- return "", fmt.Errorf("marshal registration request: %w", err)
+ return nil, fmt.Errorf("marshal registration request: %w", err)
}
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
- return "", fmt.Errorf("create registration request: %w", err)
+ return nil, fmt.Errorf("create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
- return "", fmt.Errorf("registration request failed: %w", err)
+ return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
- return "", fmt.Errorf("read registration response: %w", err)
+ return nil, fmt.Errorf("read registration response: %w", err)
}
if resp.StatusCode != http.StatusCreated {
var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil {
- return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
+ return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
- return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
}
- var regResp registerResponse
+ var regResp authResponse
if err := json.Unmarshal(respBody, ®Resp); err != nil {
- return "", fmt.Errorf("parse registration response: %w", err)
+ return nil, fmt.Errorf("parse registration response: %w", err)
}
if regResp.Token == "" {
- return "", fmt.Errorf("registration response missing token")
+ return nil, fmt.Errorf("registration response missing token")
}
hostID, err := hostIDFromJWT(regResp.Token)
if err != nil {
- return "", fmt.Errorf("extract host ID from JWT: %w", err)
+ return nil, fmt.Errorf("extract host ID from JWT: %w", err)
}
- // Persist JWT + refresh token.
- tf := tokenFile{
+ tf := TokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
+ CertPEM: regResp.CertPEM,
+ KeyPEM: regResp.KeyPEM,
+ CACertPEM: regResp.CACertPEM,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
- return "", fmt.Errorf("save host token: %w", err)
+ return nil, fmt.Errorf("save host credentials: %w", err)
}
- slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
+ slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
- return regResp.Token, nil
+ return &tf, nil
}
-// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
-// It reads and updates the token file in place.
-func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
- tf, err := loadTokenFile(tokenFilePath)
+// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
+// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
+// is updated in place. Returns the updated TokenFile.
+func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
+ tf, err := LoadTokenFile(credentialsFilePath)
if err != nil {
- return "", fmt.Errorf("load token file: %w", err)
+ return nil, fmt.Errorf("load credentials file: %w", err)
}
if tf.RefreshToken == "" {
- return "", fmt.Errorf("no refresh token available; host must re-register")
+ return nil, fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
- return "", fmt.Errorf("create refresh request: %w", err)
+ return nil, fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
- return "", fmt.Errorf("refresh request failed: %w", err)
+ return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
@@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
- return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
+ return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
- return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
- var refResp refreshResponse
+ var refResp authResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
- return "", fmt.Errorf("parse refresh response: %w", err)
+ return nil, fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
- if err := saveTokenFile(tokenFilePath, *tf); err != nil {
- return "", fmt.Errorf("save refreshed token: %w", err)
+ if refResp.CertPEM != "" {
+ tf.CertPEM = refResp.CertPEM
+ tf.KeyPEM = refResp.KeyPEM
+ tf.CACertPEM = refResp.CACertPEM
+ }
+ if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
+ return nil, fmt.Errorf("save refreshed credentials: %w", err)
}
- slog.Info("host JWT refreshed", "host_id", tf.HostID)
- return refResp.Token, nil
+ slog.Info("host credentials refreshed", "host_id", tf.HostID)
+ return tf, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
//
-// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
+// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
//
// onDeleted is called when CP returns 404, meaning this host record was deleted.
-// The token file is removed before calling onDeleted so subsequent starts prompt
-// for a new registration token.
-func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) {
+// The credentials file is removed before calling onDeleted so subsequent starts
+// prompt for a new registration token.
+//
+// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
+// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
+func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
client := &http.Client{Timeout: 10 * time.Second}
go func() {
@@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure := false
currentJWT := ""
- // Load the current JWT from disk.
- if tf, err := loadTokenFile(tokenFilePath); err == nil {
+ // Load the current JWT from the credentials file.
+ if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
currentJWT = tf.JWT
}
@@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
- slog.Warn("heartbeat: JWT rejected — attempting token refresh")
- newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
+ slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
+ newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
if refreshErr != nil {
- slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
+ slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
@@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
// Stop the heartbeat loop — operator must re-register.
return true
}
- currentJWT = newJWT
- slog.Info("heartbeat: JWT refreshed successfully")
+ currentJWT = newCreds.JWT
+ slog.Info("heartbeat: credentials refreshed successfully")
+ if onCredsRefreshed != nil {
+ onCredsRefreshed(newCreds)
+ }
case http.StatusNotFound:
- slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting")
- if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) {
- slog.Warn("heartbeat: failed to remove token file", "error", err)
+ slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
+ if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
+ slog.Warn("heartbeat: failed to remove credentials file", "error", err)
}
if onDeleted != nil {
onDeleted()
@@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) {
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
-// the token file loader.
+// the credentials file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go
index fb7fb664..7cd78f4c 100644
--- a/internal/hostagent/server.go
+++ b/internal/hostagent/server.go
@@ -12,6 +12,8 @@ import (
"time"
"connectrpc.com/connect"
+ "github.com/google/uuid"
+ "github.com/jackc/pgx/v5/pgtype"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@@ -33,13 +35,35 @@ func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
return &Server{mgr: mgr, terminate: terminate}
}
+// parseUUIDString parses a UUID hex string into a pgtype.UUID.
+// An empty string yields an all-zeros UUID (valid).
+func parseUUIDString(s string) (pgtype.UUID, error) {
+ if s == "" {
+ return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil
+ }
+ parsed, err := uuid.Parse(s)
+ if err != nil {
+ return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err)
+ }
+ return pgtype.UUID{Bytes: parsed, Valid: true}, nil
+}
+
func (s *Server) CreateSandbox(
ctx context.Context,
req *connect.Request[pb.CreateSandboxRequest],
) (*connect.Response[pb.CreateSandboxResponse], error) {
msg := req.Msg
- sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
+ teamID, err := parseUUIDString(msg.TeamId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+ templateID, err := parseUUIDString(msg.TemplateId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+
+ sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
@@ -90,12 +114,21 @@ func (s *Server) CreateSnapshot(
ctx context.Context,
req *connect.Request[pb.CreateSnapshotRequest],
) (*connect.Response[pb.CreateSnapshotResponse], error) {
- sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
+ msg := req.Msg
+ teamID, err := parseUUIDString(msg.TeamId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+ templateID, err := parseUUIDString(msg.TemplateId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+
+ sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
}
return connect.NewResponse(&pb.CreateSnapshotResponse{
- Name: req.Msg.Name,
SizeBytes: sizeBytes,
}), nil
}
@@ -104,12 +137,45 @@ func (s *Server) DeleteSnapshot(
ctx context.Context,
req *connect.Request[pb.DeleteSnapshotRequest],
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
- if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
+ msg := req.Msg
+ teamID, err := parseUUIDString(msg.TeamId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+ templateID, err := parseUUIDString(msg.TemplateId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+
+ if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
}
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
}
+func (s *Server) FlattenRootfs(
+ ctx context.Context,
+ req *connect.Request[pb.FlattenRootfsRequest],
+) (*connect.Response[pb.FlattenRootfsResponse], error) {
+ msg := req.Msg
+ teamID, err := parseUUIDString(msg.TeamId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+ templateID, err := parseUUIDString(msg.TemplateId)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInvalidArgument, err)
+ }
+
+ sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
+ if err != nil {
+ return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
+ }
+ return connect.NewResponse(&pb.FlattenRootfsResponse{
+ SizeBytes: sizeBytes,
+ }), nil
+}
+
func (s *Server) PingSandbox(
ctx context.Context,
req *connect.Request[pb.PingSandboxRequest],
@@ -400,7 +466,8 @@ func (s *Server) ListSandboxes(
infos[i] = &pb.SandboxInfo{
SandboxId: sb.ID,
Status: string(sb.Status),
- Template: sb.Template,
+ TeamId: uuid.UUID(sb.TemplateTeamID).String(),
+ TemplateId: uuid.UUID(sb.TemplateID).String(),
Vcpus: int32(sb.VCPUs),
MemoryMb: int32(sb.MemoryMB),
HostIp: sb.HostIP.String(),
diff --git a/internal/id/id.go b/internal/id/id.go
index bbda47c7..f4b6cdb6 100644
--- a/internal/id/id.go
+++ b/internal/id/id.go
@@ -4,8 +4,167 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
+ "math/big"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/jackc/pgx/v5/pgtype"
)
+const (
+ base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
+ base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
+)
+
+var base36Base = big.NewInt(36)
+
+// --- Generation ---
+
+// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
+func newUUID() pgtype.UUID {
+ return pgtype.UUID{Bytes: uuid.New(), Valid: true}
+}
+
+func NewSandboxID() pgtype.UUID { return newUUID() }
+func NewUserID() pgtype.UUID { return newUUID() }
+func NewTeamID() pgtype.UUID { return newUUID() }
+func NewAPIKeyID() pgtype.UUID { return newUUID() }
+func NewHostID() pgtype.UUID { return newUUID() }
+func NewHostTokenID() pgtype.UUID { return newUUID() }
+func NewRefreshTokenID() pgtype.UUID { return newUUID() }
+func NewAuditLogID() pgtype.UUID { return newUUID() }
+func NewBuildID() pgtype.UUID { return newUUID() }
+func NewAdminPermissionID() pgtype.UUID { return newUUID() }
+
+func NewTemplateID() pgtype.UUID { return newUUID() }
+
+// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
+func NewSnapshotName() string {
+ return "template-" + hex8()
+}
+
+// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
+func NewTeamSlug() string {
+ b := make([]byte, 6)
+ if _, err := rand.Read(b); err != nil {
+ panic(fmt.Sprintf("crypto/rand failed: %v", err))
+ }
+ return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
+}
+
+// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
+func NewRegistrationToken() string {
+ return hexToken(32)
+}
+
+// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
+func NewRefreshToken() string {
+ return hexToken(32)
+}
+
+// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
+
+const (
+ PrefixSandbox = "cl-"
+ PrefixUser = "usr-"
+ PrefixTeam = "team-"
+ PrefixAPIKey = "key-"
+ PrefixHost = "host-"
+ PrefixHostToken = "htok-"
+ PrefixRefreshToken = "hrt-"
+ PrefixAuditLog = "log-"
+ PrefixBuild = "bld-"
+ PrefixAdminPermission = "perm-"
+)
+
+// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
+func UUIDToBase36(b [16]byte) string {
+ n := new(big.Int).SetBytes(b[:])
+ buf := make([]byte, base36IDLen)
+ mod := new(big.Int)
+ for i := base36IDLen - 1; i >= 0; i-- {
+ n.DivMod(n, base36Base, mod)
+ buf[i] = base36Alphabet[mod.Int64()]
+ }
+ return string(buf)
+}
+
+// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
+func base36ToUUID(s string) ([16]byte, error) {
+ if len(s) != base36IDLen {
+ return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
+ }
+ n := new(big.Int)
+ for _, c := range s {
+ idx := strings.IndexRune(base36Alphabet, c)
+ if idx < 0 {
+ return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
+ }
+ n.Mul(n, base36Base)
+ n.Add(n, big.NewInt(int64(idx)))
+ }
+ b := n.Bytes()
+ var out [16]byte
+ // big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
+ copy(out[16-len(b):], b)
+ return out, nil
+}
+
+func formatUUID(prefix string, id pgtype.UUID) string {
+ return prefix + UUIDToBase36(id.Bytes)
+}
+
+func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
+func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
+func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
+func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
+func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
+func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
+func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
+func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
+func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
+
+// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
+
+func parseUUID(prefix, s string) (pgtype.UUID, error) {
+ if !strings.HasPrefix(s, prefix) {
+ return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
+ }
+ b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
+ if err != nil {
+ return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
+ }
+ return pgtype.UUID{Bytes: b, Valid: true}, nil
+}
+
+func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
+func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
+func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
+func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
+func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
+func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
+func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
+func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
+
+// --- Well-known IDs ---
+
+// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
+// (e.g. base templates, shared infrastructure).
+var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
+
+// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
+// template. When both team_id and template_id are zero, the host agent uses
+// the minimal rootfs at WRENN_DIR/images/minimal/.
+var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
+
+// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
+// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
+func UUIDString(id pgtype.UUID) string {
+ return uuid.UUID(id.Bytes).String()
+}
+
+// --- Helpers ---
+
func hex8() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@@ -14,73 +173,8 @@ func hex8() string {
return hex.EncodeToString(b)
}
-// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars.
-func NewSandboxID() string {
- return "sb-" + hex8()
-}
-
-// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars.
-func NewSnapshotName() string {
- return "template-" + hex8()
-}
-
-// NewUserID generates a new user ID in the format "usr-" + 8 hex chars.
-func NewUserID() string {
- return "usr-" + hex8()
-}
-
-// NewTeamID generates a new team ID in the format "team-" + 8 hex chars.
-func NewTeamID() string {
- return "team-" + hex8()
-}
-
-// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy"
-// where each part is 3 random bytes encoded as hex (6 hex chars each).
-func NewTeamSlug() string {
- b := make([]byte, 6)
- if _, err := rand.Read(b); err != nil {
- panic(fmt.Sprintf("crypto/rand failed: %v", err))
- }
- return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
-}
-
-// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars.
-func NewAPIKeyID() string {
- return "key-" + hex8()
-}
-
-// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
-func NewHostID() string {
- return "host-" + hex8()
-}
-
-// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
-func NewHostTokenID() string {
- return "htok-" + hex8()
-}
-
-// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
-func NewRegistrationToken() string {
- b := make([]byte, 32)
- if _, err := rand.Read(b); err != nil {
- panic(fmt.Sprintf("crypto/rand failed: %v", err))
- }
- return hex.EncodeToString(b)
-}
-
-// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars.
-func NewRefreshTokenID() string {
- return "hrt-" + hex8()
-}
-
-// NewAuditLogID generates a new audit log ID in the format "log-" + 8 hex chars.
-func NewAuditLogID() string {
- return "log-" + hex8()
-}
-
-// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token.
-func NewRefreshToken() string {
- b := make([]byte, 32)
+func hexToken(nBytes int) string {
+ b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
diff --git a/internal/id/id_test.go b/internal/id/id_test.go
new file mode 100644
index 00000000..2000e9cb
--- /dev/null
+++ b/internal/id/id_test.go
@@ -0,0 +1,118 @@
+package id
+
+import (
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/jackc/pgx/v5/pgtype"
+)
+
+func TestBase36RoundTrip(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ orig := uuid.New()
+ encoded := UUIDToBase36(orig)
+
+ if len(encoded) != base36IDLen {
+ t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
+ }
+
+ decoded, err := base36ToUUID(encoded)
+ if err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+
+ if decoded != orig {
+ t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
+ }
+ }
+}
+
+func TestBase36ZeroUUID(t *testing.T) {
+ var zero [16]byte
+ encoded := UUIDToBase36(zero)
+ if encoded != "0000000000000000000000000" {
+ t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
+ }
+ decoded, err := base36ToUUID(encoded)
+ if err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+ if decoded != zero {
+ t.Fatalf("round-trip failed for zero UUID")
+ }
+}
+
+func TestFormatParseRoundTrip(t *testing.T) {
+ id := NewSandboxID()
+ formatted := FormatSandboxID(id)
+
+ if formatted[:3] != "cl-" {
+ t.Fatalf("expected cl- prefix, got %s", formatted)
+ }
+ if len(formatted) != 3+base36IDLen {
+ t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
+ }
+
+ parsed, err := ParseSandboxID(formatted)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+ if parsed != id {
+ t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
+ }
+}
+
+func TestBase36InvalidInput(t *testing.T) {
+ // Wrong length.
+ if _, err := base36ToUUID("abc"); err == nil {
+ t.Fatal("expected error for short input")
+ }
+ // Invalid character.
+ if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
+ t.Fatal("expected error for invalid character")
+ }
+}
+
+func TestPlatformTeamIDFormats(t *testing.T) {
+ formatted := FormatTeamID(PlatformTeamID)
+ parsed, err := ParseTeamID(formatted)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+ if parsed != PlatformTeamID {
+ t.Fatalf("platform team ID round-trip failed")
+ }
+}
+
+func TestMaxUUID(t *testing.T) {
+ max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
+ encoded := UUIDToBase36(max)
+ if len(encoded) != base36IDLen {
+ t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
+ }
+ decoded, err := base36ToUUID(encoded)
+ if err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+ if decoded != max {
+ t.Fatalf("round-trip failed for max UUID")
+ }
+}
+
+func BenchmarkFormatSandboxID(b *testing.B) {
+ id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ FormatSandboxID(id)
+ }
+}
+
+func BenchmarkParseSandboxID(b *testing.B) {
+ id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
+ s := FormatSandboxID(id)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, _ = ParseSandboxID(s)
+ }
+}
diff --git a/internal/layout/layout.go b/internal/layout/layout.go
new file mode 100644
index 00000000..d4084f5f
--- /dev/null
+++ b/internal/layout/layout.go
@@ -0,0 +1,58 @@
+package layout
+
+import (
+ "path/filepath"
+
+ "github.com/jackc/pgx/v5/pgtype"
+
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+)
+
+// IsMinimal reports whether the given team and template IDs represent the
+// built-in "minimal" template (both all-zeros).
+func IsMinimal(teamID, templateID pgtype.UUID) bool {
+ return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
+}
+
+// TemplateDir returns the on-disk directory for a template.
+//
+// minimal (zeros, zeros): {wrennDir}/images/minimal
+// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
+func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
+ if IsMinimal(teamID, templateID) {
+ return filepath.Join(wrennDir, "images", "minimal")
+ }
+ return filepath.Join(wrennDir, "images", "teams",
+ id.UUIDToBase36(teamID.Bytes),
+ id.UUIDToBase36(templateID.Bytes))
+}
+
+// TemplateRootfs returns the path to a template's rootfs.ext4.
+func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
+ return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
+}
+
+// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
+func PauseSnapshotDir(wrennDir, sandboxID string) string {
+ return filepath.Join(wrennDir, "snapshots", sandboxID)
+}
+
+// SandboxesDir returns the directory for running sandbox CoW files.
+func SandboxesDir(wrennDir string) string {
+ return filepath.Join(wrennDir, "sandboxes")
+}
+
+// KernelPath returns the path to the Firecracker kernel.
+func KernelPath(wrennDir string) string {
+ return filepath.Join(wrennDir, "kernels", "vmlinux")
+}
+
+// ImagesRoot returns the root images directory.
+func ImagesRoot(wrennDir string) string {
+ return filepath.Join(wrennDir, "images")
+}
+
+// TeamsDir returns the directory containing all team template subdirectories.
+func TeamsDir(wrennDir string) string {
+ return filepath.Join(wrennDir, "images", "teams")
+}
diff --git a/internal/layout/layout_test.go b/internal/layout/layout_test.go
new file mode 100644
index 00000000..f7b9afd3
--- /dev/null
+++ b/internal/layout/layout_test.go
@@ -0,0 +1,120 @@
+package layout
+
+import (
+ "path/filepath"
+ "testing"
+
+ "github.com/jackc/pgx/v5/pgtype"
+
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+)
+
+func TestIsMinimal(t *testing.T) {
+ tests := []struct {
+ name string
+ teamID pgtype.UUID
+ templateID pgtype.UUID
+ want bool
+ }{
+ {
+ name: "both zeros",
+ teamID: id.PlatformTeamID,
+ templateID: id.MinimalTemplateID,
+ want: true,
+ },
+ {
+ name: "non-zero team",
+ teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
+ templateID: id.MinimalTemplateID,
+ want: false,
+ },
+ {
+ name: "non-zero template",
+ teamID: id.PlatformTeamID,
+ templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
+ want: false,
+ },
+ {
+ name: "both non-zero",
+ teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
+ templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
+ t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTemplateDir(t *testing.T) {
+ wrennDir := "/var/lib/wrenn"
+
+ t.Run("minimal", func(t *testing.T) {
+ got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
+ want := filepath.Join(wrennDir, "images", "minimal")
+ if got != want {
+ t.Errorf("TemplateDir() = %q, want %q", got, want)
+ }
+ })
+
+ t.Run("team template", func(t *testing.T) {
+ teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}
+ tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true}
+ got := TemplateDir(wrennDir, teamID, tmplID)
+ want := filepath.Join(wrennDir, "images", "teams",
+ id.UUIDToBase36(teamID.Bytes),
+ id.UUIDToBase36(tmplID.Bytes))
+ if got != want {
+ t.Errorf("TemplateDir() = %q, want %q", got, want)
+ }
+ })
+
+ t.Run("global template (platform team, non-zero template)", func(t *testing.T) {
+ tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true}
+ got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID)
+ want := filepath.Join(wrennDir, "images", "teams",
+ id.UUIDToBase36(id.PlatformTeamID.Bytes),
+ id.UUIDToBase36(tmplID.Bytes))
+ if got != want {
+ t.Errorf("TemplateDir() = %q, want %q", got, want)
+ }
+ })
+}
+
+func TestTemplateRootfs(t *testing.T) {
+ wrennDir := "/var/lib/wrenn"
+ got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
+ want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
+ if got != want {
+ t.Errorf("TemplateRootfs() = %q, want %q", got, want)
+ }
+}
+
+func TestPauseSnapshotDir(t *testing.T) {
+ got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
+ want := "/var/lib/wrenn/snapshots/cl-abc123"
+ if got != want {
+ t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
+ }
+}
+
+func TestSandboxesDir(t *testing.T) {
+ got := SandboxesDir("/var/lib/wrenn")
+ want := "/var/lib/wrenn/sandboxes"
+ if got != want {
+ t.Errorf("SandboxesDir() = %q, want %q", got, want)
+ }
+}
+
+func TestKernelPath(t *testing.T) {
+ got := KernelPath("/var/lib/wrenn")
+ want := "/var/lib/wrenn/kernels/vmlinux"
+ if got != want {
+ t.Errorf("KernelPath() = %q, want %q", got, want)
+ }
+}
diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go
index 0caf5ece..f5784898 100644
--- a/internal/lifecycle/hostpool.go
+++ b/internal/lifecycle/hostpool.go
@@ -1,6 +1,7 @@
package lifecycle
import (
+ "crypto/tls"
"fmt"
"net/http"
"strings"
@@ -8,6 +9,7 @@ import (
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
@@ -18,14 +20,33 @@ type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
+ scheme string // "http://" or "https://"
}
-// NewHostClientPool creates a new pool. The underlying HTTP client uses a
-// 10-minute timeout to support long-running streaming operations.
+// NewHostClientPool creates a pool that connects to agents over plain HTTP.
+// Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
+ scheme: "http://",
+ }
+}
+
+// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
+// tlsCfg should already carry the CP client cert and CA trust anchor
+// (use auth.CPClientTLSConfig to construct it).
+func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
+ transport := &http.Transport{
+ TLSClientConfig: tlsCfg,
+ }
+ return &HostClientPool{
+ clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
+ httpClient: &http.Client{
+ Timeout: 10 * time.Minute,
+ Transport: transport,
+ },
+ scheme: "https://",
}
}
@@ -45,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
if c, ok = p.clients[hostID]; ok {
return c
}
- c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address))
+ c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c
return c
}
@@ -53,10 +74,10 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
- if !h.Address.Valid || h.Address.String == "" {
- return nil, fmt.Errorf("host %s has no address", h.ID)
+ if h.Address == "" {
+ return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
}
- return p.Get(h.ID, h.Address.String), nil
+ return p.Get(id.FormatHostID(h.ID), h.Address), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
@@ -68,8 +89,35 @@ func (p *HostClientPool) Evict(hostID string) {
p.mu.Unlock()
}
-// ensureScheme adds "http://" if the address has no scheme.
-func ensureScheme(addr string) string {
+// ensureScheme prepends the pool's configured scheme if the address has none.
+func (p *HostClientPool) ensureScheme(addr string) string {
+ if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
+ return addr
+ }
+ return p.scheme + addr
+}
+
+// Transport returns the http.RoundTripper used by this pool. Use this when you
+// need to make raw HTTP requests to agent addresses with the same TLS settings
+// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
+func (p *HostClientPool) Transport() http.RoundTripper {
+ if p.httpClient.Transport != nil {
+ return p.httpClient.Transport
+ }
+ return http.DefaultTransport
+}
+
+// ResolveAddr prepends the pool's configured scheme to addr if it has none.
+// Use this when constructing URLs that must use the same transport as the pool
+// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
+// the same thing, but ResolveAddr exposes it for callers that only need the URL.
+func (p *HostClientPool) ResolveAddr(addr string) string {
+ return p.ensureScheme(addr)
+}
+
+// EnsureScheme adds "http://" if the address has no scheme.
+// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
+func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
diff --git a/internal/models/sandbox.go b/internal/models/sandbox.go
index b99bd6b8..ab72cd3e 100644
--- a/internal/models/sandbox.go
+++ b/internal/models/sandbox.go
@@ -18,15 +18,16 @@ const (
// Sandbox holds all state for a running sandbox on this host.
type Sandbox struct {
- ID string
- Status SandboxStatus
- Template string
- VCPUs int
- MemoryMB int
- TimeoutSec int
- SlotIndex int
- HostIP net.IP
- RootfsPath string
- CreatedAt time.Time
- LastActiveAt time.Time
+ ID string
+ Status SandboxStatus
+ TemplateTeamID [16]byte
+ TemplateID [16]byte
+ VCPUs int
+ MemoryMB int
+ TimeoutSec int
+ SlotIndex int
+ HostIP net.IP
+ RootfsPath string
+ CreatedAt time.Time
+ LastActiveAt time.Time
}
diff --git a/internal/network/setup.go b/internal/network/setup.go
index 70a8a547..ee06d391 100644
--- a/internal/network/setup.go
+++ b/internal/network/setup.go
@@ -5,13 +5,91 @@ import (
"fmt"
"log/slog"
"net"
+ "os"
"os/exec"
"runtime"
+ "strings"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
+const nsPrefix = "wrenn-ns-"
+
+// CleanupStaleNamespaces removes leftover wrenn network namespaces from a
+// previous crash. Called once at agent startup.
+func CleanupStaleNamespaces() {
+ entries, err := os.ReadDir("/run/netns")
+ if err != nil {
+ return // no /run/netns or unreadable — nothing to clean
+ }
+ for _, e := range entries {
+ name := e.Name()
+ if !strings.HasPrefix(name, nsPrefix) {
+ continue
+ }
+ // Also remove the associated veth from the host side.
+ vethName := "wrenn-veth-" + strings.TrimPrefix(name, nsPrefix)
+ if link, err := netlink.LinkByName(vethName); err == nil {
+ _ = netlink.LinkDel(link)
+ }
+ if err := netns.DeleteNamed(name); err != nil {
+ slog.Warn("failed to remove stale namespace", "ns", name, "error", err)
+ } else {
+ slog.Info("removed stale namespace", "ns", name)
+ }
+ }
+
+ // Clean up any stale wrenn iptables rules referencing old veth interfaces.
+ cleanupStaleIptablesRules()
+}
+
+// cleanupStaleIptablesRules removes host iptables rules that reference
+// wrenn-veth interfaces no longer present on the system.
+func cleanupStaleIptablesRules() {
+ for _, table := range []string{"filter", "nat"} {
+ cmd := exec.Command("iptables-save", "-t", table)
+ out, err := cmd.Output()
+ if err != nil {
+ continue
+ }
+ for _, line := range strings.Split(string(out), "\n") {
+ if !strings.Contains(line, "wrenn-veth-") {
+ continue
+ }
+ // Lines look like "-A FORWARD -i wrenn-veth-1 -o wlo1 -j ACCEPT"
+ // Convert -A to -D to delete the rule.
+ if !strings.HasPrefix(line, "-A ") {
+ continue
+ }
+ delRule := "-D " + line[3:]
+ args := strings.Fields(delRule)
+ delCmd := exec.Command("iptables", append([]string{"-t", table}, args...)...)
+ if err := delCmd.Run(); err != nil {
+ slog.Debug("failed to remove stale iptables rule", "rule", line, "error", err)
+ }
+ }
+ }
+
+ // Also remove stale host routes to 10.11.0.x via wrenn-veth interfaces.
+ routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
+ if err != nil {
+ return
+ }
+ for _, r := range routes {
+ if r.LinkIndex == 0 {
+ continue
+ }
+ link, err := netlink.LinkByIndex(r.LinkIndex)
+ if err != nil {
+ continue
+ }
+ if strings.HasPrefix(link.Attrs().Name, "wrenn-veth-") {
+ _ = netlink.RouteDel(&r)
+ }
+ }
+}
+
const (
// Fixed addresses inside each network namespace (safe because each
// sandbox gets its own netns).
@@ -84,8 +162,8 @@ func NewSlot(index int) *Slot {
GuestIP: guestIP,
GuestNetMask: guestNetMask,
TapName: tapName,
- NamespaceID: fmt.Sprintf("ns-%d", index),
- VethName: fmt.Sprintf("veth-%d", index),
+ NamespaceID: fmt.Sprintf("wrenn-ns-%d", index),
+ VethName: fmt.Sprintf("wrenn-veth-%d", index),
}
}
diff --git a/internal/recipe/context.go b/internal/recipe/context.go
new file mode 100644
index 00000000..db4c39cc
--- /dev/null
+++ b/internal/recipe/context.go
@@ -0,0 +1,63 @@
+package recipe
+
+import "strings"
+
+// ExecContext holds mutable state that persists across recipe steps.
+// It is initialized empty and updated by ENV and WORKDIR steps.
+type ExecContext struct {
+ WorkDir string
+ EnvVars map[string]string
+}
+
+// WrappedCommand returns the full shell command for a RUN step with context
+// applied. The result is passed as the argument to /bin/sh -c.
+//
+// If WORKDIR and/or ENV are set, they are prepended as a shell preamble:
+//
+// cd '/the/dir' && KEY='val' /bin/sh -c 'original command'
+func (c *ExecContext) WrappedCommand(cmd string) string {
+ prefix := c.shellPrefix()
+ if prefix == "" {
+ return cmd
+ }
+ return prefix + "/bin/sh -c " + shellescape(cmd)
+}
+
+// StartCommand returns the shell command for a START step. The process is
+// launched in the background via nohup so that the outer shell exits
+// immediately, allowing the build to continue. stdout/stderr of the
+// background process are discarded (the process keeps running in the VM).
+//
+// Multiple START steps can be issued to run several background processes
+// simultaneously before a healthcheck is evaluated.
+func (c *ExecContext) StartCommand(cmd string) string {
+ prefix := c.shellPrefix()
+ return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
+}
+
+// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command.
+// Returns an empty string when no context is set.
+func (c *ExecContext) shellPrefix() string {
+ if c.WorkDir == "" && len(c.EnvVars) == 0 {
+ return ""
+ }
+ var sb strings.Builder
+ if c.WorkDir != "" {
+ sb.WriteString("cd ")
+ sb.WriteString(shellescape(c.WorkDir))
+ sb.WriteString(" && ")
+ }
+ for k, v := range c.EnvVars {
+ sb.WriteString(k)
+ sb.WriteByte('=')
+ sb.WriteString(shellescape(v))
+ sb.WriteByte(' ')
+ }
+ return sb.String()
+}
+
+// shellescape wraps s in single quotes, escaping any embedded single quotes.
+// This is POSIX-safe for paths, env values, and shell commands.
+func shellescape(s string) string {
+ return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
+}
diff --git a/internal/recipe/context_test.go b/internal/recipe/context_test.go
new file mode 100644
index 00000000..b00dfce0
--- /dev/null
+++ b/internal/recipe/context_test.go
@@ -0,0 +1,114 @@
+package recipe
+
+import "testing"
+
+func TestExecContext_WrappedCommand(t *testing.T) {
+ tests := []struct {
+ name string
+ ctx ExecContext
+ cmd string
+ want string
+ }{
+ {
+ name: "no context",
+ ctx: ExecContext{},
+ cmd: "apt install -y curl",
+ want: "apt install -y curl",
+ },
+ {
+ name: "workdir only",
+ ctx: ExecContext{WorkDir: "/app"},
+ cmd: "npm install",
+ want: "cd '/app' && /bin/sh -c 'npm install'",
+ },
+ {
+ name: "env only",
+ ctx: ExecContext{EnvVars: map[string]string{"PORT": "8080"}},
+ cmd: "node server.js",
+ want: "PORT='8080' /bin/sh -c 'node server.js'",
+ },
+ {
+ name: "workdir with space",
+ ctx: ExecContext{WorkDir: "/my project"},
+ cmd: "make build",
+ want: "cd '/my project' && /bin/sh -c 'make build'",
+ },
+ {
+ name: "command with single quotes",
+ ctx: ExecContext{WorkDir: "/app"},
+ cmd: "echo 'hello'",
+ want: "cd '/app' && /bin/sh -c 'echo '\\''hello'\\'''",
+ },
+ {
+ name: "env value with single quotes",
+ ctx: ExecContext{EnvVars: map[string]string{"MSG": "it's fine"}},
+ cmd: "echo $MSG",
+ want: "MSG='it'\\''s fine' /bin/sh -c 'echo $MSG'",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := tc.ctx.WrappedCommand(tc.cmd)
+ if got != tc.want {
+ t.Errorf("WrappedCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestExecContext_StartCommand(t *testing.T) {
+ tests := []struct {
+ name string
+ ctx ExecContext
+ cmd string
+ want string
+ }{
+ {
+ name: "no context",
+ ctx: ExecContext{},
+ cmd: "python3 app.py",
+ want: "nohup /bin/sh -c 'python3 app.py' >/dev/null 2>&1 &",
+ },
+ {
+ name: "with workdir",
+ ctx: ExecContext{WorkDir: "/app"},
+ cmd: "python3 server.py",
+ want: "cd '/app' && nohup /bin/sh -c 'python3 server.py' >/dev/null 2>&1 &",
+ },
+ {
+ name: "with env",
+ ctx: ExecContext{EnvVars: map[string]string{"PORT": "9000"}},
+ cmd: "node index.js",
+ want: "PORT='9000' nohup /bin/sh -c 'node index.js' >/dev/null 2>&1 &",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := tc.ctx.StartCommand(tc.cmd)
+ if got != tc.want {
+ t.Errorf("StartCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestShellescape(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"simple", "'simple'"},
+ {"/path/to/dir", "'/path/to/dir'"},
+ {"it's fine", "'it'\\''s fine'"},
+ {"", "''"},
+ {"a'b'c", "'a'\\''b'\\''c'"},
+ }
+ for _, tc := range tests {
+ got := shellescape(tc.input)
+ if got != tc.want {
+ t.Errorf("shellescape(%q) = %q, want %q", tc.input, got, tc.want)
+ }
+ }
+}
diff --git a/internal/recipe/executor.go b/internal/recipe/executor.go
new file mode 100644
index 00000000..3df45dc5
--- /dev/null
+++ b/internal/recipe/executor.go
@@ -0,0 +1,185 @@
+package recipe
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "strings"
+ "time"
+
+ "connectrpc.com/connect"
+
+ pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
+)
+
+// DefaultStepTimeout is the fallback timeout for RUN steps that carry no
+// explicit --timeout flag.
+const DefaultStepTimeout = 30 * time.Second
+
+// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB).
+type BuildLogEntry struct {
+ Step int `json:"step"`
+ Phase string `json:"phase"`
+ Cmd string `json:"cmd"`
+ Stdout string `json:"stdout"`
+ Stderr string `json:"stderr"`
+ Exit int32 `json:"exit"`
+ Ok bool `json:"ok"`
+ Elapsed int64 `json:"elapsed_ms"`
+}
+
+// ExecFunc is the agent.Exec call signature used by the executor. It matches
+// the method on the hostagent Connect RPC client.
+type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
+
+// Execute runs steps sequentially against sandboxID using execFn.
+//
+// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
+// - startStep is the 1-based offset so entries are globally numbered across phases.
+// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes.
+// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward
+// into subsequent phases when the caller passes the same pointer.
+//
+// Returns all log entries appended during this call, the next step counter
+// value, and whether all steps succeeded. On false the last entry contains
+// failure details; the caller is responsible for destroying the sandbox and
+// recording the build error.
+func Execute(
+ ctx context.Context,
+ phase string,
+ steps []Step,
+ sandboxID string,
+ startStep int,
+ defaultTimeout time.Duration,
+ bctx *ExecContext,
+ execFn ExecFunc,
+) (entries []BuildLogEntry, nextStep int, ok bool) {
+ if defaultTimeout <= 0 {
+ defaultTimeout = 10 * time.Minute
+ }
+
+ step := startStep
+ for _, st := range steps {
+ step++
+ slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw)
+
+ switch st.Kind {
+ case KindENV:
+ if bctx.EnvVars == nil {
+ bctx.EnvVars = make(map[string]string)
+ }
+ bctx.EnvVars[st.Key] = st.Value
+ entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
+
+ case KindWORKDIR:
+ bctx.WorkDir = st.Path
+ entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
+
+ case KindUSER, KindCOPY:
+ verb := strings.ToUpper(strings.Fields(st.Raw)[0])
+ entries = append(entries, BuildLogEntry{
+ Step: step,
+ Phase: phase,
+ Cmd: st.Raw,
+ Stderr: verb + " is not yet supported",
+ Ok: false,
+ })
+ return entries, step, false
+
+ case KindSTART:
+ entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn)
+ entries = append(entries, entry)
+ if !succeeded {
+ return entries, step, false
+ }
+
+ case KindRUN:
+ timeout := defaultTimeout
+ if st.Timeout > 0 {
+ timeout = st.Timeout
+ }
+ entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
+ entries = append(entries, entry)
+ if !succeeded {
+ return entries, step, false
+ }
+ }
+ }
+ return entries, step, true
+}
+
+func execRun(
+ ctx context.Context,
+ st Step,
+ sandboxID, phase string,
+ step int,
+ timeout time.Duration,
+ bctx *ExecContext,
+ execFn ExecFunc,
+) (BuildLogEntry, bool) {
+ execCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ start := time.Now()
+ resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
+ SandboxId: sandboxID,
+ Cmd: "/bin/sh",
+ Args: []string{"-c", bctx.WrappedCommand(st.Shell)},
+ TimeoutSec: int32(timeout.Seconds()),
+ }))
+
+ entry := BuildLogEntry{
+ Step: step,
+ Phase: phase,
+ Cmd: st.Raw,
+ Elapsed: time.Since(start).Milliseconds(),
+ }
+ if err != nil {
+ entry.Stderr = fmt.Sprintf("exec error: %v", err)
+ return entry, false
+ }
+ entry.Stdout = string(resp.Msg.Stdout)
+ entry.Stderr = string(resp.Msg.Stderr)
+ entry.Exit = resp.Msg.ExitCode
+ entry.Ok = resp.Msg.ExitCode == 0
+ return entry, entry.Ok
+}
+
+func execStart(
+ ctx context.Context,
+ st Step,
+ sandboxID, phase string,
+ step int,
+ bctx *ExecContext,
+ execFn ExecFunc,
+) (BuildLogEntry, bool) {
+ // START uses a short timeout: just long enough for the shell to fork and
+ // return. The background process itself runs indefinitely inside the VM.
+ execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ start := time.Now()
+ resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
+ SandboxId: sandboxID,
+ Cmd: "/bin/sh",
+ Args: []string{"-c", bctx.StartCommand(st.Shell)},
+ TimeoutSec: 10,
+ }))
+
+ entry := BuildLogEntry{
+ Step: step,
+ Phase: phase,
+ Cmd: st.Raw,
+ Elapsed: time.Since(start).Milliseconds(),
+ }
+ if err != nil {
+ entry.Stderr = fmt.Sprintf("start error: %v", err)
+ return entry, false
+ }
+ entry.Exit = resp.Msg.ExitCode
+ entry.Ok = resp.Msg.ExitCode == 0
+ if !entry.Ok {
+ entry.Stderr = fmt.Sprintf("start failed with exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
+ }
+ return entry, entry.Ok
+}
diff --git a/internal/recipe/step.go b/internal/recipe/step.go
new file mode 100644
index 00000000..7d510362
--- /dev/null
+++ b/internal/recipe/step.go
@@ -0,0 +1,129 @@
+package recipe
+
+import (
+ "fmt"
+ "strings"
+ "time"
+)
+
+// Kind identifies the instruction type in a recipe line.
+type Kind int
+
+const (
+ KindRUN Kind = iota // Execute a command and wait for it to exit.
+ KindSTART // Start a command in the background (non-blocking).
+ KindENV // Set an environment variable for subsequent steps.
+ KindWORKDIR // Set the working directory for subsequent steps.
+ KindUSER // Switch the unix user for subsequent steps. (stub)
+ KindCOPY // Copy files into the sandbox. (stub)
+)
+
+// Step is the parsed representation of one recipe instruction.
+type Step struct {
+ Kind Kind
+ Raw string // original string, preserved for logging
+ Shell string // KindRUN, KindSTART: the shell command text
+ Timeout time.Duration // KindRUN: 0 means use caller's default
+ Key string // KindENV: variable name
+ Value string // KindENV: variable value
+ Path string // KindWORKDIR: directory path
+}
+
+// ParseStep parses a single recipe instruction string into a Step.
+// Instructions are Dockerfile-like: a keyword followed by arguments.
+//
+// Supported syntax:
+//
+// RUN — run command, wait for exit
+// RUN --timeout= — run command with explicit timeout (e.g. --timeout=5m)
+// START — start command in background, return immediately
+// ENV = — set environment variable
+// WORKDIR — set working directory
+// USER — not yet supported
+// COPY — not yet supported
+func ParseStep(s string) (Step, error) {
+ s = strings.TrimSpace(s)
+ if s == "" {
+ return Step{}, fmt.Errorf("empty step")
+ }
+
+ // Split on first space to get the keyword.
+ keyword, rest, _ := strings.Cut(s, " ")
+ rest = strings.TrimSpace(rest)
+
+ switch strings.ToUpper(keyword) {
+ case "RUN":
+ return parseRUN(s, rest)
+ case "START":
+ return parseSTART(s, rest)
+ case "ENV":
+ return parseENV(s, rest)
+ case "WORKDIR":
+ return parseWORKDIR(s, rest)
+ case "USER":
+ return Step{Kind: KindUSER, Raw: s}, nil
+ case "COPY":
+ return Step{Kind: KindCOPY, Raw: s}, nil
+ default:
+ return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword)
+ }
+}
+
+// ParseRecipe parses all recipe lines, returning on the first error.
+func ParseRecipe(lines []string) ([]Step, error) {
+ steps := make([]Step, 0, len(lines))
+ for i, line := range lines {
+ st, err := ParseStep(line)
+ if err != nil {
+ return nil, fmt.Errorf("recipe line %d: %w", i+1, err)
+ }
+ steps = append(steps, st)
+ }
+ return steps, nil
+}
+
+func parseRUN(raw, rest string) (Step, error) {
+ var timeout time.Duration
+ if strings.HasPrefix(rest, "--timeout=") {
+ rest = rest[len("--timeout="):]
+ flag, cmd, found := strings.Cut(rest, " ")
+ if !found || strings.TrimSpace(cmd) == "" {
+ return Step{}, fmt.Errorf("RUN --timeout= flag has no command: %q", raw)
+ }
+ d, err := time.ParseDuration(flag)
+ if err != nil {
+ return Step{}, fmt.Errorf("RUN --timeout= invalid duration %q: %w", flag, err)
+ }
+ timeout = d
+ rest = strings.TrimSpace(cmd)
+ }
+ if rest == "" {
+ return Step{}, fmt.Errorf("RUN requires a command: %q", raw)
+ }
+ return Step{Kind: KindRUN, Raw: raw, Shell: rest, Timeout: timeout}, nil
+}
+
+func parseSTART(raw, rest string) (Step, error) {
+ if rest == "" {
+ return Step{}, fmt.Errorf("START requires a command: %q", raw)
+ }
+ return Step{Kind: KindSTART, Raw: raw, Shell: rest}, nil
+}
+
+func parseENV(raw, rest string) (Step, error) {
+ key, value, found := strings.Cut(rest, "=")
+ if !found {
+ return Step{}, fmt.Errorf("ENV requires KEY=VALUE format: %q", raw)
+ }
+ if key == "" {
+ return Step{}, fmt.Errorf("ENV key is empty: %q", raw)
+ }
+ return Step{Kind: KindENV, Raw: raw, Key: key, Value: value}, nil
+}
+
+func parseWORKDIR(raw, path string) (Step, error) {
+ if path == "" {
+ return Step{}, fmt.Errorf("WORKDIR requires a path: %q", raw)
+ }
+ return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil
+}
diff --git a/internal/recipe/step_test.go b/internal/recipe/step_test.go
new file mode 100644
index 00000000..2370bb21
--- /dev/null
+++ b/internal/recipe/step_test.go
@@ -0,0 +1,208 @@
+package recipe
+
+import (
+ "testing"
+ "time"
+)
+
+func TestParseStep(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want Step
+ wantErr bool
+ }{
+ // RUN
+ {
+ name: "RUN basic",
+ input: "RUN apt install -y curl",
+ want: Step{Kind: KindRUN, Raw: "RUN apt install -y curl", Shell: "apt install -y curl"},
+ },
+ {
+ name: "RUN lowercase",
+ input: "run echo hello",
+ want: Step{Kind: KindRUN, Raw: "run echo hello", Shell: "echo hello"},
+ },
+ {
+ name: "RUN with timeout",
+ input: "RUN --timeout=5m npm install",
+ want: Step{Kind: KindRUN, Raw: "RUN --timeout=5m npm install", Shell: "npm install", Timeout: 5 * time.Minute},
+ },
+ {
+ name: "RUN with timeout seconds",
+ input: "RUN --timeout=30s make build",
+ want: Step{Kind: KindRUN, Raw: "RUN --timeout=30s make build", Shell: "make build", Timeout: 30 * time.Second},
+ },
+ {
+ name: "RUN no command",
+ input: "RUN",
+ wantErr: true,
+ },
+ {
+ name: "RUN timeout no command",
+ input: "RUN --timeout=5m",
+ wantErr: true,
+ },
+ {
+ name: "RUN invalid timeout",
+ input: "RUN --timeout=notaduration echo hi",
+ wantErr: true,
+ },
+ // START
+ {
+ name: "START basic",
+ input: "START python3 app.py",
+ want: Step{Kind: KindSTART, Raw: "START python3 app.py", Shell: "python3 app.py"},
+ },
+ {
+ name: "START uppercase",
+ input: "START node server.js --port=8080",
+ want: Step{Kind: KindSTART, Raw: "START node server.js --port=8080", Shell: "node server.js --port=8080"},
+ },
+ {
+ name: "START no command",
+ input: "START",
+ wantErr: true,
+ },
+ // ENV
+ {
+ name: "ENV basic",
+ input: "ENV FOO=bar",
+ want: Step{Kind: KindENV, Raw: "ENV FOO=bar", Key: "FOO", Value: "bar"},
+ },
+ {
+ name: "ENV value with spaces",
+ input: "ENV GREETING=hello world",
+ want: Step{Kind: KindENV, Raw: "ENV GREETING=hello world", Key: "GREETING", Value: "hello world"},
+ },
+ {
+ name: "ENV value with equals sign",
+ input: "ENV URL=http://example.com?a=1",
+ want: Step{Kind: KindENV, Raw: "ENV URL=http://example.com?a=1", Key: "URL", Value: "http://example.com?a=1"},
+ },
+ {
+ name: "ENV empty value",
+ input: "ENV FOO=",
+ want: Step{Kind: KindENV, Raw: "ENV FOO=", Key: "FOO", Value: ""},
+ },
+ {
+ name: "ENV missing equals",
+ input: "ENV FOO",
+ wantErr: true,
+ },
+ {
+ name: "ENV empty key",
+ input: "ENV =value",
+ wantErr: true,
+ },
+ // WORKDIR
+ {
+ name: "WORKDIR basic",
+ input: "WORKDIR /app",
+ want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /app", Path: "/app"},
+ },
+ {
+ name: "WORKDIR with spaces in path",
+ input: "WORKDIR /my project",
+ want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /my project", Path: "/my project"},
+ },
+ {
+ name: "WORKDIR empty",
+ input: "WORKDIR",
+ wantErr: true,
+ },
+ // USER and COPY stubs
+ {
+ name: "USER stub",
+ input: "USER www-data",
+ want: Step{Kind: KindUSER, Raw: "USER www-data"},
+ },
+ {
+ name: "COPY stub",
+ input: "COPY config.yaml /etc/app/config.yaml",
+ want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"},
+ },
+ // Unknown keyword
+ {
+ name: "unknown keyword",
+ input: "FROBNICATE something",
+ wantErr: true,
+ },
+ // Empty input
+ {
+ name: "empty string",
+ input: "",
+ wantErr: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got, err := ParseStep(tc.input)
+ if tc.wantErr {
+ if err == nil {
+ t.Fatalf("ParseStep(%q) expected error, got %+v", tc.input, got)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err)
+ }
+ if got != tc.want {
+ t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestParseRecipe(t *testing.T) {
+ t.Run("valid recipe", func(t *testing.T) {
+ lines := []string{
+ "RUN apt update",
+ "WORKDIR /app",
+ "ENV PORT=8080",
+ "START python3 server.py",
+ "RUN --timeout=2m pip install -r requirements.txt",
+ }
+ steps, err := ParseRecipe(lines)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(steps) != 5 {
+ t.Fatalf("expected 5 steps, got %d", len(steps))
+ }
+ if steps[0].Kind != KindRUN {
+ t.Errorf("step 0: want KindRUN, got %v", steps[0].Kind)
+ }
+ if steps[1].Kind != KindWORKDIR {
+ t.Errorf("step 1: want KindWORKDIR, got %v", steps[1].Kind)
+ }
+ if steps[3].Kind != KindSTART {
+ t.Errorf("step 3: want KindSTART, got %v", steps[3].Kind)
+ }
+ if steps[4].Timeout != 2*time.Minute {
+ t.Errorf("step 4: want 2m timeout, got %v", steps[4].Timeout)
+ }
+ })
+
+ t.Run("error on invalid line", func(t *testing.T) {
+ lines := []string{
+ "RUN apt update",
+ "BADCMD something",
+ }
+ _, err := ParseRecipe(lines)
+ if err == nil {
+ t.Fatal("expected error for invalid line, got nil")
+ }
+ })
+
+ t.Run("empty recipe", func(t *testing.T) {
+ steps, err := ParseRecipe(nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(steps) != 0 {
+ t.Fatalf("expected 0 steps, got %d", len(steps))
+ }
+ })
+}
diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go
new file mode 100644
index 00000000..d9eac72b
--- /dev/null
+++ b/internal/sandbox/conntracker.go
@@ -0,0 +1,66 @@
+package sandbox
+
+import (
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// ConnTracker tracks active proxy connections for a single sandbox and
+// provides a drain mechanism for pre-pause graceful shutdown.
+// It is safe for concurrent use.
+type ConnTracker struct {
+ draining atomic.Bool
+ wg sync.WaitGroup
+}
+
+// Acquire registers one in-flight connection. Returns false if the tracker
+// is already draining; the caller must not call Release in that case.
+func (t *ConnTracker) Acquire() bool {
+ if t.draining.Load() {
+ return false
+ }
+ t.wg.Add(1)
+ // Re-check after Add: Drain may have set draining between our Load
+ // and Add. If so, undo the Add and reject the connection.
+ if t.draining.Load() {
+ t.wg.Done()
+ return false
+ }
+ return true
+}
+
+// Release marks one connection as complete. Must be called exactly once
+// per successful Acquire.
+func (t *ConnTracker) Release() {
+ t.wg.Done()
+}
+
+// Drain marks the tracker as draining (all future Acquire calls return
+// false) and waits up to timeout for in-flight connections to finish.
+//
+// Note: if the timeout expires with connections still in-flight, the
+// internal goroutine waiting on wg.Wait() will remain until those
+// connections complete. This is bounded by the number of hung connections
+// at drain time and self-heals once they close.
+func (t *ConnTracker) Drain(timeout time.Duration) {
+ t.draining.Store(true)
+
+ done := make(chan struct{})
+ go func() {
+ t.wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ }
+}
+
+// Reset re-enables the tracker after a failed drain. This allows the
+// sandbox to accept proxy connections again if the pause operation fails
+// and the VM is resumed.
+func (t *ConnTracker) Reset() {
+ t.draining.Store(false)
+}
diff --git a/internal/sandbox/images.go b/internal/sandbox/images.go
new file mode 100644
index 00000000..1716d80d
--- /dev/null
+++ b/internal/sandbox/images.go
@@ -0,0 +1,106 @@
+package sandbox
+
+import (
+ "fmt"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/layout"
+)
+
+// DefaultDiskSizeMB is the standard disk size for base images. Images smaller
+// than this are expanded at startup so that dm-snapshot sandboxes see the full
+// size without per-sandbox copies. The expansion is sparse — only metadata
+// changes; no physical disk is consumed beyond the original content.
+const DefaultDiskSizeMB = 5120 // 5 GB
+
+// EnsureImageSizes walks template directories and expands any rootfs.ext4 that
+// is smaller than the target size. This is idempotent: images already at or
+// above the target size are left untouched. Should be called once at host agent
+// startup before any sandboxes are created.
+func EnsureImageSizes(wrennDir string, targetMB int) error {
+ if targetMB <= 0 {
+ targetMB = DefaultDiskSizeMB
+ }
+ targetBytes := int64(targetMB) * 1024 * 1024
+
+ // Expand the built-in minimal image.
+ minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
+ if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil {
+ return err
+ }
+
+ // Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep.
+ teamsDir := layout.TeamsDir(wrennDir)
+ teamEntries, err := os.ReadDir(teamsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil // teams dir doesn't exist yet — nothing to expand
+ }
+ return fmt.Errorf("read teams dir: %w", err)
+ }
+
+ for _, teamEntry := range teamEntries {
+ if !teamEntry.IsDir() {
+ continue
+ }
+ teamPath := filepath.Join(teamsDir, teamEntry.Name())
+ templateEntries, err := os.ReadDir(teamPath)
+ if err != nil {
+ continue
+ }
+ for _, tmplEntry := range templateEntries {
+ if !tmplEntry.IsDir() {
+ continue
+ }
+ rootfs := filepath.Join(teamPath, tmplEntry.Name(), "rootfs.ext4")
+ if err := expandImage(rootfs, targetBytes, targetMB); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// expandImage expands a single rootfs image if it is smaller than targetBytes.
+func expandImage(rootfs string, targetBytes int64, targetMB int) error {
+ info, err := os.Stat(rootfs)
+ if err != nil {
+ return nil // not every template dir has a rootfs.ext4
+ }
+
+ if info.Size() >= targetBytes {
+ return nil // already large enough
+ }
+
+ slog.Info("expanding base image",
+ "path", rootfs,
+ "from_mb", info.Size()/(1024*1024),
+ "to_mb", targetMB,
+ )
+
+ // Expand the file (sparse — instant, no physical disk used).
+ if err := os.Truncate(rootfs, targetBytes); err != nil {
+ return fmt.Errorf("truncate %s: %w", rootfs, err)
+ }
+
+ // Check filesystem before resize.
+ if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil {
+ // e2fsck returns 1 if it fixed errors, which is fine.
+ if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
+ return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err)
+ }
+ }
+
+ // Grow the ext4 filesystem to fill the new file size.
+ if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil {
+ return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err)
+ }
+
+ slog.Info("base image expanded", "path", rootfs, "size_mb", targetMB)
+ return nil
+}
diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go
index 9a795b5c..67a70ca0 100644
--- a/internal/sandbox/manager.go
+++ b/internal/sandbox/manager.go
@@ -4,31 +4,32 @@ import (
"context"
"fmt"
"log/slog"
+ "net"
"os"
+ "os/exec"
"path/filepath"
+ "strings"
"sync"
"time"
"github.com/google/uuid"
+ "github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
"git.omukk.dev/wrenn/sandbox/internal/envdclient"
"git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/models"
"git.omukk.dev/wrenn/sandbox/internal/network"
"git.omukk.dev/wrenn/sandbox/internal/snapshot"
"git.omukk.dev/wrenn/sandbox/internal/uffd"
- "git.omukk.dev/wrenn/sandbox/internal/validate"
"git.omukk.dev/wrenn/sandbox/internal/vm"
)
// Config holds the paths and defaults for the sandbox manager.
type Config struct {
- KernelPath string
- ImagesDir string // directory containing template images (e.g., /var/lib/wrenn/images/{name}/rootfs.ext4)
- SandboxesDir string // directory for per-sandbox rootfs clones (e.g., /var/lib/wrenn/sandboxes)
- SnapshotsDir string // directory for pause snapshots (e.g., /var/lib/wrenn/snapshots/{sandbox-id}/)
- EnvdTimeout time.Duration
+ WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
+ EnvdTimeout time.Duration
}
// Manager orchestrates sandbox lifecycle: VM, network, filesystem, envd.
@@ -50,7 +51,8 @@ type sandboxState struct {
models.Sandbox
slot *network.Slot
client *envdclient.Client
- uffdSocketPath string // non-empty for sandboxes restored from snapshot
+ connTracker *ConnTracker // tracks in-flight proxy connections for pre-pause drain
+ uffdSocketPath string // non-empty for sandboxes restored from snapshot
dmDevice *devicemapper.SnapshotDevice
baseImagePath string // path to the base template rootfs (for loop registry release)
@@ -74,8 +76,12 @@ type snapshotParent struct {
}
// maxDiffGenerations caps how many incremental diff generations we chain
-// before falling back to a Full snapshot to collapse the chain.
-const maxDiffGenerations = 10
+// before falling back to a Full snapshot to collapse the chain. Firecracker
+// snapshot/restore of a Go process (envd) accumulates runtime memory state
+// drift; empirically, ~10 diff-based cycles corrupt the Go page allocator.
+// A Full snapshot resets the generation counter and produces a clean base,
+// preventing the crash.
+const maxDiffGenerations = 8
// New creates a new sandbox manager.
func New(cfg Config) *Manager {
@@ -94,9 +100,9 @@ func New(cfg Config) *Manager {
// Create boots a new sandbox: clone rootfs, set up network, start VM, wait for envd.
// If sandboxID is empty, a new ID is generated.
-func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec int) (*models.Sandbox, error) {
+func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, memoryMB, timeoutSec, diskSizeMB int) (*models.Sandbox, error) {
if sandboxID == "" {
- sandboxID = id.NewSandboxID()
+ sandboxID = id.FormatSandboxID(id.NewSandboxID())
}
if vcpus <= 0 {
@@ -105,21 +111,18 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus,
if memoryMB <= 0 {
memoryMB = 512
}
-
- if template == "" {
- template = "minimal"
- }
- if err := validate.SafeName(template); err != nil {
- return nil, fmt.Errorf("invalid template name: %w", err)
+ if diskSizeMB <= 0 {
+ diskSizeMB = 5120 // 5 GB default
}
// Check if template refers to a snapshot (has snapfile + memfile + header + rootfs).
- if snapshot.IsSnapshot(m.cfg.ImagesDir, template) {
- return m.createFromSnapshot(ctx, sandboxID, template, vcpus, memoryMB, timeoutSec)
+ tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
+ if _, err := os.Stat(filepath.Join(tmplDir, snapshot.SnapFileName)); err == nil {
+ return m.createFromSnapshot(ctx, sandboxID, teamID, templateID, vcpus, memoryMB, timeoutSec, diskSizeMB)
}
- // Resolve base rootfs image: /var/lib/wrenn/images/{template}/rootfs.ext4
- baseRootfs := filepath.Join(m.cfg.ImagesDir, template, "rootfs.ext4")
+ // Resolve base rootfs image.
+ baseRootfs := layout.TemplateRootfs(m.cfg.WrennDir, teamID, templateID)
if _, err := os.Stat(baseRootfs); err != nil {
return nil, fmt.Errorf("base rootfs not found at %s: %w", baseRootfs, err)
}
@@ -138,8 +141,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus,
// Create dm-snapshot with per-sandbox CoW file.
dmName := "wrenn-" + sandboxID
- cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID))
- dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize)
+ cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
+ cowSize := int64(diskSizeMB) * 1024 * 1024
+ dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil {
m.loops.Release(baseRootfs)
return nil, fmt.Errorf("create dm-snapshot: %w", err)
@@ -167,7 +171,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus,
// Boot VM — Firecracker gets the dm device path.
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
- KernelPath: m.cfg.KernelPath,
+ TemplateID: id.UUIDString(templateID),
+ KernelPath: layout.KernelPath(m.cfg.WrennDir),
RootfsPath: dmDev.DevicePath,
VCPUs: vcpus,
MemoryMB: memoryMB,
@@ -203,33 +208,25 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus,
return nil, fmt.Errorf("wait for envd: %w", err)
}
- // Sync guest clock in background. Non-fatal — sandbox is usable before this completes.
- // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane.
- go func() {
- initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer initCancel()
- if err := client.Init(initCtx); err != nil {
- slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err)
- }
- }()
-
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
- ID: sandboxID,
- Status: models.StatusRunning,
- Template: template,
- VCPUs: vcpus,
- MemoryMB: memoryMB,
- TimeoutSec: timeoutSec,
- SlotIndex: slotIdx,
- HostIP: slot.HostIP,
- RootfsPath: dmDev.DevicePath,
- CreatedAt: now,
- LastActiveAt: now,
+ ID: sandboxID,
+ Status: models.StatusRunning,
+ TemplateTeamID: teamID.Bytes,
+ TemplateID: templateID.Bytes,
+ VCPUs: vcpus,
+ MemoryMB: memoryMB,
+ TimeoutSec: timeoutSec,
+ SlotIndex: slotIdx,
+ HostIP: slot.HostIP,
+ RootfsPath: dmDev.DevicePath,
+ CreatedAt: now,
+ LastActiveAt: now,
},
slot: slot,
client: client,
+ connTracker: &ConnTracker{},
dmDevice: dmDev,
baseImagePath: baseRootfs,
}
@@ -242,7 +239,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus,
slog.Info("sandbox created",
"id", sandboxID,
- "template", template,
+ "team_id", teamID,
+ "template_id", templateID,
"host_ip", slot.HostIP.String(),
"dm_device", dmDev.DevicePath,
)
@@ -265,7 +263,9 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
}
// Always clean up pause snapshot files (may exist if sandbox was paused).
- warnErr("snapshot cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ if err := os.RemoveAll(layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)); err != nil {
+ slog.Warn("snapshot cleanup error", "id", sandboxID, "error", err)
+ }
slog.Info("sandbox destroyed", "id", sandboxID)
return nil
@@ -311,43 +311,53 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
+ // Step 0: Drain in-flight proxy connections before freezing vCPUs.
+ // This prevents Go runtime corruption inside the guest caused by stale
+ // TCP state from connections that were alive when the VM was snapshotted.
+ sb.connTracker.Drain(2 * time.Second)
+ slog.Debug("pause: proxy connections drained", "id", sandboxID)
+
pauseStart := time.Now()
// Step 1: Pause the VM (freeze vCPUs).
if err := m.vm.Pause(ctx, sandboxID); err != nil {
+ sb.connTracker.Reset()
return fmt.Errorf("pause VM: %w", err)
}
slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart))
- // Determine snapshot type: Diff if resumed from snapshot (avoids UFFD
- // fault-in storm), Full otherwise or if generation cap is reached.
+ // Always use Diff when we have a parent snapshot — Diff only captures
+ // changed pages and is much faster than Full (which dumps all memory).
+ // For first-time pauses (no parent) we must use Full.
snapshotType := "Full"
- if sb.parent != nil && sb.parent.header.Metadata.Generation < maxDiffGenerations {
+ if sb.parent != nil {
snapshotType = "Diff"
}
// resumeOnError unpauses the VM so the sandbox stays usable when a
// post-freeze step fails. If the resume itself fails, the sandbox is
- // left frozen — the caller should destroy it.
+ // left frozen — the caller should destroy it. It also resets the
+ // connection tracker so the sandbox can accept proxy connections again.
resumeOnError := func() {
+ sb.connTracker.Reset()
if err := m.vm.Resume(ctx, sandboxID); err != nil {
slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err)
}
}
// Step 2: Take VM state snapshot (snapfile + memfile).
- if err := snapshot.EnsureDir(m.cfg.SnapshotsDir, sandboxID); err != nil {
+ pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)
+ if err := os.MkdirAll(pauseDir, 0755); err != nil {
resumeOnError()
return fmt.Errorf("create snapshot dir: %w", err)
}
- snapDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID)
- rawMemPath := filepath.Join(snapDir, "memfile.raw")
- snapPath := snapshot.SnapPath(m.cfg.SnapshotsDir, sandboxID)
+ rawMemPath := filepath.Join(pauseDir, "memfile.raw")
+ snapPath := filepath.Join(pauseDir, snapshot.SnapFileName)
snapshotStart := time.Now()
if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("create VM snapshot: %w", err)
}
@@ -355,34 +365,76 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// Step 3: Process the raw memfile into a compact diff + header.
buildID := uuid.New()
- headerPath := snapshot.MemHeaderPath(m.cfg.SnapshotsDir, sandboxID)
+ headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName)
processStart := time.Now()
- if sb.parent != nil && snapshotType == "Diff" {
+ if sb.parent != nil {
// Diff: process against parent header, producing only changed blocks.
- diffPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, buildID)
+ diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID)
if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("process memfile with parent: %w", err)
}
// Copy previous generation diff files into the snapshot directory.
for prevBuildID, prevPath := range sb.parent.diffPaths {
- dstPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, uuid.MustParse(prevBuildID))
+ dstPath := snapshot.MemDiffPathForBuild(pauseDir, "", uuid.MustParse(prevBuildID))
if prevPath != dstPath {
if err := copyFile(prevPath, dstPath); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("copy parent diff file: %w", err)
}
}
}
+
+ // If the generation cap is reached, merge all diff files into a
+ // single file to collapse the chain. This is a file-level operation
+ // (no Firecracker involvement) so it's fast and reliable.
+ generation := sb.parent.header.Metadata.Generation + 1
+ if generation >= maxDiffGenerations {
+ slog.Debug("pause: merging diff generations", "id", sandboxID, "generation", generation)
+
+ // Load the header we just wrote (it references all generations).
+ headerData, err := os.ReadFile(headerPath)
+ if err != nil {
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
+ resumeOnError()
+ return fmt.Errorf("read header for merge: %w", err)
+ }
+ currentHeader, err := snapshot.Deserialize(headerData)
+ if err != nil {
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
+ resumeOnError()
+ return fmt.Errorf("deserialize header for merge: %w", err)
+ }
+
+ // Locate all diff files referenced by the header.
+ diffFiles, err := snapshot.ListDiffFiles(pauseDir, "", currentHeader)
+ if err != nil {
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
+ resumeOnError()
+ return fmt.Errorf("list diff files for merge: %w", err)
+ }
+
+ // Merge into a single new diff file.
+ mergedPath := snapshot.MemDiffPath(pauseDir, "")
+ if _, err := snapshot.MergeDiffs(currentHeader, diffFiles, mergedPath, headerPath); err != nil {
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
+ resumeOnError()
+ return fmt.Errorf("merge diff files: %w", err)
+ }
+
+ // Remove the old per-generation diff files.
+ removeStaleMemDiffs(pauseDir)
+ slog.Debug("pause: diff merge complete", "id", sandboxID)
+ }
} else {
- // Full: first generation or generation cap reached — single diff file.
- diffPath := snapshot.MemDiffPath(m.cfg.SnapshotsDir, sandboxID)
+ // Full: first pause — no parent to diff against.
+ diffPath := snapshot.MemDiffPath(pauseDir, "")
if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("process memfile: %w", err)
}
@@ -412,7 +464,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
if sb.uffdSocketPath != "" {
os.Remove(sb.uffdSocketPath)
}
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
m.mu.Lock()
delete(m.boxes, sandboxID)
m.mu.Unlock()
@@ -420,9 +472,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
}
// Move (not copy) the CoW file into the snapshot directory.
- snapshotCow := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID)
+ snapshotCow := snapshot.CowPath(pauseDir, "")
if err := os.Rename(sb.dmDevice.CowPath, snapshotCow); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
// VM and dm-snapshot are already gone — clean up remaining resources.
warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot))
m.slots.Release(sb.SlotIndex)
@@ -439,10 +491,10 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
}
// Record which base template this CoW was built against.
- if err := snapshot.WriteMeta(m.cfg.SnapshotsDir, sandboxID, &snapshot.RootfsMeta{
+ if err := snapshot.WriteMeta(pauseDir, "", &snapshot.RootfsMeta{
BaseTemplate: sb.baseImagePath,
}); err != nil {
- warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID))
+ warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
// VM and dm-snapshot are already gone — clean up remaining resources.
warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot))
m.slots.Release(sb.SlotIndex)
@@ -482,13 +534,13 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// Resume restores a paused sandbox from its snapshot using UFFD for
// lazy memory loading. The sandbox gets a new network slot.
func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) (*models.Sandbox, error) {
- snapDir := m.cfg.SnapshotsDir
- if !snapshot.Exists(snapDir, sandboxID) {
+ pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)
+ if _, err := os.Stat(pauseDir); err != nil {
return nil, fmt.Errorf("no snapshot found for sandbox %s", sandboxID)
}
// Read the header to set up the UFFD memory source.
- headerData, err := os.ReadFile(snapshot.MemHeaderPath(snapDir, sandboxID))
+ headerData, err := os.ReadFile(filepath.Join(pauseDir, snapshot.MemHeaderName))
if err != nil {
return nil, fmt.Errorf("read header: %w", err)
}
@@ -499,7 +551,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
}
// Build diff file map — supports both single-generation and multi-generation.
- diffPaths, err := snapshot.ListDiffFiles(snapDir, sandboxID, header)
+ diffPaths, err := snapshot.ListDiffFiles(pauseDir, "", header)
if err != nil {
return nil, fmt.Errorf("list diff files: %w", err)
}
@@ -510,7 +562,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
}
// Read rootfs metadata to find the base template image.
- meta, err := snapshot.ReadMeta(snapDir, sandboxID)
+ meta, err := snapshot.ReadMeta(pauseDir, "")
if err != nil {
source.Close()
return nil, fmt.Errorf("read rootfs meta: %w", err)
@@ -532,8 +584,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
}
// Move CoW file from snapshot dir to sandboxes dir for the running sandbox.
- savedCow := snapshot.CowPath(snapDir, sandboxID)
- cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID))
+ savedCow := snapshot.CowPath(pauseDir, "")
+ cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
if err := os.Rename(savedCow, cowPath); err != nil {
source.Close()
m.loops.Release(baseImagePath)
@@ -579,7 +631,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
}
// Start UFFD server.
- uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID))
+ uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID))
os.Remove(uffdSocketPath) // Clean stale socket.
uffdServer := uffd.NewServer(uffdSocketPath, source)
if err := uffdServer.Start(ctx); err != nil {
@@ -595,7 +647,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
// Restore VM from snapshot.
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
- KernelPath: m.cfg.KernelPath,
+ KernelPath: layout.KernelPath(m.cfg.WrennDir),
RootfsPath: dmDev.DevicePath,
VCPUs: 1, // Placeholder; overridden by snapshot.
MemoryMB: int(header.Metadata.Size / (1024 * 1024)), // Placeholder; overridden by snapshot.
@@ -607,8 +659,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
NetMask: slot.GuestNetMask,
}
- snapPath := snapshot.SnapPath(snapDir, sandboxID)
- if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil {
+ resumeSnapPath := filepath.Join(pauseDir, snapshot.SnapFileName)
+ if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, resumeSnapPath, uffdSocketPath); err != nil {
warnErr("uffd server stop error", sandboxID, uffdServer.Stop())
source.Close()
warnErr("network cleanup error", sandboxID, network.RemoveNetwork(slot))
@@ -636,22 +688,11 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
return nil, fmt.Errorf("wait for envd: %w", err)
}
- // Sync guest clock in background. Non-fatal — sandbox is usable before this completes.
- // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane.
- go func() {
- initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer initCancel()
- if err := client.Init(initCtx); err != nil {
- slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err)
- }
- }()
-
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
ID: sandboxID,
Status: models.StatusRunning,
- Template: "",
VCPUs: vmCfg.VCPUs,
MemoryMB: vmCfg.MemoryMB,
TimeoutSec: timeoutSec,
@@ -663,6 +704,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
},
slot: slot,
client: client,
+ connTracker: &ConnTracker{},
uffdSocketPath: uffdSocketPath,
dmDevice: dmDev,
baseImagePath: baseImagePath,
@@ -700,11 +742,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
// The rootfs is flattened (base + CoW merged) into a new standalone rootfs.ext4
// so the template has no dependency on the original base image. Memory state
// and VM snapshot files are copied as-is.
-func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (int64, error) {
- if err := validate.SafeName(name); err != nil {
- return 0, fmt.Errorf("invalid snapshot name: %w", err)
- }
-
+func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) {
// If the sandbox is running, pause it first.
if _, err := m.get(sandboxID); err == nil {
if err := m.Pause(ctx, sandboxID); err != nil {
@@ -712,25 +750,26 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i
}
}
- // At this point, pause snapshot files must exist in SnapshotsDir/{sandboxID}/.
- if !snapshot.Exists(m.cfg.SnapshotsDir, sandboxID) {
+ // At this point, pause snapshot files must exist.
+ pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)
+ if _, err := os.Stat(pauseDir); err != nil {
return 0, fmt.Errorf("no snapshot found for sandbox %s", sandboxID)
}
// Create template directory.
- if err := snapshot.EnsureDir(m.cfg.ImagesDir, name); err != nil {
+ dstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
+ if err := os.MkdirAll(dstDir, 0755); err != nil {
return 0, fmt.Errorf("create template dir: %w", err)
}
// Copy VM snapshot file and memory header.
- srcDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID)
- dstDir := snapshot.DirPath(m.cfg.ImagesDir, name)
+ srcDir := pauseDir
for _, fname := range []string{snapshot.SnapFileName, snapshot.MemHeaderName} {
src := filepath.Join(srcDir, fname)
dst := filepath.Join(dstDir, fname)
if err := copyFile(src, dst); err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("copy %s: %w", fname, err)
}
}
@@ -738,59 +777,59 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i
// Copy all memory diff files referenced by the header (supports multi-generation).
headerData, err := os.ReadFile(filepath.Join(srcDir, snapshot.MemHeaderName))
if err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("read header for template: %w", err)
}
srcHeader, err := snapshot.Deserialize(headerData)
if err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("deserialize header for template: %w", err)
}
- srcDiffPaths, err := snapshot.ListDiffFiles(m.cfg.SnapshotsDir, sandboxID, srcHeader)
+ srcDiffPaths, err := snapshot.ListDiffFiles(pauseDir, "", srcHeader)
if err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("list diff files for template: %w", err)
}
for _, srcPath := range srcDiffPaths {
dstPath := filepath.Join(dstDir, filepath.Base(srcPath))
if err := copyFile(srcPath, dstPath); err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("copy diff file %s: %w", filepath.Base(srcPath), err)
}
}
// Flatten rootfs: temporarily set up dm device from base + CoW, dd to new image.
- meta, err := snapshot.ReadMeta(m.cfg.SnapshotsDir, sandboxID)
+ meta, err := snapshot.ReadMeta(pauseDir, "")
if err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("read rootfs meta: %w", err)
}
originLoop, err := m.loops.Acquire(meta.BaseTemplate)
if err != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("acquire loop device for flatten: %w", err)
}
originSize, err := devicemapper.OriginSizeBytes(originLoop)
if err != nil {
m.loops.Release(meta.BaseTemplate)
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("get origin size: %w", err)
}
// Temporarily restore the dm-snapshot to read the merged view.
- cowPath := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID)
+ cowPath := snapshot.CowPath(pauseDir, "")
tmpDmName := "wrenn-flatten-" + sandboxID
tmpDev, err := devicemapper.RestoreSnapshot(ctx, tmpDmName, originLoop, cowPath, originSize)
if err != nil {
m.loops.Release(meta.BaseTemplate)
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("restore dm-snapshot for flatten: %w", err)
}
// Flatten to new standalone rootfs.
- flattenedPath := snapshot.RootfsPath(m.cfg.ImagesDir, name)
+ flattenedPath := filepath.Join(dstDir, snapshot.RootfsFileName)
flattenErr := devicemapper.FlattenSnapshot(tmpDev.DevicePath, flattenedPath)
// Always clean up the temporary dm device.
@@ -798,40 +837,131 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i
m.loops.Release(meta.BaseTemplate)
if flattenErr != nil {
- warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name))
+ warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir))
return 0, fmt.Errorf("flatten rootfs: %w", flattenErr)
}
- sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name)
+ sizeBytes, err := snapshot.DirSize(dstDir, "")
if err != nil {
slog.Warn("failed to calculate snapshot size", "error", err)
}
slog.Info("template snapshot created (rootfs flattened)",
"sandbox", sandboxID,
- "name", name,
+ "team_id", teamID,
+ "template_id", templateID,
"size_bytes", sizeBytes,
)
return sizeBytes, nil
}
-// DeleteSnapshot removes a snapshot template from disk.
-func (m *Manager) DeleteSnapshot(name string) error {
- if err := validate.SafeName(name); err != nil {
- return fmt.Errorf("invalid snapshot name: %w", err)
+// FlattenRootfs stops a running sandbox, flattens its device-mapper CoW
+// rootfs into a standalone rootfs.ext4, and cleans up all resources.
+// The result is an image-only template (no VM memory/CPU state) stored in
+// ImagesDir/{name}/rootfs.ext4.
+func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) {
+ m.mu.Lock()
+ sb, ok := m.boxes[sandboxID]
+ if ok {
+ delete(m.boxes, sandboxID)
}
- return snapshot.Remove(m.cfg.ImagesDir, name)
+ m.mu.Unlock()
+
+ if !ok {
+ return 0, fmt.Errorf("sandbox %s not found", sandboxID)
+ }
+
+ // Stop the VM but keep the dm device alive for flattening.
+ m.stopSampler(sb)
+ if err := m.vm.Destroy(ctx, sb.ID); err != nil {
+ slog.Warn("vm destroy error during flatten", "id", sb.ID, "error", err)
+ }
+
+ // Release network resources — not needed after VM is stopped.
+ if err := network.RemoveNetwork(sb.slot); err != nil {
+ slog.Warn("network cleanup error during flatten", "id", sb.ID, "error", err)
+ }
+ m.slots.Release(sb.SlotIndex)
+
+ if sb.uffdSocketPath != "" {
+ os.Remove(sb.uffdSocketPath)
+ }
+
+ // Create template directory and flatten the dm-snapshot.
+ flattenDstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
+ if err := os.MkdirAll(flattenDstDir, 0755); err != nil {
+ m.cleanupDM(sb)
+ return 0, fmt.Errorf("create template dir: %w", err)
+ }
+
+ outputPath := filepath.Join(flattenDstDir, snapshot.RootfsFileName)
+ if sb.dmDevice == nil {
+ m.cleanupDM(sb)
+ warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir))
+ return 0, fmt.Errorf("sandbox %s has no dm device", sandboxID)
+ }
+
+ if err := devicemapper.FlattenSnapshot(sb.dmDevice.DevicePath, outputPath); err != nil {
+ m.cleanupDM(sb)
+ warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir))
+ return 0, fmt.Errorf("flatten rootfs: %w", err)
+ }
+
+ // Clean up dm device and loop device now that flatten is complete.
+ m.cleanupDM(sb)
+
+ // Shrink the flattened image to its minimum size so stored templates are
+ // compact. EnsureImageSizes will re-expand them on the next agent startup.
+ if out, err := exec.Command("e2fsck", "-fy", outputPath).CombinedOutput(); err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
+ slog.Warn("e2fsck before shrink failed (non-fatal)", "output", string(out), "error", err)
+ }
+ }
+ if out, err := exec.Command("resize2fs", "-M", outputPath).CombinedOutput(); err != nil {
+ slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err)
+ }
+
+ sizeBytes, err := snapshot.DirSize(flattenDstDir, "")
+ if err != nil {
+ slog.Warn("failed to calculate template size", "error", err)
+ }
+
+ slog.Info("rootfs flattened to image-only template",
+ "sandbox", sandboxID,
+ "team_id", teamID,
+ "template_id", templateID,
+ "size_bytes", sizeBytes,
+ )
+ return sizeBytes, nil
+}
+
+// cleanupDM tears down the dm-snapshot device and releases the base image loop device.
+func (m *Manager) cleanupDM(sb *sandboxState) {
+ if sb.dmDevice != nil {
+ if err := devicemapper.RemoveSnapshot(context.Background(), sb.dmDevice); err != nil {
+ slog.Warn("dm-snapshot remove error", "id", sb.ID, "error", err)
+ }
+ os.Remove(sb.dmDevice.CowPath)
+ }
+ if sb.baseImagePath != "" {
+ m.loops.Release(sb.baseImagePath)
+ }
+}
+
+// DeleteSnapshot removes a snapshot template from disk.
+func (m *Manager) DeleteSnapshot(teamID, templateID pgtype.UUID) error {
+ return os.RemoveAll(layout.TemplateDir(m.cfg.WrennDir, teamID, templateID))
}
// createFromSnapshot creates a new sandbox by restoring from a snapshot template
// in ImagesDir/{snapshotName}/. Uses UFFD for lazy memory loading.
// The template's rootfs.ext4 is a flattened standalone image — we create a
// dm-snapshot on top of it just like a normal Create.
-func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotName string, vcpus, _, timeoutSec int) (*models.Sandbox, error) {
- imagesDir := m.cfg.ImagesDir
+func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, _, timeoutSec, diskSizeMB int) (*models.Sandbox, error) {
+ tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
// Read the header.
- headerData, err := os.ReadFile(snapshot.MemHeaderPath(imagesDir, snapshotName))
+ headerData, err := os.ReadFile(filepath.Join(tmplDir, snapshot.MemHeaderName))
if err != nil {
return nil, fmt.Errorf("read snapshot header: %w", err)
}
@@ -845,7 +975,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
memoryMB := int(header.Metadata.Size / (1024 * 1024))
// Build diff file map — supports multi-generation templates.
- diffPaths, err := snapshot.ListDiffFiles(imagesDir, snapshotName, header)
+ diffPaths, err := snapshot.ListDiffFiles(tmplDir, "", header)
if err != nil {
return nil, fmt.Errorf("list diff files: %w", err)
}
@@ -856,7 +986,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
}
// Set up dm-snapshot on the template's flattened rootfs.
- baseRootfs := snapshot.RootfsPath(imagesDir, snapshotName)
+ baseRootfs := filepath.Join(tmplDir, snapshot.RootfsFileName)
originLoop, err := m.loops.Acquire(baseRootfs)
if err != nil {
source.Close()
@@ -871,8 +1001,9 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
}
dmName := "wrenn-" + sandboxID
- cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID))
- dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize)
+ cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID))
+ cowSize := int64(diskSizeMB) * 1024 * 1024
+ dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil {
source.Close()
m.loops.Release(baseRootfs)
@@ -900,7 +1031,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
}
// Start UFFD server.
- uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID))
+ uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID))
os.Remove(uffdSocketPath)
uffdServer := uffd.NewServer(uffdSocketPath, source)
if err := uffdServer.Start(ctx); err != nil {
@@ -916,7 +1047,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
// Restore VM.
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
- KernelPath: m.cfg.KernelPath,
+ TemplateID: id.UUIDString(templateID),
+ KernelPath: layout.KernelPath(m.cfg.WrennDir),
RootfsPath: dmDev.DevicePath,
VCPUs: vcpus,
MemoryMB: memoryMB,
@@ -928,7 +1060,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
NetMask: slot.GuestNetMask,
}
- snapPath := snapshot.SnapPath(imagesDir, snapshotName)
+ snapPath := filepath.Join(tmplDir, snapshot.SnapFileName)
if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil {
warnErr("uffd server stop error", sandboxID, uffdServer.Stop())
source.Close()
@@ -957,33 +1089,25 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
return nil, fmt.Errorf("wait for envd: %w", err)
}
- // Sync guest clock in background. Non-fatal — sandbox is usable before this completes.
- // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane.
- go func() {
- initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer initCancel()
- if err := client.Init(initCtx); err != nil {
- slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err)
- }
- }()
-
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
- ID: sandboxID,
- Status: models.StatusRunning,
- Template: snapshotName,
- VCPUs: vcpus,
- MemoryMB: memoryMB,
- TimeoutSec: timeoutSec,
- SlotIndex: slotIdx,
- HostIP: slot.HostIP,
- RootfsPath: dmDev.DevicePath,
- CreatedAt: now,
- LastActiveAt: now,
+ ID: sandboxID,
+ Status: models.StatusRunning,
+ TemplateTeamID: teamID.Bytes,
+ TemplateID: templateID.Bytes,
+ VCPUs: vcpus,
+ MemoryMB: memoryMB,
+ TimeoutSec: timeoutSec,
+ SlotIndex: slotIdx,
+ HostIP: slot.HostIP,
+ RootfsPath: dmDev.DevicePath,
+ CreatedAt: now,
+ LastActiveAt: now,
},
slot: slot,
client: client,
+ connTracker: &ConnTracker{},
uffdSocketPath: uffdSocketPath,
dmDevice: dmDev,
baseImagePath: baseRootfs,
@@ -1002,7 +1126,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam
slog.Info("sandbox created from snapshot",
"id", sandboxID,
- "snapshot", snapshotName,
+ "team_id", teamID,
+ "template_id", templateID,
"host_ip", slot.HostIP.String(),
"dm_device", dmDev.DevicePath,
)
@@ -1079,6 +1204,25 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) {
return sb.client, nil
}
+// AcquireProxyConn atomically looks up a sandbox by ID and registers an
+// in-flight proxy connection. Returns the sandbox's host-reachable IP, the
+// connection tracker, and true on success. The caller must call
+// tracker.Release() when the request completes. Returns zero values and
+// false if the sandbox is not found, not running, or is draining for a pause.
+func (m *Manager) AcquireProxyConn(sandboxID string) (net.IP, *ConnTracker, bool) {
+ m.mu.RLock()
+ sb, ok := m.boxes[sandboxID]
+ m.mu.RUnlock()
+
+ if !ok || sb.Status != models.StatusRunning {
+ return nil, nil, false
+ }
+ if !sb.connTracker.Acquire() {
+ return nil, nil, false
+ }
+ return sb.HostIP, sb.connTracker, true
+}
+
// Ping resets the inactivity timer for a running sandbox.
func (m *Manager) Ping(sandboxID string) error {
m.mu.Lock()
@@ -1218,6 +1362,23 @@ func (m *Manager) PauseAll(ctx context.Context) {
}
}
+// removeStaleMemDiffs removes memfile.{uuid} diff files from a snapshot
+// directory. Called before writing a Full snapshot to prevent orphaned diffs
+// from accumulating across generation resets.
+func removeStaleMemDiffs(dir string) {
+ entries, err := os.ReadDir(dir)
+ if err != nil {
+ return
+ }
+ for _, e := range entries {
+ name := e.Name()
+ // Match "memfile.{uuid}" but not "memfile", "memfile.header", or "memfile.raw".
+ if strings.HasPrefix(name, "memfile.") && name != snapshot.MemHeaderName && name != "memfile.raw" {
+ os.Remove(filepath.Join(dir, name))
+ }
+ }
+}
+
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup
// in error paths where the primary error has already been captured.
func warnErr(msg string, id string, err error) {
diff --git a/internal/scheduler/round_robin.go b/internal/scheduler/round_robin.go
index 31433a0b..c2ab0f40 100644
--- a/internal/scheduler/round_robin.go
+++ b/internal/scheduler/round_robin.go
@@ -5,6 +5,8 @@ import (
"fmt"
"sync/atomic"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/db"
)
@@ -15,7 +17,7 @@ type HostScheduler interface {
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
// are considered. For non-BYOC teams, only online regular (platform) hosts
// are considered. Returns an error if no suitable host is available.
- SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error)
+ SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error)
}
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
@@ -32,7 +34,7 @@ func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
}
// SelectHost returns the next eligible online host in round-robin order.
-func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) {
+func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
@@ -40,12 +42,12 @@ func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isB
var eligible []db.Host
for _, h := range hosts {
- if h.Status != "online" || !h.Address.Valid || h.Address.String == "" {
+ if h.Status != "online" || h.Address == "" {
continue
}
if isByoc {
// BYOC team: only use hosts belonging to this team.
- if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID.String != teamID {
+ if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
continue
}
} else {
diff --git a/internal/service/apikey.go b/internal/service/apikey.go
index c49ddcaa..7a2b073a 100644
--- a/internal/service/apikey.go
+++ b/internal/service/apikey.go
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
@@ -22,7 +24,7 @@ type APIKeyCreateResult struct {
}
// Create generates a new API key for the given team.
-func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
+func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
@@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string)
}
// List returns all API keys belonging to the given team.
-func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
+func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
-func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
+func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
-func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
+func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}
diff --git a/internal/service/audit.go b/internal/service/audit.go
index 53061421..67faafa2 100644
--- a/internal/service/audit.go
+++ b/internal/service/audit.go
@@ -9,6 +9,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
)
const auditMaxLimit = 200
@@ -31,13 +32,13 @@ type AuditEntry struct {
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
- TeamID string
- AdminScoped bool // true → include admin-scoped events; false → team-scoped only
- ResourceTypes []string // empty = no filter; multiple values = OR match
- Actions []string // empty = no filter; multiple values = OR match
- Before time.Time // zero = no cursor (start from latest)
- BeforeID string // tie-breaker: id of the last item at the Before timestamp; empty = no tie-break
- Limit int // clamped to auditMaxLimit by the handler
+ TeamID pgtype.UUID
+ AdminScoped bool // true → include admin-scoped events; false → team-scoped only
+ ResourceTypes []string // empty = no filter; multiple values = OR match
+ Actions []string // empty = no filter; multiple values = OR match
+ Before time.Time // zero = no cursor (start from latest)
+ BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
+ Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
@@ -94,11 +95,11 @@ func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntr
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
- ID: row.ID,
- TeamID: row.TeamID,
+ ID: id.FormatAuditLogID(row.ID),
+ TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
- ActorName: row.ActorName.String,
+ ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
diff --git a/internal/service/build.go b/internal/service/build.go
new file mode 100644
index 00000000..1108044d
--- /dev/null
+++ b/internal/service/build.go
@@ -0,0 +1,519 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "sync"
+ "time"
+
+ "connectrpc.com/connect"
+ "github.com/jackc/pgx/v5/pgtype"
+ "github.com/redis/go-redis/v9"
+
+ "git.omukk.dev/wrenn/sandbox/internal/db"
+ "git.omukk.dev/wrenn/sandbox/internal/id"
+ "git.omukk.dev/wrenn/sandbox/internal/lifecycle"
+ "git.omukk.dev/wrenn/sandbox/internal/recipe"
+ "git.omukk.dev/wrenn/sandbox/internal/scheduler"
+ pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
+)
+
+const (
+ buildQueueKey = "wrenn:build_queue"
+ buildCommandTimeout = 30 * time.Second
+ healthcheckInterval = 1 * time.Second
+ healthcheckTimeout = 60 * time.Second
+)
+
+// preBuildCmds run before the user recipe to prepare the build environment.
+var preBuildCmds = []string{
+ "RUN apt update",
+}
+
+// postBuildCmds run after the user recipe to clean up caches and reduce image size.
+var postBuildCmds = []string{
+ "RUN apt clean",
+ "RUN apt autoremove -y",
+ "RUN rm -rf /var/lib/apt/lists/*",
+}
+
+// buildAgentClient is the subset of the host agent client used by the build worker.
+type buildAgentClient interface {
+ CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
+ DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
+ Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
+ CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
+ FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
+}
+
+// BuildService handles template build orchestration.
+type BuildService struct {
+ DB *db.Queries
+ Redis *redis.Client
+ Pool *lifecycle.HostClientPool
+ Scheduler scheduler.HostScheduler
+
+ mu sync.Mutex
+ cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
+}
+
+// BuildCreateParams holds the parameters for creating a template build.
+type BuildCreateParams struct {
+ Name string
+ BaseTemplate string
+ Recipe []string
+ Healthcheck string
+ VCPUs int32
+ MemoryMB int32
+ SkipPrePost bool
+}
+
+// Create inserts a new build record and enqueues it to Redis.
+func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
+ if p.BaseTemplate == "" {
+ p.BaseTemplate = "minimal"
+ }
+ if p.VCPUs <= 0 {
+ p.VCPUs = 1
+ }
+ if p.MemoryMB <= 0 {
+ p.MemoryMB = 512
+ }
+
+ recipeJSON, err := json.Marshal(p.Recipe)
+ if err != nil {
+ return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
+ }
+
+ buildID := id.NewBuildID()
+ buildIDStr := id.FormatBuildID(buildID)
+ newTemplateID := id.NewTemplateID()
+
+ defaultSteps := len(preBuildCmds) + len(postBuildCmds)
+ if p.SkipPrePost {
+ defaultSteps = 0
+ }
+
+ build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
+ ID: buildID,
+ Name: p.Name,
+ BaseTemplate: p.BaseTemplate,
+ Recipe: recipeJSON,
+ Healthcheck: p.Healthcheck,
+ Vcpus: p.VCPUs,
+ MemoryMb: p.MemoryMB,
+ TotalSteps: int32(len(p.Recipe) + defaultSteps),
+ TemplateID: newTemplateID,
+ TeamID: id.PlatformTeamID,
+ SkipPrePost: p.SkipPrePost,
+ })
+ if err != nil {
+ return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
+ }
+
+ // Enqueue build ID (as formatted string) to Redis for workers to pick up.
+ if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
+ return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
+ }
+
+ return build, nil
+}
+
+// Get returns a single build by ID.
+func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
+ return s.DB.GetTemplateBuild(ctx, buildID)
+}
+
+// List returns all builds ordered by creation time.
+func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
+ return s.DB.ListTemplateBuilds(ctx)
+}
+
+// Cancel cancels a pending or running build. For pending builds the status is
+// updated in the DB and the worker skips it when dequeued. For running builds
+// the per-build context is cancelled, which causes the current exec step to
+// abort; executeBuild then detects the cancellation and records the status.
+func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
+ build, err := s.DB.GetTemplateBuild(ctx, buildID)
+ if err != nil {
+ return fmt.Errorf("get build: %w", err)
+ }
+ switch build.Status {
+ case "success", "failed", "cancelled":
+ return fmt.Errorf("build is already %s", build.Status)
+ }
+
+ // Mark cancelled in DB first. This handles both pending builds (which haven't
+ // been picked up yet) and acts as a flag for executeBuild to check on start.
+ if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
+ ID: buildID, Status: "cancelled",
+ }); err != nil {
+ return fmt.Errorf("update build status: %w", err)
+ }
+
+ // If the build is currently running, signal its context.
+ buildIDStr := id.FormatBuildID(buildID)
+ s.mu.Lock()
+ cancel, running := s.cancelMap[buildIDStr]
+ s.mu.Unlock()
+ if running {
+ cancel()
+ }
+
+ return nil
+}
+
+// StartWorkers launches n goroutines that consume from the Redis build queue.
+// The returned cancel function stops all workers.
+func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
+ ctx, cancel := context.WithCancel(ctx)
+ for i := range n {
+ go s.worker(ctx, i)
+ }
+ slog.Info("build workers started", "count", n)
+ return cancel
+}
+
+func (s *BuildService) worker(ctx context.Context, workerID int) {
+ log := slog.With("worker", workerID)
+ for {
+ // BLPOP blocks until a build ID is available or context is cancelled.
+ result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
+ if err != nil {
+ if ctx.Err() != nil {
+ log.Info("build worker shutting down")
+ return
+ }
+ log.Error("redis BLPOP error", "error", err)
+ time.Sleep(time.Second)
+ continue
+ }
+ // result[0] is the key, result[1] is the build ID (formatted string).
+ buildIDStr := result[1]
+ log.Info("picked up build", "build_id", buildIDStr)
+ s.executeBuild(ctx, buildIDStr)
+ }
+}
+
+func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
+ log := slog.With("build_id", buildIDStr)
+
+ buildID, err := id.ParseBuildID(buildIDStr)
+ if err != nil {
+ log.Error("invalid build ID from queue", "error", err)
+ return
+ }
+
+ // Create a per-build context so this build can be cancelled independently of
+ // the worker. Register in cancelMap before fetching the build so that a
+ // concurrent Cancel call can always find and signal it.
+ buildCtx, buildCancel := context.WithCancel(ctx)
+ defer buildCancel()
+
+ s.mu.Lock()
+ if s.cancelMap == nil {
+ s.cancelMap = make(map[string]context.CancelFunc)
+ }
+ s.cancelMap[buildIDStr] = buildCancel
+ s.mu.Unlock()
+ defer func() {
+ s.mu.Lock()
+ delete(s.cancelMap, buildIDStr)
+ s.mu.Unlock()
+ }()
+
+ build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
+ if err != nil {
+ log.Error("failed to fetch build", "error", err)
+ return
+ }
+
+ // Skip if already cancelled (Cancel was called before we dequeued).
+ if build.Status == "cancelled" {
+ log.Info("build already cancelled, skipping")
+ return
+ }
+
+ // Mark as running.
+ if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
+ ID: buildID, Status: "running",
+ }); err != nil {
+ log.Error("failed to update build status", "error", err)
+ return
+ }
+
+ // Parse user recipe.
+ var userRecipe []string
+ if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
+ return
+ }
+
+ // Pick a platform host and create a sandbox.
+ host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
+ if err != nil {
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
+ return
+ }
+
+ agent, err := s.Pool.GetForHost(host)
+ if err != nil {
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
+ return
+ }
+
+ sandboxID := id.NewSandboxID()
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
+ log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
+
+ // Resolve the base template to UUIDs. "minimal" is the zero sentinel.
+ baseTeamID := id.PlatformTeamID
+ baseTemplateID := id.MinimalTemplateID
+ if build.BaseTemplate != "minimal" {
+ baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
+ if err != nil {
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
+ return
+ }
+ baseTeamID = baseTmpl.TeamID
+ baseTemplateID = baseTmpl.ID
+ }
+
+ resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
+ SandboxId: sandboxIDStr,
+ Template: build.BaseTemplate,
+ TeamId: id.UUIDString(baseTeamID),
+ TemplateId: id.UUIDString(baseTemplateID),
+ Vcpus: build.Vcpus,
+ MemoryMb: build.MemoryMb,
+ TimeoutSec: 0, // no auto-pause for builds
+ DiskSizeMb: 5120, // 5 GB for template builds
+ }))
+ if err != nil {
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
+ return
+ }
+ _ = resp
+
+ // Record sandbox/host association.
+ _ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
+ ID: buildID,
+ SandboxID: sandboxID,
+ HostID: host.ID,
+ })
+
+ // Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
+ // valid; panic on error is appropriate here since it would be a programmer mistake.
+ preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
+ if err != nil {
+ panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
+ }
+ userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
+ if err != nil {
+ s.destroySandbox(buildCtx, agent, sandboxIDStr)
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
+ return
+ }
+ postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
+ if err != nil {
+ panic(fmt.Sprintf("invalid post-build recipe: %v", err))
+ }
+
+ // Execute build phases: pre-build → user recipe → post-build.
+ // bctx carries working directory and env vars across all phases.
+ var logs []recipe.BuildLogEntry
+ step := 0
+ bctx := &recipe.ExecContext{}
+
+ runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
+ newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
+ logs = append(logs, newEntries...)
+ step = nextStep
+ s.updateLogs(buildCtx, buildID, step, logs)
+ if !ok {
+ s.destroySandbox(buildCtx, agent, sandboxIDStr)
+ // If the build was cancelled, status is already set — don't overwrite with "failed".
+ if buildCtx.Err() != nil {
+ return false
+ }
+ last := newEntries[len(newEntries)-1]
+ reason := last.Stderr
+ if reason == "" {
+ reason = fmt.Sprintf("exit code %d", last.Exit)
+ }
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
+ }
+ return ok
+ }
+
+ if !build.SkipPrePost {
+ if !runPhase("pre-build", preBuildSteps, 0) {
+ return
+ }
+ }
+ if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
+ return
+ }
+ if !build.SkipPrePost {
+ if !runPhase("post-build", postBuildSteps, 0) {
+ return
+ }
+ }
+
+ // Healthcheck or direct snapshot.
+ var sizeBytes int64
+ if build.Healthcheck != "" {
+ log.Info("running healthcheck", "cmd", build.Healthcheck)
+ if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, build.Healthcheck); err != nil {
+ s.destroySandbox(buildCtx, agent, sandboxIDStr)
+ if buildCtx.Err() != nil {
+ return
+ }
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
+ return
+ }
+
+ // Healthcheck passed → full snapshot (with memory/CPU state).
+ log.Info("healthcheck passed, creating snapshot")
+ snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
+ SandboxId: sandboxIDStr,
+ Name: build.Name,
+ TeamId: id.UUIDString(build.TeamID),
+ TemplateId: id.UUIDString(build.TemplateID),
+ }))
+ if err != nil {
+ s.destroySandbox(buildCtx, agent, sandboxIDStr)
+ if buildCtx.Err() != nil {
+ return
+ }
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
+ return
+ }
+ sizeBytes = snapResp.Msg.SizeBytes
+ } else {
+ // No healthcheck → image-only template (rootfs only).
+ log.Info("no healthcheck, flattening rootfs")
+ flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
+ SandboxId: sandboxIDStr,
+ Name: build.Name,
+ TeamId: id.UUIDString(build.TeamID),
+ TemplateId: id.UUIDString(build.TemplateID),
+ }))
+ if err != nil {
+ s.destroySandbox(buildCtx, agent, sandboxIDStr)
+ if buildCtx.Err() != nil {
+ return
+ }
+ s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
+ return
+ }
+ sizeBytes = flatResp.Msg.SizeBytes
+ }
+
+ // Insert into templates table as a global (platform) template.
+ templateType := "base"
+ if build.Healthcheck != "" {
+ templateType = "snapshot"
+ }
+
+ if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
+ ID: build.TemplateID,
+ Name: build.Name,
+ Type: templateType,
+ Vcpus: build.Vcpus,
+ MemoryMb: build.MemoryMb,
+ SizeBytes: sizeBytes,
+ TeamID: id.PlatformTeamID,
+ }); err != nil {
+ log.Error("failed to insert template record", "error", err)
+ // Build succeeded on disk, just DB record failed — don't mark as failed.
+ }
+
+ // For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
+ // For FlattenRootfs, the sandbox is already destroyed by the flatten process.
+ // No additional destroy needed.
+
+ // Mark build as success.
+ if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
+ ID: buildID, Status: "success",
+ }); err != nil {
+ log.Error("failed to mark build as success", "error", err)
+ }
+
+ log.Info("template build completed successfully", "name", build.Name)
+}
+
+func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr, cmd string) error {
+ deadline := time.NewTimer(healthcheckTimeout)
+ defer deadline.Stop()
+ ticker := time.NewTicker(healthcheckInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-deadline.C:
+ return fmt.Errorf("healthcheck timed out after %s", healthcheckTimeout)
+ case <-ticker.C:
+ execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
+ SandboxId: sandboxIDStr,
+ Cmd: "/bin/sh",
+ Args: []string{"-c", cmd},
+ TimeoutSec: 10,
+ }))
+ cancel()
+
+ if err != nil {
+ slog.Debug("healthcheck exec error (retrying)", "error", err)
+ continue
+ }
+ if resp.Msg.ExitCode == 0 {
+ return nil
+ }
+ slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
+ }
+ }
+}
+
+func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
+ logsJSON, err := json.Marshal(logs)
+ if err != nil {
+ slog.Warn("failed to marshal build logs", "error", err)
+ return
+ }
+ if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
+ ID: buildID,
+ CurrentStep: int32(step),
+ Logs: logsJSON,
+ }); err != nil {
+ slog.Warn("failed to update build progress", "error", err)
+ }
+}
+
+func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
+ slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
+ // Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
+ ID: buildID,
+ Error: errMsg,
+ }); err != nil {
+ slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
+ }
+}
+
+func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
+ // Use a detached context so cleanup succeeds even during shutdown.
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
+ SandboxId: sandboxIDStr,
+ })); err != nil {
+ slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
+ }
+}
diff --git a/internal/service/host.go b/internal/service/host.go
index b3538dfb..74018ebe 100644
--- a/internal/service/host.go
+++ b/internal/service/host.go
@@ -27,15 +27,16 @@ type HostService struct {
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
+ CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
- TeamID string // required for BYOC, empty for regular
+ TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
- RequestingUserID string
+ RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
@@ -55,18 +56,28 @@ type HostRegisterParams struct {
Address string
}
-// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token.
+// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
+// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
RefreshToken string
+ // mTLS cert material — empty when CA is not configured.
+ CertPEM string
+ KeyPEM string
+ CACertPEM string
}
-// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh.
+// HostRefreshResult holds a new JWT and rotated refresh token after a successful
+// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
+ // mTLS cert material — empty when CA is not configured.
+ CertPEM string
+ KeyPEM string
+ CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
@@ -103,7 +114,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
}
} else {
// BYOC: platform admin, or team owner/admin.
- if p.TeamID == "" {
+ if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
@@ -124,7 +135,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
}
// Validate team exists, is not deleted, and has BYOC enabled.
- if p.TeamID != "" {
+ if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
@@ -136,25 +147,12 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
hostID := id.NewHostID()
- var teamID pgtype.Text
- if p.TeamID != "" {
- teamID = pgtype.Text{String: p.TeamID, Valid: true}
- }
- var provider pgtype.Text
- if p.Provider != "" {
- provider = pgtype.Text{String: p.Provider, Valid: true}
- }
- var az pgtype.Text
- if p.AvailabilityZone != "" {
- az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
- }
-
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
- TeamID: teamID,
- Provider: provider,
- AvailabilityZone: az,
+ TeamID: p.TeamID,
+ Provider: p.Provider,
+ AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
@@ -166,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
- HostID: hostID,
- TokenID: tokenID,
+ HostID: id.FormatHostID(hostID),
+ TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@@ -180,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
- slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
+ slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
@@ -189,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
-func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
+func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
@@ -202,7 +200,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
- if !host.TeamID.Valid || host.TeamID.String != teamID {
+ if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
@@ -224,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
- HostID: hostID,
- TokenID: tokenID,
+ HostID: id.FormatHostID(hostID),
+ TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@@ -238,7 +236,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
- slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
+ slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
@@ -262,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
- if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
+ hostID, err := id.ParseHostID(payload.HostID)
+ if err != nil {
+ return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
+ }
+ tokenID, err := id.ParseHostTokenID(payload.TokenID)
+ if err != nil {
+ return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
+ }
+
+ if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
- hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
+ hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
+ // Issue mTLS certificate if CA is configured.
+ var hc auth.HostCert
+ if s.CA != nil {
+ hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
+ if err != nil {
+ return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
+ }
+ }
+
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
- ID: payload.HostID,
- Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
- CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
- MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
- DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
- Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
+ ID: hostID,
+ Arch: p.Arch,
+ CpuCores: p.CPUCores,
+ MemoryMb: p.MemoryMB,
+ DiskGb: p.DiskGB,
+ Address: p.Address,
+ CertFingerprint: hc.Fingerprint,
+ CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
@@ -289,23 +307,29 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
}
// Mark audit trail.
- if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
+ if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
- refreshToken, err := s.issueRefreshToken(ctx, payload.HostID)
+ refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
- host, err := s.DB.GetHost(ctx, payload.HostID)
+ host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
- return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil
+ result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
+ if s.CA != nil {
+ result.CertPEM = hc.CertPEM
+ result.KeyPEM = hc.KeyPEM
+ result.CACertPEM = s.CA.PEM
+ }
+ return result, nil
}
// Refresh validates a refresh token, rotates it (revokes old, issues new),
@@ -332,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
+ // Renew mTLS certificate if CA is configured.
+ var hc auth.HostCert
+ if s.CA != nil {
+ hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
+ if err != nil {
+ return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
+ }
+ if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
+ ID: host.ID,
+ CertFingerprint: hc.Fingerprint,
+ CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
+ }); err != nil {
+ return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
+ }
+ }
+
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
@@ -344,12 +384,18 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
- return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil
+ result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
+ if s.CA != nil {
+ result.CertPEM = hc.CertPEM
+ result.KeyPEM = hc.KeyPEM
+ result.CACertPEM = s.CA.PEM
+ }
+ return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
-func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) {
+func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
@@ -375,7 +421,7 @@ func hashToken(token string) string {
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
-func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
+func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
@@ -388,21 +434,21 @@ func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
-func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
+func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
- return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
+ return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
-func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
+func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
- if !host.TeamID.Valid || host.TeamID.String != teamID {
+ if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
@@ -411,8 +457,8 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
-func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) {
- host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin)
+func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
+ host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return HostDeletePreview{}, err
}
@@ -427,7 +473,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string,
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
- ids[i] = sb.ID
+ ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
@@ -436,7 +482,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string,
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
-func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error {
+func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
@@ -453,35 +499,37 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string,
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
- ids[i] = sb.ID
+ ids[i] = id.FormatSandboxID(sb.ID)
}
return &HostHasSandboxesError{SandboxIDs: ids}
}
+ hostIDStr := id.FormatHostID(hostID)
+
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
- if host.Address.Valid && host.Address.String != "" {
+ if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
- SandboxId: sb.ID,
+ SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
- slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr)
+ slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
- slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostID, "error", rpcErr)
+ slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
- sbIDs := make([]string, len(sandboxes))
+ sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
@@ -489,18 +537,18 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string,
Column1: sbIDs,
Status: "stopped",
}); err != nil {
- slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err)
+ slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
- slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err)
+ slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
- s.Pool.Evict(hostID)
+ s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
@@ -508,7 +556,7 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string,
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
-func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) {
+func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
@@ -521,11 +569,11 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID,
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
- if !host.TeamID.Valid || host.TeamID.String != teamID {
+ if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
- if userID != "" {
+ if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
@@ -545,7 +593,7 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID,
}
// AddTag adds a tag to a host.
-func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
+func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@@ -553,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin
}
// RemoveTag removes a tag from a host.
-func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
+func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@@ -561,7 +609,7 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd
}
// ListTags returns all tags for a host.
-func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
+func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go
index 142b9bd0..2d1f68c4 100644
--- a/internal/service/sandbox.go
+++ b/internal/service/sandbox.go
@@ -27,15 +27,16 @@ type SandboxService struct {
// SandboxCreateParams holds the parameters for creating a sandbox.
type SandboxCreateParams struct {
- TeamID string
+ TeamID pgtype.UUID
Template string
VCPUs int32
MemoryMB int32
TimeoutSec int32
+ DiskSizeMB int32
}
// agentForSandbox looks up the host for the given sandbox and returns a client.
-func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) {
+func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@@ -77,18 +78,28 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
+ if p.DiskSizeMB <= 0 {
+ p.DiskSizeMB = 5120 // 5 GB default
+ }
- // If the template is a snapshot, use its baked-in vcpus/memory.
- if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
- if tmpl.Vcpus.Valid {
- p.VCPUs = tmpl.Vcpus.Int32
+ // Resolve template name → (teamID, templateID).
+ templateTeamID := id.PlatformTeamID
+ templateID := id.MinimalTemplateID
+ if p.Template != "minimal" {
+ tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
+ if err != nil {
+ return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
}
- if tmpl.MemoryMb.Valid {
- p.MemoryMB = tmpl.MemoryMb.Int32
+ templateTeamID = tmpl.TeamID
+ templateID = tmpl.ID
+ // If the template is a snapshot, use its baked-in vcpus/memory.
+ if tmpl.Type == "snapshot" {
+ p.VCPUs = tmpl.Vcpus
+ p.MemoryMB = tmpl.MemoryMb
}
}
- if p.TeamID == "" {
+ if !p.TeamID.Valid {
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
}
@@ -110,32 +121,39 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
}
sandboxID := id.NewSandboxID()
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
- ID: sandboxID,
- TeamID: p.TeamID,
- HostID: host.ID,
- Template: p.Template,
- Status: "pending",
- Vcpus: p.VCPUs,
- MemoryMb: p.MemoryMB,
- TimeoutSec: p.TimeoutSec,
+ ID: sandboxID,
+ TeamID: p.TeamID,
+ HostID: host.ID,
+ Template: p.Template,
+ Status: "pending",
+ Vcpus: p.VCPUs,
+ MemoryMb: p.MemoryMB,
+ TimeoutSec: p.TimeoutSec,
+ DiskSizeMb: p.DiskSizeMB,
+ TemplateID: templateID,
+ TemplateTeamID: templateTeamID,
}); err != nil {
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
Template: p.Template,
+ TeamId: id.UUIDString(templateTeamID),
+ TemplateId: id.UUIDString(templateID),
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
+ DiskSizeMb: p.DiskSizeMB,
}))
if err != nil {
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "error",
}); dbErr != nil {
- slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
+ slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
}
@@ -158,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
}
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
-func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
+func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
return s.DB.ListSandboxesByTeam(ctx, teamID)
}
// Get returns a single sandbox by ID, scoped to the given team.
-func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
+func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
}
// Pause snapshots and freezes a running sandbox to disk.
-func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
+func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@@ -182,26 +200,40 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
return db.Sandbox{}, err
}
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
+
+ // Pre-mark as "paused" in DB before the RPC so the reconciler does not
+ // mark the sandbox "stopped" while the host agent processes the pause.
+ if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
+ ID: sandboxID, Status: "paused",
+ }); err != nil {
+ return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
+ }
+
// Flush all metrics tiers before pausing so data survives in DB.
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
})); err != nil {
+ // Revert status on failure.
+ if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
+ ID: sandboxID, Status: "running",
+ }); dbErr != nil {
+ slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
+ }
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}
- sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
- ID: sandboxID, Status: "paused",
- })
+ sb, err = s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
- return db.Sandbox{}, fmt.Errorf("update status: %w", err)
+ return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
}
return sb, nil
}
// Resume restores a paused sandbox from snapshot.
-func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
+func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@@ -215,8 +247,10 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
return db.Sandbox{}, err
}
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
+
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
TimeoutSec: sb.TimeoutSec,
}))
if err != nil {
@@ -240,7 +274,7 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
}
// Destroy stops a sandbox and marks it as stopped.
-func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
+func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
@@ -251,6 +285,8 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
return err
}
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
+
// If running, flush 24h tier metrics for analytics before destroying.
if sb.Status == "running" {
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
@@ -258,7 +294,7 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
}
@@ -284,12 +320,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
// the returned data to DB. If allTiers is true, all three tiers are saved;
// otherwise only the 24h tier (for post-destroy analytics).
-func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID string, allTiers bool) {
+func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
}))
if err != nil {
- slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxID, "error", err)
+ slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
return
}
msg := resp.Msg
@@ -301,7 +338,8 @@ func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hosta
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
}
-func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tier string, points []*pb.MetricPoint) {
+func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
for _, p := range points {
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
SandboxID: sandboxID,
@@ -311,13 +349,13 @@ func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tie
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}); err != nil {
- slog.Warn("persist metric point failed", "sandbox_id", sandboxID, "tier", tier, "error", err)
+ slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
}
}
}
// Ping resets the inactivity timer for a running sandbox.
-func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
+func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
@@ -331,8 +369,10 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
return err
}
+ sandboxIDStr := id.FormatSandboxID(sandboxID)
+
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
- SandboxId: sandboxID,
+ SandboxId: sandboxIDStr,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)
}
@@ -344,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
Valid: true,
},
}); err != nil {
- slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
+ slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
}
return nil
}
diff --git a/internal/service/stats.go b/internal/service/stats.go
index 1a075aa0..88cace72 100644
--- a/internal/service/stats.go
+++ b/internal/service/stats.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
@@ -72,7 +73,7 @@ type StatsService struct {
// GetStats returns current stats, 30-day peaks, and a time-series for the
// given team and time range. If no snapshots exist yet, zeros are returned.
-func (s *StatsService) GetStats(ctx context.Context, teamID string, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
+func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
cfg, ok := rangeConfigs[r]
if !ok {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
@@ -132,7 +133,7 @@ GROUP BY bucket
ORDER BY bucket ASC
`
-func (s *StatsService) queryTimeSeries(ctx context.Context, teamID string, cfg rangeConfig) ([]StatPoint, error) {
+func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
if err != nil {
return nil, err
diff --git a/internal/service/team.go b/internal/service/team.go
index d4c911c2..a7acbac1 100644
--- a/internal/service/team.go
+++ b/internal/service/team.go
@@ -9,6 +9,7 @@ import (
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
@@ -43,7 +44,7 @@ type MemberInfo struct {
// callerRole fetches the calling user's role in the given team from DB.
// Returns an error wrapping "forbidden" if the caller is not a member.
-func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID string) (string, error) {
+func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: callerUserID,
TeamID: teamID,
@@ -66,7 +67,7 @@ func requireAdmin(role string) error {
}
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
-func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, error) {
+func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
if err == pgx.ErrNoRows {
@@ -81,7 +82,7 @@ func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, erro
}
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
-func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]TeamWithRole, error) {
+func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
rows, err := s.DB.GetTeamsForUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list teams: %w", err)
@@ -97,7 +98,7 @@ func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]Te
}
// CreateTeam creates a new team owned by the given user.
-func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string) (TeamWithRole, error) {
+func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
if !teamNameRE.MatchString(name) {
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
@@ -137,7 +138,7 @@ func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string)
}
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
-func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newName string) error {
+func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
if !teamNameRE.MatchString(newName) {
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
@@ -159,7 +160,7 @@ func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newN
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
// are preserved; only the team's deleted_at is set and active VMs are stopped.
-func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID string) error {
+func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
@@ -174,16 +175,16 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin
return fmt.Errorf("list active sandboxes: %w", err)
}
- var stopIDs []string
+ var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
- SandboxId: sb.ID,
+ SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
- slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err)
+ slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
@@ -201,14 +202,63 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin
}
}
+ // Clean up team-owned templates from all hosts in the background.
+ go s.cleanupTeamTemplates(context.Background(), teamID)
+
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}
+// cleanupTeamTemplates deletes all template files for a team from all online hosts,
+// then removes the DB records. Called asynchronously during team deletion.
+func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
+ templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
+ if err != nil {
+ slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
+ return
+ }
+ if len(templates) == 0 {
+ return
+ }
+
+ hosts, err := s.DB.ListActiveHosts(ctx)
+ if err != nil {
+ slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
+ return
+ }
+
+ for _, tmpl := range templates {
+ for _, host := range hosts {
+ if host.Status != "online" {
+ continue
+ }
+ agent, err := s.HostPool.GetForHost(host)
+ if err != nil {
+ continue
+ }
+ if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
+ TeamId: id.UUIDString(tmpl.TeamID),
+ TemplateId: id.UUIDString(tmpl.ID),
+ })); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
+ slog.Warn("team delete: failed to delete template on host",
+ "host_id", id.FormatHostID(host.ID),
+ "template", tmpl.Name,
+ "error", err,
+ )
+ }
+ }
+ }
+
+ // Remove DB records.
+ if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
+ slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
+ }
+}
+
// GetMembers returns all members of the team with their emails and roles.
-func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberInfo, error) {
+func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
rows, err := s.DB.GetTeamMembers(ctx, teamID)
if err != nil {
return nil, fmt.Errorf("get members: %w", err)
@@ -220,7 +270,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn
joinedAt = r.JoinedAt.Time
}
members[i] = MemberInfo{
- UserID: r.ID,
+ UserID: id.FormatUserID(r.ID),
Name: r.Name,
Email: r.Email,
Role: r.Role,
@@ -232,7 +282,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn
// AddMember adds an existing user (looked up by email) to the team as a member.
// Caller must be admin or owner (verified from DB).
-func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email string) (MemberInfo, error) {
+func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return MemberInfo{}, err
@@ -269,12 +319,12 @@ func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
}
- return MemberInfo{UserID: target.ID, Name: target.Name, Email: target.Email, Role: "member"}, nil
+ return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
}
// RemoveMember removes a user from the team.
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
-func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID string) error {
+func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
@@ -310,7 +360,7 @@ func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, ta
// UpdateMemberRole changes a member's role to admin or member.
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
// Valid target roles: "admin", "member".
-func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID, newRole string) error {
+func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
if newRole != "admin" && newRole != "member" {
return fmt.Errorf("invalid: role must be admin or member")
}
@@ -350,7 +400,7 @@ func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID
// LeaveTeam removes the calling user from the team.
// The owner cannot leave; they must delete the team instead.
-func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string) error {
+func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
@@ -371,7 +421,7 @@ func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
// be disabled — it is a one-way transition.
// Admin-only — the caller must verify admin status before invoking this.
-func (s *TeamService) SetBYOC(ctx context.Context, teamID string, enabled bool) error {
+func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
diff --git a/internal/service/template.go b/internal/service/template.go
index d669e455..22bc4d60 100644
--- a/internal/service/template.go
+++ b/internal/service/template.go
@@ -3,6 +3,8 @@ package service
import (
"context"
+ "github.com/jackc/pgx/v5/pgtype"
+
"git.omukk.dev/wrenn/sandbox/internal/db"
)
@@ -14,7 +16,7 @@ type TemplateService struct {
// List returns all templates belonging to the given team. If typeFilter is
// non-empty, only templates of that type ("base" or "snapshot") are returned.
-func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
+func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
if typeFilter != "" {
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
TeamID: teamID,
diff --git a/internal/snapshot/memfile.go b/internal/snapshot/memfile.go
index aabe8851..f7b14f9a 100644
--- a/internal/snapshot/memfile.go
+++ b/internal/snapshot/memfile.go
@@ -4,6 +4,7 @@
package snapshot
import (
+ "context"
"fmt"
"io"
"os"
@@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe
return header, nil
}
+// MergeDiffs consolidates multiple generation diff files into a single diff
+// file and resets the generation counter to 0. This is a pure file-level
+// operation — no Firecracker involvement.
+//
+// It reads each non-nil block from the appropriate diff file (as mapped by
+// the header), writes them all sequentially into a single new diff file,
+// and produces a fresh header pointing only at that file.
+//
+// diffFiles maps build ID (string) → open file path for each generation's diff.
+func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
+ blockSize := int64(header.Metadata.BlockSize)
+ mergedBuildID := uuid.New()
+
+ // Open all source diff files.
+ sources := make(map[string]*os.File, len(diffFiles))
+ for id, path := range diffFiles {
+ f, err := os.Open(path)
+ if err != nil {
+ // Close already opened files.
+ for _, sf := range sources {
+ sf.Close()
+ }
+ return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
+ }
+ sources[id] = f
+ }
+ defer func() {
+ for _, f := range sources {
+ f.Close()
+ }
+ }()
+
+ dst, err := os.Create(mergedDiffPath)
+ if err != nil {
+ return nil, fmt.Errorf("create merged diff file: %w", err)
+ }
+ defer dst.Close()
+
+ totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
+ dirty := make([]bool, totalBlocks)
+ empty := make([]bool, totalBlocks)
+ buf := make([]byte, blockSize)
+
+ for i := int64(0); i < totalBlocks; i++ {
+ offset := i * blockSize
+ mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
+ if err != nil {
+ return nil, fmt.Errorf("lookup block %d: %w", i, err)
+ }
+
+ if *buildID == uuid.Nil {
+ empty[i] = true
+ continue
+ }
+
+ src, ok := sources[buildID.String()]
+ if !ok {
+ return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
+ }
+
+ if _, err := src.ReadAt(buf, mappedOffset); err != nil {
+ return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
+ }
+
+ dirty[i] = true
+ if _, err := dst.Write(buf); err != nil {
+ return nil, fmt.Errorf("write merged block %d: %w", i, err)
+ }
+ }
+
+ // Build fresh header with generation 0.
+ dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
+ emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
+ merged := MergeMappings(dirtyMappings, emptyMappings)
+ normalized := NormalizeMappings(merged)
+
+ metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
+ newHeader, err := NewHeader(metadata, normalized)
+ if err != nil {
+ return nil, fmt.Errorf("create merged header: %w", err)
+ }
+
+ headerData, err := Serialize(metadata, normalized)
+ if err != nil {
+ return nil, fmt.Errorf("serialize merged header: %w", err)
+ }
+ if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
+ return nil, fmt.Errorf("write merged header: %w", err)
+ }
+
+ return newHeader, nil
+}
+
// isZeroBlock checks if a block is entirely zero bytes.
func isZeroBlock(block []byte) bool {
// Fast path: compare 8 bytes at a time.
diff --git a/internal/validate/name_test.go b/internal/validate/name_test.go
index 4b7769e2..f17e210a 100644
--- a/internal/validate/name_test.go
+++ b/internal/validate/name_test.go
@@ -11,7 +11,7 @@ func TestSafeName(t *testing.T) {
{"simple", "minimal", false},
{"with-dash", "template-abc123", false},
{"with-dot", "my-snapshot.v2", false},
- {"sandbox-id", "sb-12345678", false},
+ {"sandbox-id", "cl-12345678", false},
{"single-char", "a", false},
{"numbers", "123", false},
{"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false},
diff --git a/internal/vm/config.go b/internal/vm/config.go
index 35bc2939..0c1f2582 100644
--- a/internal/vm/config.go
+++ b/internal/vm/config.go
@@ -4,9 +4,13 @@ import "fmt"
// VMConfig holds the configuration for creating a Firecracker microVM.
type VMConfig struct {
- // SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4").
+ // SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4").
SandboxID string
+ // TemplateID is the template UUID string used to populate MMDS metadata
+ // so that envd can read WRENN_TEMPLATE_ID from inside the guest.
+ TemplateID string
+
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
KernelPath string
@@ -91,7 +95,7 @@ func (c *VMConfig) kernelArgs() string {
)
return fmt.Sprintf(
- "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s",
+ "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s",
c.InitPath, ipArg,
)
}
diff --git a/internal/vm/fc.go b/internal/vm/fc.go
index b5af5dbb..3d0f246d 100644
--- a/internal/vm/fc.go
+++ b/internal/vm/fc.go
@@ -101,6 +101,31 @@ func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error
})
}
+// setMMDSConfig enables MMDS V2 token-based access on the given network interface.
+// Must be called before startVM.
+func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error {
+ return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{
+ "version": "V2",
+ "network_interfaces": []string{ifaceID},
+ })
+}
+
+// mmdsMetadata is the metadata payload written to the Firecracker MMDS store.
+// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID.
+type mmdsMetadata struct {
+ SandboxID string `json:"instanceID"`
+ TemplateID string `json:"envID"`
+}
+
+// setMMDS writes sandbox metadata to the Firecracker MMDS store.
+// Can be called after the VM has started.
+func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error {
+ return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{
+ SandboxID: sandboxID,
+ TemplateID: templateID,
+ })
+}
+
// startVM issues the InstanceStart action.
func (c *fcClient) startVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/actions", map[string]string{
diff --git a/internal/vm/manager.go b/internal/vm/manager.go
index c7e34797..9e9466f5 100644
--- a/internal/vm/manager.go
+++ b/internal/vm/manager.go
@@ -71,6 +71,13 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
return nil, fmt.Errorf("start VM: %w", err)
}
+ // Step 5: Push sandbox metadata into MMDS so envd can read
+ // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
+ if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
+ _ = proc.stop()
+ return nil, fmt.Errorf("set MMDS metadata: %w", err)
+ }
+
vm := &VM{
Config: cfg,
process: proc,
@@ -108,6 +115,12 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
return fmt.Errorf("set machine config: %w", err)
}
+ // MMDS config — enable V2 token access on eth0 so that envd can read
+ // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
+ if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
+ return fmt.Errorf("set MMDS config: %w", err)
+ }
+
return nil
}
@@ -238,6 +251,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
return nil, fmt.Errorf("resume VM: %w", err)
}
+ // Step 5: Push sandbox metadata into MMDS.
+ if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
+ _ = proc.stop()
+ return nil, fmt.Errorf("set MMDS metadata: %w", err)
+ }
+
vm := &VM{
Config: cfg,
process: proc,
diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go
index f496b2cb..516e4d25 100644
--- a/proto/hostagent/gen/hostagent.pb.go
+++ b/proto/hostagent/gen/hostagent.pb.go
@@ -25,7 +25,7 @@ type CreateSandboxRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Sandbox ID assigned by the control plane. If empty, the host agent generates one.
SandboxId string `protobuf:"bytes,5,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
- // Template name (e.g., "minimal", "python311"). Determines base rootfs.
+ // Deprecated: use team_id + template_id instead.
Template string `protobuf:"bytes,1,opt,name=template,proto3" json:"template,omitempty"`
// Number of virtual CPUs (default: 1).
Vcpus int32 `protobuf:"varint,2,opt,name=vcpus,proto3" json:"vcpus,omitempty"`
@@ -33,7 +33,14 @@ type CreateSandboxRequest struct {
MemoryMb int32 `protobuf:"varint,3,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"`
// TTL in seconds. Sandbox is auto-paused after this duration of
// inactivity. 0 means no auto-pause.
- TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"`
+ TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"`
+ // Disk size in MB for the rootfs. Base images are expanded to this size
+ // at host agent startup. Default: 5120 (5 GB).
+ DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"`
+ // Team UUID that owns the template (hex string). All-zeros = platform.
+ TeamId string `protobuf:"bytes,7,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"`
+ // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel.
+ TemplateId string `protobuf:"bytes,8,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -103,6 +110,27 @@ func (x *CreateSandboxRequest) GetTimeoutSec() int32 {
return 0
}
+func (x *CreateSandboxRequest) GetDiskSizeMb() int32 {
+ if x != nil {
+ return x.DiskSizeMb
+ }
+ return 0
+}
+
+func (x *CreateSandboxRequest) GetTeamId() string {
+ if x != nil {
+ return x.TeamId
+ }
+ return ""
+}
+
+func (x *CreateSandboxRequest) GetTemplateId() string {
+ if x != nil {
+ return x.TemplateId
+ }
+ return ""
+}
+
type CreateSandboxResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
@@ -438,9 +466,14 @@ func (x *ResumeSandboxResponse) GetHostIp() string {
}
type CreateSnapshotRequest struct {
- state protoimpl.MessageState `protogen:"open.v1"`
- SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
- Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
+ state protoimpl.MessageState `protogen:"open.v1"`
+ SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
+ // Deprecated: use team_id + template_id instead.
+ Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
+ // Team UUID that will own the new template.
+ TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"`
+ // Template UUID for the new snapshot template.
+ TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -489,6 +522,20 @@ func (x *CreateSnapshotRequest) GetName() string {
return ""
}
+func (x *CreateSnapshotRequest) GetTeamId() string {
+ if x != nil {
+ return x.TeamId
+ }
+ return ""
+}
+
+func (x *CreateSnapshotRequest) GetTemplateId() string {
+ if x != nil {
+ return x.TemplateId
+ }
+ return ""
+}
+
type CreateSnapshotResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
@@ -542,8 +589,13 @@ func (x *CreateSnapshotResponse) GetSizeBytes() int64 {
}
type DeleteSnapshotRequest struct {
- state protoimpl.MessageState `protogen:"open.v1"`
- Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ state protoimpl.MessageState `protogen:"open.v1"`
+ // Deprecated: use team_id + template_id instead.
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ // Team UUID that owns the template.
+ TeamId string `protobuf:"bytes,2,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"`
+ // Template UUID to delete.
+ TemplateId string `protobuf:"bytes,3,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -585,6 +637,20 @@ func (x *DeleteSnapshotRequest) GetName() string {
return ""
}
+func (x *DeleteSnapshotRequest) GetTeamId() string {
+ if x != nil {
+ return x.TeamId
+ }
+ return ""
+}
+
+func (x *DeleteSnapshotRequest) GetTemplateId() string {
+ if x != nil {
+ return x.TemplateId
+ }
+ return ""
+}
+
type DeleteSnapshotResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -841,16 +907,19 @@ func (x *ListSandboxesResponse) GetAutoPausedSandboxIds() []string {
}
type SandboxInfo struct {
- state protoimpl.MessageState `protogen:"open.v1"`
- SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
- Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
- Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"`
- Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"`
- MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"`
- HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"`
- CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"`
- LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"`
- TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"`
+ state protoimpl.MessageState `protogen:"open.v1"`
+ SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
+ Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
+ // Deprecated: use team_id + template_id instead.
+ Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"`
+ Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"`
+ MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"`
+ HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"`
+ CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"`
+ LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"`
+ TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"`
+ TeamId string `protobuf:"bytes,10,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"`
+ TemplateId string `protobuf:"bytes,11,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -948,6 +1017,20 @@ func (x *SandboxInfo) GetTimeoutSec() int32 {
return 0
}
+func (x *SandboxInfo) GetTeamId() string {
+ if x != nil {
+ return x.TeamId
+ }
+ return ""
+}
+
+func (x *SandboxInfo) GetTemplateId() string {
+ if x != nil {
+ return x.TemplateId
+ }
+ return ""
+}
+
type WriteFileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
@@ -2171,11 +2254,126 @@ func (x *FlushSandboxMetricsResponse) GetPoints_24H() []*MetricPoint {
return nil
}
+type FlattenRootfsRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"`
+ // Deprecated: use team_id + template_id instead.
+ Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
+ // Team UUID that will own the resulting template.
+ TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"`
+ // Template UUID for the output.
+ TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *FlattenRootfsRequest) Reset() {
+ *x = FlattenRootfsRequest{}
+ mi := &file_hostagent_proto_msgTypes[40]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *FlattenRootfsRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*FlattenRootfsRequest) ProtoMessage() {}
+
+func (x *FlattenRootfsRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_hostagent_proto_msgTypes[40]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use FlattenRootfsRequest.ProtoReflect.Descriptor instead.
+func (*FlattenRootfsRequest) Descriptor() ([]byte, []int) {
+ return file_hostagent_proto_rawDescGZIP(), []int{40}
+}
+
+func (x *FlattenRootfsRequest) GetSandboxId() string {
+ if x != nil {
+ return x.SandboxId
+ }
+ return ""
+}
+
+func (x *FlattenRootfsRequest) GetName() string {
+ if x != nil {
+ return x.Name
+ }
+ return ""
+}
+
+func (x *FlattenRootfsRequest) GetTeamId() string {
+ if x != nil {
+ return x.TeamId
+ }
+ return ""
+}
+
+func (x *FlattenRootfsRequest) GetTemplateId() string {
+ if x != nil {
+ return x.TemplateId
+ }
+ return ""
+}
+
+type FlattenRootfsResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ SizeBytes int64 `protobuf:"varint,1,opt,name=size_bytes,json=sizeBytes,proto3" json:"size_bytes,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *FlattenRootfsResponse) Reset() {
+ *x = FlattenRootfsResponse{}
+ mi := &file_hostagent_proto_msgTypes[41]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *FlattenRootfsResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*FlattenRootfsResponse) ProtoMessage() {}
+
+func (x *FlattenRootfsResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_hostagent_proto_msgTypes[41]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use FlattenRootfsResponse.ProtoReflect.Descriptor instead.
+func (*FlattenRootfsResponse) Descriptor() ([]byte, []int) {
+ return file_hostagent_proto_rawDescGZIP(), []int{41}
+}
+
+func (x *FlattenRootfsResponse) GetSizeBytes() int64 {
+ if x != nil {
+ return x.SizeBytes
+ }
+ return 0
+}
+
var File_hostagent_proto protoreflect.FileDescriptor
const file_hostagent_proto_rawDesc = "" +
"\n" +
- "\x0fhostagent.proto\x12\fhostagent.v1\"\xa5\x01\n" +
+ "\x0fhostagent.proto\x12\fhostagent.v1\"\x81\x02\n" +
"\x14CreateSandboxRequest\x12\x1d\n" +
"\n" +
"sandbox_id\x18\x05 \x01(\tR\tsandboxId\x12\x1a\n" +
@@ -2183,7 +2381,12 @@ const file_hostagent_proto_rawDesc = "" +
"\x05vcpus\x18\x02 \x01(\x05R\x05vcpus\x12\x1b\n" +
"\tmemory_mb\x18\x03 \x01(\x05R\bmemoryMb\x12\x1f\n" +
"\vtimeout_sec\x18\x04 \x01(\x05R\n" +
- "timeoutSec\"g\n" +
+ "timeoutSec\x12 \n" +
+ "\fdisk_size_mb\x18\x06 \x01(\x05R\n" +
+ "diskSizeMb\x12\x17\n" +
+ "\ateam_id\x18\a \x01(\tR\x06teamId\x12\x1f\n" +
+ "\vtemplate_id\x18\b \x01(\tR\n" +
+ "templateId\"g\n" +
"\x15CreateSandboxResponse\x12\x1d\n" +
"\n" +
"sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" +
@@ -2206,17 +2409,23 @@ const file_hostagent_proto_rawDesc = "" +
"\n" +
"sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x17\n" +
- "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"J\n" +
+ "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"\x84\x01\n" +
"\x15CreateSnapshotRequest\x12\x1d\n" +
"\n" +
"sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" +
- "\x04name\x18\x02 \x01(\tR\x04name\"K\n" +
+ "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" +
+ "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" +
+ "\vtemplate_id\x18\x04 \x01(\tR\n" +
+ "templateId\"K\n" +
"\x16CreateSnapshotResponse\x12\x12\n" +
"\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n" +
"\n" +
- "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"+\n" +
+ "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"e\n" +
"\x15DeleteSnapshotRequest\x12\x12\n" +
- "\x04name\x18\x01 \x01(\tR\x04name\"\x18\n" +
+ "\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n" +
+ "\ateam_id\x18\x02 \x01(\tR\x06teamId\x12\x1f\n" +
+ "\vtemplate_id\x18\x03 \x01(\tR\n" +
+ "templateId\"\x18\n" +
"\x16DeleteSnapshotResponse\"s\n" +
"\vExecRequest\x12\x1d\n" +
"\n" +
@@ -2232,7 +2441,7 @@ const file_hostagent_proto_rawDesc = "" +
"\x14ListSandboxesRequest\"\x87\x01\n" +
"\x15ListSandboxesResponse\x127\n" +
"\tsandboxes\x18\x01 \x03(\v2\x19.hostagent.v1.SandboxInfoR\tsandboxes\x125\n" +
- "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xa4\x02\n" +
+ "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xde\x02\n" +
"\vSandboxInfo\x12\x1d\n" +
"\n" +
"sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" +
@@ -2244,7 +2453,11 @@ const file_hostagent_proto_rawDesc = "" +
"\x0fcreated_at_unix\x18\a \x01(\x03R\rcreatedAtUnix\x12-\n" +
"\x13last_active_at_unix\x18\b \x01(\x03R\x10lastActiveAtUnix\x12\x1f\n" +
"\vtimeout_sec\x18\t \x01(\x05R\n" +
- "timeoutSec\"_\n" +
+ "timeoutSec\x12\x17\n" +
+ "\ateam_id\x18\n" +
+ " \x01(\tR\x06teamId\x12\x1f\n" +
+ "\vtemplate_id\x18\v \x01(\tR\n" +
+ "templateId\"_\n" +
"\x10WriteFileRequest\x12\x1d\n" +
"\n" +
"sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" +
@@ -2319,7 +2532,17 @@ const file_hostagent_proto_rawDesc = "" +
"points_10m\x18\x01 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints10m\x126\n" +
"\tpoints_2h\x18\x02 \x03(\v2\x19.hostagent.v1.MetricPointR\bpoints2h\x128\n" +
"\n" +
- "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h2\xee\v\n" +
+ "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h\"\x83\x01\n" +
+ "\x14FlattenRootfsRequest\x12\x1d\n" +
+ "\n" +
+ "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" +
+ "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" +
+ "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" +
+ "\vtemplate_id\x18\x04 \x01(\tR\n" +
+ "templateId\"6\n" +
+ "\x15FlattenRootfsResponse\x12\x1d\n" +
+ "\n" +
+ "size_bytes\x18\x01 \x01(\x03R\tsizeBytes2\xc8\f\n" +
"\x10HostAgentService\x12X\n" +
"\rCreateSandbox\x12\".hostagent.v1.CreateSandboxRequest\x1a#.hostagent.v1.CreateSandboxResponse\x12[\n" +
"\x0eDestroySandbox\x12#.hostagent.v1.DestroySandboxRequest\x1a$.hostagent.v1.DestroySandboxResponse\x12U\n" +
@@ -2338,7 +2561,8 @@ const file_hostagent_proto_rawDesc = "" +
"\vPingSandbox\x12 .hostagent.v1.PingSandboxRequest\x1a!.hostagent.v1.PingSandboxResponse\x12L\n" +
"\tTerminate\x12\x1e.hostagent.v1.TerminateRequest\x1a\x1f.hostagent.v1.TerminateResponse\x12d\n" +
"\x11GetSandboxMetrics\x12&.hostagent.v1.GetSandboxMetricsRequest\x1a'.hostagent.v1.GetSandboxMetricsResponse\x12j\n" +
- "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponseB\xb0\x01\n" +
+ "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponse\x12X\n" +
+ "\rFlattenRootfs\x12\".hostagent.v1.FlattenRootfsRequest\x1a#.hostagent.v1.FlattenRootfsResponseB\xb0\x01\n" +
"\x10com.hostagent.v1B\x0eHostagentProtoP\x01Z;git.omukk.dev/wrenn/sandbox/proto/hostagent/gen;hostagentv1\xa2\x02\x03HXX\xaa\x02\fHostagent.V1\xca\x02\fHostagent\\V1\xe2\x02\x18Hostagent\\V1\\GPBMetadata\xea\x02\rHostagent::V1b\x06proto3"
var (
@@ -2353,7 +2577,7 @@ func file_hostagent_proto_rawDescGZIP() []byte {
return file_hostagent_proto_rawDescData
}
-var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 40)
+var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 42)
var file_hostagent_proto_goTypes = []any{
(*CreateSandboxRequest)(nil), // 0: hostagent.v1.CreateSandboxRequest
(*CreateSandboxResponse)(nil), // 1: hostagent.v1.CreateSandboxResponse
@@ -2395,6 +2619,8 @@ var file_hostagent_proto_goTypes = []any{
(*GetSandboxMetricsResponse)(nil), // 37: hostagent.v1.GetSandboxMetricsResponse
(*FlushSandboxMetricsRequest)(nil), // 38: hostagent.v1.FlushSandboxMetricsRequest
(*FlushSandboxMetricsResponse)(nil), // 39: hostagent.v1.FlushSandboxMetricsResponse
+ (*FlattenRootfsRequest)(nil), // 40: hostagent.v1.FlattenRootfsRequest
+ (*FlattenRootfsResponse)(nil), // 41: hostagent.v1.FlattenRootfsResponse
}
var file_hostagent_proto_depIdxs = []int32{
16, // 0: hostagent.v1.ListSandboxesResponse.sandboxes:type_name -> hostagent.v1.SandboxInfo
@@ -2423,25 +2649,27 @@ var file_hostagent_proto_depIdxs = []int32{
33, // 23: hostagent.v1.HostAgentService.Terminate:input_type -> hostagent.v1.TerminateRequest
36, // 24: hostagent.v1.HostAgentService.GetSandboxMetrics:input_type -> hostagent.v1.GetSandboxMetricsRequest
38, // 25: hostagent.v1.HostAgentService.FlushSandboxMetrics:input_type -> hostagent.v1.FlushSandboxMetricsRequest
- 1, // 26: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse
- 3, // 27: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse
- 5, // 28: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse
- 7, // 29: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse
- 13, // 30: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse
- 15, // 31: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse
- 18, // 32: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse
- 20, // 33: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse
- 9, // 34: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse
- 11, // 35: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse
- 22, // 36: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse
- 28, // 37: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse
- 30, // 38: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse
- 32, // 39: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse
- 34, // 40: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse
- 37, // 41: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse
- 39, // 42: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse
- 26, // [26:43] is the sub-list for method output_type
- 9, // [9:26] is the sub-list for method input_type
+ 40, // 26: hostagent.v1.HostAgentService.FlattenRootfs:input_type -> hostagent.v1.FlattenRootfsRequest
+ 1, // 27: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse
+ 3, // 28: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse
+ 5, // 29: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse
+ 7, // 30: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse
+ 13, // 31: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse
+ 15, // 32: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse
+ 18, // 33: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse
+ 20, // 34: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse
+ 9, // 35: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse
+ 11, // 36: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse
+ 22, // 37: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse
+ 28, // 38: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse
+ 30, // 39: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse
+ 32, // 40: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse
+ 34, // 41: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse
+ 37, // 42: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse
+ 39, // 43: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse
+ 41, // 44: hostagent.v1.HostAgentService.FlattenRootfs:output_type -> hostagent.v1.FlattenRootfsResponse
+ 27, // [27:45] is the sub-list for method output_type
+ 9, // [9:27] is the sub-list for method input_type
9, // [9:9] is the sub-list for extension type_name
9, // [9:9] is the sub-list for extension extendee
0, // [0:9] is the sub-list for field type_name
@@ -2471,7 +2699,7 @@ func file_hostagent_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_hostagent_proto_rawDesc), len(file_hostagent_proto_rawDesc)),
NumEnums: 0,
- NumMessages: 40,
+ NumMessages: 42,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go
index 7f0fa707..02f4ecc6 100644
--- a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go
+++ b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go
@@ -83,6 +83,9 @@ const (
// HostAgentServiceFlushSandboxMetricsProcedure is the fully-qualified name of the
// HostAgentService's FlushSandboxMetrics RPC.
HostAgentServiceFlushSandboxMetricsProcedure = "/hostagent.v1.HostAgentService/FlushSandboxMetrics"
+ // HostAgentServiceFlattenRootfsProcedure is the fully-qualified name of the HostAgentService's
+ // FlattenRootfs RPC.
+ HostAgentServiceFlattenRootfsProcedure = "/hostagent.v1.HostAgentService/FlattenRootfs"
)
// HostAgentServiceClient is a client for the hostagent.v1.HostAgentService service.
@@ -126,6 +129,11 @@ type HostAgentServiceClient interface {
// FlushSandboxMetrics returns all ring buffer tiers and clears them.
// Called by the control plane before pause/destroy to persist metrics to DB.
FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error)
+ // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW
+ // snapshot into a standalone rootfs.ext4 in the images directory, then
+ // cleans up all sandbox resources. Used by the template build system to
+ // produce image-only templates (no memory/CPU state).
+ FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error)
}
// NewHostAgentServiceClient constructs a client for the hostagent.v1.HostAgentService service. By
@@ -241,6 +249,12 @@ func NewHostAgentServiceClient(httpClient connect.HTTPClient, baseURL string, op
connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")),
connect.WithClientOptions(opts...),
),
+ flattenRootfs: connect.NewClient[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse](
+ httpClient,
+ baseURL+HostAgentServiceFlattenRootfsProcedure,
+ connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")),
+ connect.WithClientOptions(opts...),
+ ),
}
}
@@ -263,6 +277,7 @@ type hostAgentServiceClient struct {
terminate *connect.Client[gen.TerminateRequest, gen.TerminateResponse]
getSandboxMetrics *connect.Client[gen.GetSandboxMetricsRequest, gen.GetSandboxMetricsResponse]
flushSandboxMetrics *connect.Client[gen.FlushSandboxMetricsRequest, gen.FlushSandboxMetricsResponse]
+ flattenRootfs *connect.Client[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse]
}
// CreateSandbox calls hostagent.v1.HostAgentService.CreateSandbox.
@@ -350,6 +365,11 @@ func (c *hostAgentServiceClient) FlushSandboxMetrics(ctx context.Context, req *c
return c.flushSandboxMetrics.CallUnary(ctx, req)
}
+// FlattenRootfs calls hostagent.v1.HostAgentService.FlattenRootfs.
+func (c *hostAgentServiceClient) FlattenRootfs(ctx context.Context, req *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) {
+ return c.flattenRootfs.CallUnary(ctx, req)
+}
+
// HostAgentServiceHandler is an implementation of the hostagent.v1.HostAgentService service.
type HostAgentServiceHandler interface {
// CreateSandbox boots a new microVM with the given configuration.
@@ -391,6 +411,11 @@ type HostAgentServiceHandler interface {
// FlushSandboxMetrics returns all ring buffer tiers and clears them.
// Called by the control plane before pause/destroy to persist metrics to DB.
FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error)
+ // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW
+ // snapshot into a standalone rootfs.ext4 in the images directory, then
+ // cleans up all sandbox resources. Used by the template build system to
+ // produce image-only templates (no memory/CPU state).
+ FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error)
}
// NewHostAgentServiceHandler builds an HTTP handler from the service implementation. It returns the
@@ -502,6 +527,12 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han
connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")),
connect.WithHandlerOptions(opts...),
)
+ hostAgentServiceFlattenRootfsHandler := connect.NewUnaryHandler(
+ HostAgentServiceFlattenRootfsProcedure,
+ svc.FlattenRootfs,
+ connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")),
+ connect.WithHandlerOptions(opts...),
+ )
return "/hostagent.v1.HostAgentService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case HostAgentServiceCreateSandboxProcedure:
@@ -538,6 +569,8 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han
hostAgentServiceGetSandboxMetricsHandler.ServeHTTP(w, r)
case HostAgentServiceFlushSandboxMetricsProcedure:
hostAgentServiceFlushSandboxMetricsHandler.ServeHTTP(w, r)
+ case HostAgentServiceFlattenRootfsProcedure:
+ hostAgentServiceFlattenRootfsHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
@@ -614,3 +647,7 @@ func (UnimplementedHostAgentServiceHandler) GetSandboxMetrics(context.Context, *
func (UnimplementedHostAgentServiceHandler) FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlushSandboxMetrics is not implemented"))
}
+
+func (UnimplementedHostAgentServiceHandler) FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) {
+ return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlattenRootfs is not implemented"))
+}
diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto
index 214a84e0..817d5359 100644
--- a/proto/hostagent/hostagent.proto
+++ b/proto/hostagent/hostagent.proto
@@ -61,13 +61,19 @@ service HostAgentService {
// Called by the control plane before pause/destroy to persist metrics to DB.
rpc FlushSandboxMetrics(FlushSandboxMetricsRequest) returns (FlushSandboxMetricsResponse);
+ // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW
+ // snapshot into a standalone rootfs.ext4 in the images directory, then
+ // cleans up all sandbox resources. Used by the template build system to
+ // produce image-only templates (no memory/CPU state).
+ rpc FlattenRootfs(FlattenRootfsRequest) returns (FlattenRootfsResponse);
+
}
message CreateSandboxRequest {
// Sandbox ID assigned by the control plane. If empty, the host agent generates one.
string sandbox_id = 5;
- // Template name (e.g., "minimal", "python311"). Determines base rootfs.
+ // Deprecated: use team_id + template_id instead.
string template = 1;
// Number of virtual CPUs (default: 1).
@@ -79,6 +85,16 @@ message CreateSandboxRequest {
// TTL in seconds. Sandbox is auto-paused after this duration of
// inactivity. 0 means no auto-pause.
int32 timeout_sec = 4;
+
+ // Disk size in MB for the rootfs. Base images are expanded to this size
+ // at host agent startup. Default: 5120 (5 GB).
+ int32 disk_size_mb = 6;
+
+ // Team UUID that owns the template (hex string). All-zeros = platform.
+ string team_id = 7;
+
+ // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel.
+ string template_id = 8;
}
message CreateSandboxResponse {
@@ -115,7 +131,12 @@ message ResumeSandboxResponse {
message CreateSnapshotRequest {
string sandbox_id = 1;
+ // Deprecated: use team_id + template_id instead.
string name = 2;
+ // Team UUID that will own the new template.
+ string team_id = 3;
+ // Template UUID for the new snapshot template.
+ string template_id = 4;
}
message CreateSnapshotResponse {
@@ -124,7 +145,12 @@ message CreateSnapshotResponse {
}
message DeleteSnapshotRequest {
+ // Deprecated: use team_id + template_id instead.
string name = 1;
+ // Team UUID that owns the template.
+ string team_id = 2;
+ // Template UUID to delete.
+ string template_id = 3;
}
message DeleteSnapshotResponse {}
@@ -156,6 +182,7 @@ message ListSandboxesResponse {
message SandboxInfo {
string sandbox_id = 1;
string status = 2;
+ // Deprecated: use team_id + template_id instead.
string template = 3;
int32 vcpus = 4;
int32 memory_mb = 5;
@@ -163,6 +190,8 @@ message SandboxInfo {
int64 created_at_unix = 7;
int64 last_active_at_unix = 8;
int32 timeout_sec = 9;
+ string team_id = 10;
+ string template_id = 11;
}
message WriteFileRequest {
@@ -284,3 +313,19 @@ message FlushSandboxMetricsResponse {
repeated MetricPoint points_2h = 2;
repeated MetricPoint points_24h = 3;
}
+
+// ── FlattenRootfs ────────────────────────────────────────────────────
+
+message FlattenRootfsRequest {
+ string sandbox_id = 1;
+ // Deprecated: use team_id + template_id instead.
+ string name = 2;
+ // Team UUID that will own the resulting template.
+ string team_id = 3;
+ // Template UUID for the output.
+ string template_id = 4;
+}
+
+message FlattenRootfsResponse {
+ int64 size_bytes = 1;
+}
diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh
index ce1dd526..2f96a3a7 100755
--- a/scripts/rootfs-from-container.sh
+++ b/scripts/rootfs-from-container.sh
@@ -3,7 +3,10 @@
# rootfs-from-container.sh — Create a bootable Wrenn rootfs from a Docker container.
#
# Exports a container's filesystem, writes it into an ext4 image, injects
-# envd + wrenn-init, and shrinks the image to minimum size.
+# envd + wrenn-init + tini, and shrinks the image to minimum size.
+#
+# The container image must already include: socat, chrony, curl, ca-certificates, git.
+# These are installed via apt in the container before export, not injected here.
#
# Usage:
# bash scripts/rootfs-from-container.sh
@@ -13,7 +16,7 @@
# image_name — Directory name under images dir (e.g. "waitlist")
#
# Output:
-# ${AGENT_FILES_ROOTDIR}/images//rootfs.ext4
+# ${WRENN_DIR}/images//rootfs.ext4
#
# Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini download)
# Sudo is used only for mount/umount/copy-into-image operations.
@@ -22,8 +25,8 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
-AGENT_FILES_ROOTDIR="${AGENT_FILES_ROOTDIR:-/var/lib/wrenn}"
-AGENT_IMAGES_PATH="${AGENT_FILES_ROOTDIR}/images"
+WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}"
+WRENN_IMAGES_PATH="${WRENN_DIR}/images"
if [ $# -lt 2 ]; then
echo "Usage: $0 "
@@ -32,7 +35,7 @@ fi
CONTAINER="$1"
IMAGE_NAME="$2"
-OUTPUT_DIR="${AGENT_IMAGES_PATH}/${IMAGE_NAME}"
+OUTPUT_DIR="${WRENN_IMAGES_PATH}/${IMAGE_NAME}"
OUTPUT_FILE="${OUTPUT_DIR}/rootfs.ext4"
MOUNT_DIR="/tmp/wrenn-rootfs-build"
TAR_FILE="/tmp/wrenn-rootfs-export-${IMAGE_NAME}.tar"
@@ -130,16 +133,29 @@ sudo mkdir -p "${MOUNT_DIR}/sbin"
sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini"
sudo chmod 755 "${MOUNT_DIR}/sbin/tini"
-# Step 6: Verify.
+# Step 6: Verify injected binaries and required container packages.
echo ""
echo "==> Installed guest binaries:"
ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini"
+echo ""
+echo "==> Checking required container packages..."
+MISSING_PKGS=""
+for bin in socat chronyd curl git; do
+ if ! find "${MOUNT_DIR}" -name "${bin}" -type f 2>/dev/null | head -1 | grep -q .; then
+ MISSING_PKGS="${MISSING_PKGS} ${bin}"
+ fi
+done
+if [ -n "${MISSING_PKGS}" ]; then
+ echo "WARNING: The following binaries were not found in the container image:${MISSING_PKGS}"
+ echo " Install them in the container (via apt) before exporting."
+fi
+
# Unmount before shrinking.
sudo umount "${MOUNT_DIR}"
rmdir "${MOUNT_DIR}" 2>/dev/null || true
-# Step 7: Shrink the image to minimum size.
+# Step 8: Shrink the image to minimum size.
echo ""
echo "==> Shrinking image..."
sudo e2fsck -fy "${OUTPUT_FILE}"
diff --git a/scripts/update-debug-rootfs.sh b/scripts/update-debug-rootfs.sh
index 7d0544e1..bdedded2 100755
--- a/scripts/update-debug-rootfs.sh
+++ b/scripts/update-debug-rootfs.sh
@@ -1,11 +1,11 @@
#!/usr/bin/env bash
#
-# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init) into the debug rootfs.
+# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init + tini) into the debug rootfs.
#
# This script:
# 1. Builds a fresh envd static binary via make
# 2. Mounts the rootfs image
-# 3. Copies envd and wrenn-init into the image
+# 3. Copies envd, wrenn-init, and tini into the image
# 4. Unmounts cleanly
#
# Usage: