From 602ee470d9ac213a6df1690c2f6f0092d3828788 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 02:11:54 +0600 Subject: [PATCH 01/28] WIP: Add socat injection to rootfs build scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Inject a statically-linked socat binary into rootfs images. envd's port forwarder requires socat to bridge localhost-listening services (e.g. Jupyter kernel) to the guest TAP interface. Both scripts follow the same 3-step resolution: check rootfs, check host, build from source (http://www.dest-unreach.org/socat/ v1.8.1.1). Static linkage is verified before injection. This is an intermediate state — needs further work for the full code interpreter feature. --- scripts/rootfs-from-container.sh | 44 +++++++++++++++++++++++++++++--- scripts/update-debug-rootfs.sh | 37 ++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index ce1dd526..c5756d76 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -15,7 +15,8 @@ # Output: # ${AGENT_FILES_ROOTDIR}/images//rootfs.ext4 # -# Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini download) +# Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini/socat download), +# gcc, make (for building socat from source) # Sudo is used only for mount/umount/copy-into-image operations. set -euo pipefail @@ -130,16 +131,51 @@ sudo mkdir -p "${MOUNT_DIR}/sbin" sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini" sudo chmod 755 "${MOUNT_DIR}/sbin/tini" -# Step 6: Verify. +echo "==> Installing socat..." +SOCAT_BIN="" +# 1. Already in the exported container image? +for p in "${MOUNT_DIR}/usr/bin/socat" "${MOUNT_DIR}/usr/local/bin/socat"; do + if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi +done +# 2. Available on the host? +if [ -z "${SOCAT_BIN}" ]; then + for p in /usr/bin/socat /usr/local/bin/socat; do + if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi + done +fi +# 3. Build from source. +if [ -z "${SOCAT_BIN}" ]; then + SOCAT_VERSION="1.8.1.1" + SOCAT_URL="http://www.dest-unreach.org/socat/download/socat-${SOCAT_VERSION}.tar.gz" + SOCAT_BUILD_DIR="/tmp/socat-build" + echo " Building socat ${SOCAT_VERSION} from source..." + rm -rf "${SOCAT_BUILD_DIR}" + mkdir -p "${SOCAT_BUILD_DIR}" + curl -fsSL "${SOCAT_URL}" | tar xz -C "${SOCAT_BUILD_DIR}" --strip-components=1 + (cd "${SOCAT_BUILD_DIR}" && LDFLAGS="-static" ./configure --quiet && make -j"$(nproc)" -s) + SOCAT_BIN="${SOCAT_BUILD_DIR}/socat" + if [ ! -f "${SOCAT_BIN}" ]; then + echo "ERROR: socat build failed" + exit 1 + fi + if ! file "${SOCAT_BIN}" | grep -q "statically linked"; then + echo "ERROR: socat is not statically linked!" + exit 1 + fi +fi +sudo cp "${SOCAT_BIN}" "${MOUNT_DIR}/usr/local/bin/socat" +sudo chmod 755 "${MOUNT_DIR}/usr/local/bin/socat" + +# Step 7: Verify. echo "" echo "==> Installed guest binaries:" -ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" +ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" "${MOUNT_DIR}/usr/local/bin/socat" # Unmount before shrinking. sudo umount "${MOUNT_DIR}" rmdir "${MOUNT_DIR}" 2>/dev/null || true -# Step 7: Shrink the image to minimum size. +# Step 8: Shrink the image to minimum size. echo "" echo "==> Shrinking image..." sudo e2fsck -fy "${OUTPUT_FILE}" diff --git a/scripts/update-debug-rootfs.sh b/scripts/update-debug-rootfs.sh index 7d0544e1..76d3ed42 100755 --- a/scripts/update-debug-rootfs.sh +++ b/scripts/update-debug-rootfs.sh @@ -96,10 +96,45 @@ sudo mkdir -p "${MOUNT_DIR}/sbin" sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini" sudo chmod 755 "${MOUNT_DIR}/sbin/tini" +echo "==> Installing socat..." +SOCAT_BIN="" +# 1. Already in the rootfs? +for p in "${MOUNT_DIR}/usr/bin/socat" "${MOUNT_DIR}/usr/local/bin/socat"; do + if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi +done +# 2. Available on the host? +if [ -z "${SOCAT_BIN}" ]; then + for p in /usr/bin/socat /usr/local/bin/socat; do + if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi + done +fi +# 3. Build from source. +if [ -z "${SOCAT_BIN}" ]; then + SOCAT_VERSION="1.8.1.1" + SOCAT_URL="http://www.dest-unreach.org/socat/download/socat-${SOCAT_VERSION}.tar.gz" + SOCAT_BUILD_DIR="/tmp/socat-build" + echo " Building socat ${SOCAT_VERSION} from source..." + rm -rf "${SOCAT_BUILD_DIR}" + mkdir -p "${SOCAT_BUILD_DIR}" + curl -fsSL "${SOCAT_URL}" | tar xz -C "${SOCAT_BUILD_DIR}" --strip-components=1 + (cd "${SOCAT_BUILD_DIR}" && LDFLAGS="-static" ./configure --quiet && make -j"$(nproc)" -s) + SOCAT_BIN="${SOCAT_BUILD_DIR}/socat" + if [ ! -f "${SOCAT_BIN}" ]; then + echo "ERROR: socat build failed" + exit 1 + fi + if ! file "${SOCAT_BIN}" | grep -q "statically linked"; then + echo "ERROR: socat is not statically linked!" + exit 1 + fi +fi +sudo cp "${SOCAT_BIN}" "${MOUNT_DIR}/usr/local/bin/socat" +sudo chmod 755 "${MOUNT_DIR}/usr/local/bin/socat" + # Step 4: Verify. echo "" echo "==> Installed files:" -ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" +ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" "${MOUNT_DIR}/usr/local/bin/socat" echo "" echo "==> Done. Rootfs updated: ${ROOTFS}" From f4675ebfc08f998457a2e4d13bc563d4f9bf62eb Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 02:12:01 +0600 Subject: [PATCH 02/28] WIP: Add HTTP proxy endpoint to host agent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add /proxy/{sandbox_id}/{port}/* handler that reverse-proxies HTTP requests to services running inside sandbox VMs. The sandbox's host IP (10.11.0.{idx}) is used as the upstream target. Includes port validation (1-65535) and shared HTTP transport for connection pooling. Supports WebSocket upgrades for protocols like Jupyter's streaming API. This is an intermediate state — needs further work for the full code interpreter feature. --- cmd/host-agent/main.go | 3 ++ internal/hostagent/proxy.go | 94 +++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 internal/hostagent/proxy.go diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 2d34cd1d..130faaf7 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -119,8 +119,11 @@ func main() { }) path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) + proxyHandler := hostagent.NewProxyHandler(mgr) + mux := http.NewServeMux() mux.Handle(path, handler) + mux.Handle("/proxy/", proxyHandler) httpServer.Handler = mux // Start heartbeat loop. Handler must be set before this because the diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go new file mode 100644 index 00000000..b4a39ee0 --- /dev/null +++ b/internal/hostagent/proxy.go @@ -0,0 +1,94 @@ +package hostagent + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httputil" + "strconv" + "strings" + + "git.omukk.dev/wrenn/sandbox/internal/models" + "git.omukk.dev/wrenn/sandbox/internal/sandbox" +) + +// ProxyHandler reverse-proxies HTTP requests to services running inside +// sandboxes. It handles requests of the form: +// +// /proxy/{sandbox_id}/{port}/{path...} +// +// The sandbox's HostIP (routable on this machine) is used as the upstream. +// This supports any protocol that rides on HTTP, including WebSocket upgrades. +type ProxyHandler struct { + mgr *sandbox.Manager + transport http.RoundTripper +} + +// NewProxyHandler creates a new sandbox proxy handler. +func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler { + return &ProxyHandler{ + mgr: mgr, + transport: http.DefaultTransport, + } +} + +// ServeHTTP implements http.Handler. +func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Expected path: /proxy/{sandbox_id}/{port}/... + // After trimming "/proxy/", we get "{sandbox_id}/{port}/..." + trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/") + if trimmed == r.URL.Path { + http.Error(w, "invalid proxy path", http.StatusBadRequest) + return + } + + parts := strings.SplitN(trimmed, "/", 3) + if len(parts) < 2 { + http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest) + return + } + + sandboxID := parts[0] + port := parts[1] + remainder := "" + if len(parts) == 3 { + remainder = parts[2] + } + + // Validate port is a number in the valid range. + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + http.Error(w, "invalid port", http.StatusBadRequest) + return + } + + sb, err := h.mgr.Get(sandboxID) + if err != nil { + http.Error(w, "sandbox not found", http.StatusNotFound) + return + } + + if sb.Status != models.StatusRunning { + http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) + return + } + + targetHost := fmt.Sprintf("%s:%d", sb.HostIP.String(), portNum) + + proxy := &httputil.ReverseProxy{ + Transport: h.transport, + Director: func(req *http.Request) { + req.URL.Scheme = "http" + req.URL.Host = targetHost + req.URL.Path = "/" + remainder + req.URL.RawQuery = r.URL.RawQuery + req.Host = targetHost + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err) + http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) + }, + } + + proxy.ServeHTTP(w, r) +} From 4be65b0abbebfcf4692845f13d95e2002131419b Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 02:12:10 +0600 Subject: [PATCH 03/28] WIP: Add sandbox proxy catch-all to control plane MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add SandboxProxyWrapper that intercepts requests with Host headers matching {port}-{sandbox_id}.{domain} and proxies them through the owning host agent's /proxy endpoint. Authentication is via X-API-Key only (no JWT). The API key's team must own the sandbox. Export EnsureScheme from lifecycle package for reuse. Request flow: SDK -> Caddy -> CP catch-all -> Host Agent -> sandbox VM. This is an intermediate state — needs further work for the full code interpreter feature. --- cmd/control-plane/main.go | 7 +- internal/api/handler_sandbox_proxy.go | 149 ++++++++++++++++++++++++++ internal/lifecycle/hostpool.go | 6 +- 3 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 internal/api/handler_sandbox_proxy.go diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 9f84edcc..40c3c48a 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -98,9 +98,14 @@ func main() { sampler := api.NewMetricsSampler(queries, 10*time.Second) sampler.Start(ctx) + // Wrap the API handler with the sandbox proxy so that requests with + // {port}-{sandbox_id}.{domain} Host headers are routed to the sandbox's + // host agent. All other requests pass through to the normal API router. + proxyWrapper := api.NewSandboxProxyWrapper(srv.Handler(), queries, hostPool) + httpServer := &http.Server{ Addr: cfg.ListenAddr, - Handler: srv.Handler(), + Handler: proxyWrapper, } // Graceful shutdown on signal. diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go new file mode 100644 index 00000000..019fd7fd --- /dev/null +++ b/internal/api/handler_sandbox_proxy.go @@ -0,0 +1,149 @@ +package api + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "strconv" + "strings" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" +) + +// sandboxHostPattern matches hostnames like "49999-sb-abcd1234.localhost" or +// "49999-sb-abcd1234.example.com". Captures: port, sandbox ID. +var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f]+)\.`) + +// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests +// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching +// requests are reverse-proxied through the host agent that owns the sandbox. +// All other requests are passed through to the inner handler. +// +// Authentication is via X-API-Key header only (no JWT). The API key's team +// must own the sandbox. +type SandboxProxyWrapper struct { + inner http.Handler + db *db.Queries + pool *lifecycle.HostClientPool + transport http.RoundTripper +} + +// NewSandboxProxyWrapper creates a new proxy wrapper. +func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper { + return &SandboxProxyWrapper{ + inner: inner, + db: queries, + pool: pool, + transport: http.DefaultTransport, + } +} + +func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host := r.Host + // Strip port from Host header (e.g. "49999-sb-abcd1234.localhost:8000" → "49999-sb-abcd1234.localhost") + if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 { + host = host[:colonIdx] + } + + matches := sandboxHostPattern.FindStringSubmatch(host) + if matches == nil { + h.inner.ServeHTTP(w, r) + return + } + + port := matches[1] + sandboxID := matches[2] + + // Validate port. + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + http.Error(w, "invalid port", http.StatusBadRequest) + return + } + + // Authenticate: require API key or JWT, extract team ID. + teamID, err := h.authenticateRequest(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", err.Error()) + return + } + + ctx := r.Context() + + // Look up sandbox and verify ownership. + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ + ID: sandboxID, + TeamID: teamID, + }) + if err != nil { + http.Error(w, "sandbox not found", http.StatusNotFound) + return + } + + if sb.Status != "running" { + http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) + return + } + + agentHost, err := h.db.GetHost(ctx, sb.HostID) + if err != nil { + http.Error(w, "host agent not found", http.StatusServiceUnavailable) + return + } + + if !agentHost.Address.Valid || agentHost.Address.String == "" { + http.Error(w, "host agent has no address", http.StatusServiceUnavailable) + return + } + + agentAddr := lifecycle.EnsureScheme(agentHost.Address.String) + upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxID, port, r.URL.Path) + + target, err := url.Parse(agentAddr) + if err != nil { + http.Error(w, "invalid host agent address", http.StatusInternalServerError) + return + } + + proxy := &httputil.ReverseProxy{ + Transport: h.transport, + Director: func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = upstreamPath + req.URL.RawQuery = r.URL.RawQuery + req.Host = target.Host + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + slog.Debug("sandbox proxy error", + "sandbox_id", sandboxID, + "port", port, + "error", err, + ) + http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) + }, + } + + proxy.ServeHTTP(w, r) +} + +// authenticateRequest validates the request's API key and returns the team ID. +// Only API key authentication is supported for sandbox proxy requests (not JWT). +func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (string, error) { + key := r.Header.Get("X-API-Key") + if key == "" { + return "", fmt.Errorf("X-API-Key header required") + } + + hash := auth.HashAPIKey(key) + row, err := h.db.GetAPIKeyByHash(r.Context(), hash) + if err != nil { + return "", fmt.Errorf("invalid API key") + } + return row.TeamID, nil +} diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go index 0caf5ece..c6e724b1 100644 --- a/internal/lifecycle/hostpool.go +++ b/internal/lifecycle/hostpool.go @@ -45,7 +45,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen if c, ok = p.clients[hostID]; ok { return c } - c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address)) + c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, EnsureScheme(address)) p.clients[hostID] = c return c } @@ -68,8 +68,8 @@ func (p *HostClientPool) Evict(hostID string) { p.mu.Unlock() } -// ensureScheme adds "http://" if the address has no scheme. -func ensureScheme(addr string) string { +// EnsureScheme adds "http://" if the address has no scheme. +func EnsureScheme(addr string) string { if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { return addr } From b0a8b498a891f0b11a35487efd18d6bfc69f488a Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 02:12:21 +0600 Subject: [PATCH 04/28] WIP: Add Caddy reverse proxy for dev environment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Caddy to docker-compose as the single entry point on port 8000: - localhost -> /api/* stripped and proxied to CP:8080, /* to frontend:5173 - *.localhost -> proxied to CP:8080 (sandbox proxy catch-all) - Direct /v1/*, /auth/*, /docs routes proxied to CP Move CP from :8000 to :8080 (its default). Caddy takes :8000. Update .env.example, vite proxy target (kept as fallback), and Makefile dev targets (pg_isready via docker exec, frontend binds 0.0.0.0). This is an intermediate state — needs further work for the full code interpreter feature. --- .env.example | 4 ++-- Makefile | 4 ++-- deploy/Caddyfile.dev | 41 +++++++++++++++++++++++++++++++++++ deploy/docker-compose.dev.yml | 17 +++++---------- frontend/vite.config.ts | 2 +- 5 files changed, 52 insertions(+), 16 deletions(-) create mode 100644 deploy/Caddyfile.dev diff --git a/.env.example b/.env.example index dee152cf..bf528015 100644 --- a/.env.example +++ b/.env.example @@ -5,13 +5,13 @@ DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable REDIS_URL=redis://localhost:6379/0 # Control Plane -CP_LISTEN_ADDR=:8000 +CP_LISTEN_ADDR=:8080 # Host Agent AGENT_LISTEN_ADDR=:50051 AGENT_FILES_ROOTDIR=/var/lib/wrenn AGENT_HOST_INTERFACE=eth0 -AGENT_CP_URL=http://localhost:8000 +AGENT_CP_URL=http://localhost:8080 # Lago (billing — external service) LAGO_API_URL=http://localhost:3000 diff --git a/Makefile b/Makefile index 4a2e0b61..80fbd3a7 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ dev: dev-infra migrate-up dev-cp dev-infra: docker compose -f deploy/docker-compose.dev.yml up -d @echo "Waiting for PostgreSQL..." - @until pg_isready -h localhost -p 5432 -q; do sleep 0.5; done + @until docker compose -f deploy/docker-compose.dev.yml exec -T postgres pg_isready -q 2>/dev/null; do sleep 0.5; done @echo "Dev infrastructure ready." dev-down: @@ -53,7 +53,7 @@ dev-agent: sudo go run ./cmd/host-agent dev-frontend: - cd frontend && pnpm dev --port 5173 + cd frontend && pnpm dev --port 5173 --host 0.0.0.0 dev-envd: cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002 diff --git a/deploy/Caddyfile.dev b/deploy/Caddyfile.dev new file mode 100644 index 00000000..789f8dfd --- /dev/null +++ b/deploy/Caddyfile.dev @@ -0,0 +1,41 @@ +# Sandbox port forwarding: {port}-{sandbox_id}.localhost +# Matches subdomains like 49999-sb-abcd1234.localhost and proxies them +# to the control plane, which inspects the Host header and routes to +# the correct host agent. +# +# NOTE: Wildcard *.localhost DNS resolution requires local setup. +# Option 1: Add entries to /etc/hosts for each sandbox +# Option 2: Use dnsmasq: address=/.localhost/127.0.0.1 +# Option 3: Use systemd-resolved (Ubuntu default — *.localhost resolves to 127.0.0.1) +http://*.localhost { + reverse_proxy host.docker.internal:8080 +} + +# Main entry point: API + frontend +http://localhost { + # API routes — strip /api prefix and proxy to the control plane. + # The frontend calls /api/v1/... which becomes /v1/... at the CP. + handle_path /api/* { + reverse_proxy host.docker.internal:8080 + } + + # Backend routes served directly (SDK clients, OAuth initiation) + handle /v1/* { + reverse_proxy host.docker.internal:8080 + } + handle /openapi.yaml { + reverse_proxy host.docker.internal:8080 + } + handle /docs { + reverse_proxy host.docker.internal:8080 + } + handle /auth/oauth/* { + reverse_proxy host.docker.internal:8080 + } + + # Everything else — proxy to the frontend dev server + # This includes: /login, /dashboard/*, /admin/*, /auth/github/callback + handle { + reverse_proxy host.docker.internal:5173 + } +} diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index ebcd3087..28f401f6 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -15,19 +15,14 @@ services: ports: - "6379:6379" - prometheus: - image: prom/prometheus:latest + caddy: + image: caddy:2-alpine ports: - - "9090:9090" + - "8000:80" volumes: - - ./deploy/prometheus.yml:/etc/prometheus/prometheus.yml - - grafana: - image: grafana/grafana:latest - ports: - - "3001:3000" - environment: - GF_SECURITY_ADMIN_PASSWORD: admin + - ./Caddyfile.dev:/etc/caddy/Caddyfile:ro + extra_hosts: + - "host.docker.internal:host-gateway" volumes: pgdata: diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 89afaba3..350070b5 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -7,7 +7,7 @@ export default defineConfig({ server: { proxy: { '/api': { - target: 'http://localhost:8000', + target: 'http://localhost:8080', rewrite: (path) => path.replace(/^\/api/, '') } } From 139f86bf9c6ba9ce9c6af2cb87e099108e43cd0b Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 02:13:12 +0600 Subject: [PATCH 05/28] Fix static build: disable prerender for dynamic capsule detail route The [id] route cannot be prerendered at build time since IDs are unknown. With adapter-static's index.html fallback, the route is handled client-side. --- frontend/src/routes/dashboard/capsules/[id]/+page.js | 1 + 1 file changed, 1 insertion(+) create mode 100644 frontend/src/routes/dashboard/capsules/[id]/+page.js diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.js b/frontend/src/routes/dashboard/capsules/[id]/+page.js new file mode 100644 index 00000000..d43d0cd2 --- /dev/null +++ b/frontend/src/routes/dashboard/capsules/[id]/+page.js @@ -0,0 +1 @@ +export const prerender = false; From 12d1e356fa8785dce276a1005f2215a03296c48d Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 03:58:12 +0600 Subject: [PATCH 06/28] Minor UI copy updates across capsules and templates pages --- .../lib/components/CreateCapsuleDialog.svelte | 5 +- .../routes/dashboard/capsules/+layout.svelte | 2 +- .../routes/dashboard/capsules/+page.svelte | 22 +++++--- .../dashboard/capsules/[id]/+page.svelte | 2 +- .../routes/dashboard/snapshots/+page.svelte | 53 ++++++++++--------- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/frontend/src/lib/components/CreateCapsuleDialog.svelte b/frontend/src/lib/components/CreateCapsuleDialog.svelte index b570f2b4..2bd027d6 100644 --- a/frontend/src/lib/components/CreateCapsuleDialog.svelte +++ b/frontend/src/lib/components/CreateCapsuleDialog.svelte @@ -56,6 +56,7 @@ class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]" placeholder="minimal" /> +

Name of a snapshot or base image to boot from.

@@ -85,14 +86,16 @@
- + +

Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.

diff --git a/frontend/src/routes/dashboard/capsules/+layout.svelte b/frontend/src/routes/dashboard/capsules/+layout.svelte index d9f93f82..5c400b9a 100644 --- a/frontend/src/routes/dashboard/capsules/+layout.svelte +++ b/frontend/src/routes/dashboard/capsules/+layout.svelte @@ -47,7 +47,7 @@ Capsules

- Isolated VMs. Start cold in under a second — pause, snapshot, or destroy at will. + All active and recent capsules across your team.

diff --git a/frontend/src/routes/dashboard/capsules/+page.svelte b/frontend/src/routes/dashboard/capsules/+page.svelte index 4f270035..afb9de0c 100644 --- a/frontend/src/routes/dashboard/capsules/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/+page.svelte @@ -247,6 +247,13 @@ return `${Math.floor(seconds / 86400)}d ago`; } + function fmtTimeout(sec: number): string { + if (!sec) return 'None'; + if (sec < 60) return `${sec}s`; + if (sec < 3600) return `${Math.round(sec / 60)}m`; + return `${Math.round(sec / 3600)}h`; + } + function handleClickOutside(event: MouseEvent) { if (openMenuId && !(event.target as Element)?.closest('.status-menu-container')) { openMenuId = null; @@ -300,7 +307,7 @@ class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-2)] py-2 pl-9 pr-3 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]" /> - {filteredCapsules.length} total + {filteredCapsules.length} capsule{filteredCapsules.length !== 1 ? 's' : ''}
@@ -363,8 +370,11 @@ {#if error} -
- {error} +
+ + + + {error}. Try refreshing the page.
{/if} @@ -466,14 +476,14 @@
- {capsule.timeout_sec ? `${capsule.timeout_sec}s` : '—'} + {fmtTimeout(capsule.timeout_sec)}
{formatTime(capsule.started_at)} {#if capsule.last_active_at} - {timeAgo(capsule.last_active_at)} + active {timeAgo(capsule.last_active_at)} {/if}
@@ -612,7 +622,7 @@ -

This capsule will be paused first — memory state is captured at rest.

+

This capsule will be paused first, then its full state (memory + disk) will be captured.

{:else}

The capsule's current memory state will be captured and stored as a reusable snapshot.

diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte index e932209a..ed26426d 100644 --- a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte @@ -512,7 +512,7 @@ - Failed to load metrics: {metricsError} + Could not load metrics: {metricsError}. Will retry automatically. {/if} diff --git a/frontend/src/routes/dashboard/snapshots/+page.svelte b/frontend/src/routes/dashboard/snapshots/+page.svelte index e39bf3b6..6f86d4c2 100644 --- a/frontend/src/routes/dashboard/snapshots/+page.svelte +++ b/frontend/src/routes/dashboard/snapshots/+page.svelte @@ -114,15 +114,15 @@ } function emptyHeading(f: TypeFilter): string { - if (f === 'snapshot') return 'No snapshots'; - if (f === 'base') return 'No images'; - return 'No templates yet'; + if (f === 'snapshot') return 'No snapshots yet'; + if (f === 'base') return 'No base images'; + return 'No snapshots yet'; } function emptyDescription(f: TypeFilter): string { - if (f === 'snapshot') return 'Pause a capsule from the Capsules page, then snapshot it to capture its state.'; - if (f === 'base') return 'Base images are added by the Wrenn team. Contact support to request a custom image.'; - return 'To create a snapshot, go to Capsules, pause a running capsule, then choose Snapshot.'; + if (f === 'snapshot') return 'Pause a running capsule, then choose Snapshot to save its state.'; + if (f === 'base') return 'Base images are provided by the Wrenn team. Contact support to request a custom one.'; + return 'Pause a running capsule, then choose Snapshot to save its state. You can launch new capsules from any snapshot.'; } onMount(fetchSnapshots); @@ -162,7 +162,7 @@ Templates

- Snapshots capture a live capsule state. Base images are the rootfs every capsule starts from. Launch a full VM from any template. + Snapshots capture a running capsule's state. Base images are the starting point for every new capsule. Launch from either.

@@ -206,8 +206,11 @@ {#if pageTab === 'snapshots'}
{#if error} -
- {error} +
+ + + + {error}. Try refreshing the page.
{/if} @@ -274,11 +277,11 @@
{filteredSnapshots.length} - {typeFilter === 'all' - ? filteredSnapshots.length === 1 ? 'template' : 'templates' - : typeFilter === 'snapshot' - ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' - : filteredSnapshots.length === 1 ? 'image' : 'images'} + {typeFilter === 'snapshot' + ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' + : typeFilter === 'base' + ? filteredSnapshots.length === 1 ? 'image' : 'images' + : filteredSnapshots.length === 1 ? 'item' : 'total'}
@@ -447,11 +450,11 @@

{filteredSnapshots.length} - {typeFilter === 'all' - ? filteredSnapshots.length === 1 ? 'template' : 'templates' - : typeFilter === 'snapshot' - ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' - : filteredSnapshots.length === 1 ? 'image' : 'images'} + {typeFilter === 'snapshot' + ? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots' + : typeFilter === 'base' + ? filteredSnapshots.length === 1 ? 'image' : 'images' + : filteredSnapshots.length === 1 ? 'item' : 'total'} {typeFilter !== 'all' ? '· filtered' : '· total'}

{/if} @@ -513,10 +516,10 @@ class="relative w-full max-w-[380px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] p-6" style="animation: fadeUp 0.2s ease both" > -

Delete Snapshot

+

Delete snapshot

Permanently delete {deleteTarget.name}. - Any capsule using this template will not be affected, but you won't be able to launch from it again. + Running capsules won't be affected, but you won't be able to launch new ones from it.

{#if deleteTarget.type === 'snapshot'} @@ -526,7 +529,7 @@

- This live capture includes saved memory state. Any capsule relying on it will be unable to resume. + This snapshot includes memory state. Paused capsules that depend on it won't be able to resume.

{/if} @@ -580,7 +583,7 @@ >

Launch Capsule

- Configure resources and launch. The VM will clone from this template and be ready in seconds. + Configure resources and launch a new capsule from this snapshot.

{#if launchError} @@ -655,14 +658,16 @@
- + +

Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.

From 6898528096b1fbf0d81168d12dca0cef1407a2c2 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 04:47:44 +0600 Subject: [PATCH 07/28] Replace one-shot clock_settime with chrony for continuous guest time sync Switch from the envd /init endpoint pushing host time via syscall to chronyd reading the KVM PTP hardware clock (/dev/ptp0) continuously. This fixes clock drift between init calls and handles snapshot resume gracefully. Changes: - Add clocksource=kvm-clock kernel boot arg - Start chronyd in wrenn-init.sh before tini (PHC /dev/ptp0, makestep 1.0 -1) - Remove clock_settime logic from envd SetData and shouldSetSystemTime - Remove client.Init() clock sync calls from sandbox manager (3 sites) - Remove Init() method from envdclient (no longer needed) - Simplify rootfs scripts: socat/chrony now come from apt in the container image, only envd/wrenn-init/tini are injected by build scripts --- envd/internal/api/init.go | 27 ------------- envd/internal/api/init_test.go | 66 -------------------------------- images/wrenn-init.sh | 14 ++++++- internal/envdclient/client.go | 31 --------------- internal/sandbox/manager.go | 30 --------------- internal/vm/config.go | 2 +- scripts/rootfs-from-container.sh | 60 ++++++++++------------------- scripts/update-debug-rootfs.sh | 41 ++------------------ 8 files changed, 37 insertions(+), 234 deletions(-) diff --git a/envd/internal/api/init.go b/envd/internal/api/init.go index a4894599..bd2456e8 100644 --- a/envd/internal/api/init.go +++ b/envd/internal/api/init.go @@ -17,8 +17,6 @@ import ( "github.com/awnumar/memguard" "github.com/rs/zerolog" "github.com/txn2/txeh" - "golang.org/x/sys/unix" - "git.omukk.dev/wrenn/sandbox/envd/internal/host" "git.omukk.dev/wrenn/sandbox/envd/internal/logs" "git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys" @@ -29,11 +27,6 @@ var ( ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized") ) -const ( - maxTimeInPast = 50 * time.Millisecond - maxTimeInFuture = 5 * time.Second -) - // validateInitAccessToken validates the access token for /init requests. // Token is valid if it matches the existing token OR the MMDS hash. // If neither exists, first-time setup is allowed. @@ -172,20 +165,6 @@ func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJ return err } - if data.Timestamp != nil { - // Check if current time differs significantly from the received timestamp - if shouldSetSystemTime(time.Now(), *data.Timestamp) { - logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp) - ts := unix.NsecToTimespec(data.Timestamp.UnixNano()) - err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts) - if err != nil { - logger.Error().Msgf("Failed to set system time: %v", err) - } - } else { - logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp) - } - } - if data.EnvVars != nil { logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars))) @@ -309,9 +288,3 @@ func getIPFamily(address string) (txeh.IPFamily, error) { } } -// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp, -// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than -// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime. -func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool { - return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture)) -} diff --git a/envd/internal/api/init_test.go b/envd/internal/api/init_test.go index c4b6f4b5..f3db3615 100644 --- a/envd/internal/api/init_test.go +++ b/envd/internal/api/init_test.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -59,71 +58,6 @@ func TestSimpleCases(t *testing.T) { } } -func TestShouldSetSystemTime(t *testing.T) { - t.Parallel() - sandboxTime := time.Now() - - tests := []struct { - name string - hostTime time.Time - want bool - }{ - { - name: "sandbox time far ahead of host time (should set)", - hostTime: sandboxTime.Add(-10 * time.Second), - want: true, - }, - { - name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)", - hostTime: sandboxTime.Add(-50 * time.Millisecond), - want: false, - }, - { - name: "sandbox time just within maxTimeInPast ahead of host time (should not set)", - hostTime: sandboxTime.Add(-40 * time.Millisecond), - want: false, - }, - { - name: "sandbox time slightly ahead of host time (should not set)", - hostTime: sandboxTime.Add(-10 * time.Millisecond), - want: false, - }, - { - name: "sandbox time equals host time (should not set)", - hostTime: sandboxTime, - want: false, - }, - { - name: "sandbox time slightly behind host time (should not set)", - hostTime: sandboxTime.Add(1 * time.Second), - want: false, - }, - { - name: "sandbox time just within maxTimeInFuture behind host time (should not set)", - hostTime: sandboxTime.Add(4 * time.Second), - want: false, - }, - { - name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)", - hostTime: sandboxTime.Add(5 * time.Second), - want: false, - }, - { - name: "sandbox time far behind host time (should set)", - hostTime: sandboxTime.Add(1 * time.Minute), - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := shouldSetSystemTime(tt.hostTime, sandboxTime) - assert.Equal(t, tt.want, got) - }) - } -} - func secureTokenPtr(s string) *SecureToken { token := &SecureToken{} _ = token.Set([]byte(s)) diff --git a/images/wrenn-init.sh b/images/wrenn-init.sh index 32285ea7..4c393711 100644 --- a/images/wrenn-init.sh +++ b/images/wrenn-init.sh @@ -1,6 +1,6 @@ #!/bin/sh # wrenn-init: minimal PID 1 init for Firecracker microVMs. -# Mounts virtual filesystems then execs envd. +# Mounts virtual filesystems, starts chronyd for time sync, then execs tini + envd. set -e @@ -27,5 +27,17 @@ echo "nameserver 8.8.4.4" >> /etc/resolv.conf # Set a standard PATH so envd and all child processes can find common binaries. export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +# Write chrony config to sync time from the KVM PTP hardware clock. +# /dev/ptp0 is a paravirtual clock exposed by KVM — no network required. +mkdir -p /etc/chrony /run/chrony +cat > /etc/chrony/chrony.conf </dev/null || true + # Exec tini as PID 1 — it reaps zombie processes and forwards signals to envd. exec /sbin/tini -- /usr/local/bin/envd diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go index 04a1dc2e..49765690 100644 --- a/internal/envdclient/client.go +++ b/internal/envdclient/client.go @@ -3,14 +3,12 @@ package envdclient import ( "bytes" "context" - "encoding/json" "fmt" "io" "log/slog" "mime/multipart" "net/http" "net/url" - "time" "connectrpc.com/connect" @@ -49,35 +47,6 @@ func (c *Client) BaseURL() string { return c.base } -// Init calls POST /init on envd to sync the guest clock with the host. -// This is important after snapshot resume where the guest clock is frozen. -func (c *Client) Init(ctx context.Context) error { - now := time.Now().UTC() - body, err := json.Marshal(map[string]any{"timestamp": now}) - if err != nil { - return fmt.Errorf("marshal init body: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("create init request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("init request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusNoContent { - respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody)) - } - - return nil -} - // ExecResult holds the output of a command execution. type ExecResult struct { Stdout []byte diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 9a795b5c..94642639 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -203,16 +203,6 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ @@ -636,16 +626,6 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ @@ -957,16 +937,6 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam return nil, fmt.Errorf("wait for envd: %w", err) } - // Sync guest clock in background. Non-fatal — sandbox is usable before this completes. - // Run in a goroutine so Init latency doesn't block the RPC response back to the control plane. - go func() { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - if err := client.Init(initCtx); err != nil { - slog.Warn("envd init (clock sync) failed", "sandbox", sandboxID, "error", err) - } - }() - now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ diff --git a/internal/vm/config.go b/internal/vm/config.go index 35bc2939..b99480e4 100644 --- a/internal/vm/config.go +++ b/internal/vm/config.go @@ -91,7 +91,7 @@ func (c *VMConfig) kernelArgs() string { ) return fmt.Sprintf( - "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s", + "console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s", c.InitPath, ipArg, ) } diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index c5756d76..2159ac7c 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -3,7 +3,10 @@ # rootfs-from-container.sh — Create a bootable Wrenn rootfs from a Docker container. # # Exports a container's filesystem, writes it into an ext4 image, injects -# envd + wrenn-init, and shrinks the image to minimum size. +# envd + wrenn-init + tini, and shrinks the image to minimum size. +# +# The container image must already include: socat, chrony, curl, ca-certificates, git. +# These are installed via apt in the container before export, not injected here. # # Usage: # bash scripts/rootfs-from-container.sh @@ -15,8 +18,7 @@ # Output: # ${AGENT_FILES_ROOTDIR}/images//rootfs.ext4 # -# Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini/socat download), -# gcc, make (for building socat from source) +# Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini download) # Sudo is used only for mount/umount/copy-into-image operations. set -euo pipefail @@ -131,45 +133,23 @@ sudo mkdir -p "${MOUNT_DIR}/sbin" sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini" sudo chmod 755 "${MOUNT_DIR}/sbin/tini" -echo "==> Installing socat..." -SOCAT_BIN="" -# 1. Already in the exported container image? -for p in "${MOUNT_DIR}/usr/bin/socat" "${MOUNT_DIR}/usr/local/bin/socat"; do - if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi -done -# 2. Available on the host? -if [ -z "${SOCAT_BIN}" ]; then - for p in /usr/bin/socat /usr/local/bin/socat; do - if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi - done -fi -# 3. Build from source. -if [ -z "${SOCAT_BIN}" ]; then - SOCAT_VERSION="1.8.1.1" - SOCAT_URL="http://www.dest-unreach.org/socat/download/socat-${SOCAT_VERSION}.tar.gz" - SOCAT_BUILD_DIR="/tmp/socat-build" - echo " Building socat ${SOCAT_VERSION} from source..." - rm -rf "${SOCAT_BUILD_DIR}" - mkdir -p "${SOCAT_BUILD_DIR}" - curl -fsSL "${SOCAT_URL}" | tar xz -C "${SOCAT_BUILD_DIR}" --strip-components=1 - (cd "${SOCAT_BUILD_DIR}" && LDFLAGS="-static" ./configure --quiet && make -j"$(nproc)" -s) - SOCAT_BIN="${SOCAT_BUILD_DIR}/socat" - if [ ! -f "${SOCAT_BIN}" ]; then - echo "ERROR: socat build failed" - exit 1 - fi - if ! file "${SOCAT_BIN}" | grep -q "statically linked"; then - echo "ERROR: socat is not statically linked!" - exit 1 - fi -fi -sudo cp "${SOCAT_BIN}" "${MOUNT_DIR}/usr/local/bin/socat" -sudo chmod 755 "${MOUNT_DIR}/usr/local/bin/socat" - -# Step 7: Verify. +# Step 6: Verify injected binaries and required container packages. echo "" echo "==> Installed guest binaries:" -ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" "${MOUNT_DIR}/usr/local/bin/socat" +ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" + +echo "" +echo "==> Checking required container packages..." +MISSING_PKGS="" +for bin in socat chronyd curl git; do + if ! find "${MOUNT_DIR}" -name "${bin}" -type f 2>/dev/null | head -1 | grep -q .; then + MISSING_PKGS="${MISSING_PKGS} ${bin}" + fi +done +if [ -n "${MISSING_PKGS}" ]; then + echo "WARNING: The following binaries were not found in the container image:${MISSING_PKGS}" + echo " Install them in the container (via apt) before exporting." +fi # Unmount before shrinking. sudo umount "${MOUNT_DIR}" diff --git a/scripts/update-debug-rootfs.sh b/scripts/update-debug-rootfs.sh index 76d3ed42..bdedded2 100755 --- a/scripts/update-debug-rootfs.sh +++ b/scripts/update-debug-rootfs.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash # -# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init) into the debug rootfs. +# update-debug-rootfs.sh — Build envd and inject it (plus wrenn-init + tini) into the debug rootfs. # # This script: # 1. Builds a fresh envd static binary via make # 2. Mounts the rootfs image -# 3. Copies envd and wrenn-init into the image +# 3. Copies envd, wrenn-init, and tini into the image # 4. Unmounts cleanly # # Usage: @@ -96,45 +96,10 @@ sudo mkdir -p "${MOUNT_DIR}/sbin" sudo cp "${TINI_BIN}" "${MOUNT_DIR}/sbin/tini" sudo chmod 755 "${MOUNT_DIR}/sbin/tini" -echo "==> Installing socat..." -SOCAT_BIN="" -# 1. Already in the rootfs? -for p in "${MOUNT_DIR}/usr/bin/socat" "${MOUNT_DIR}/usr/local/bin/socat"; do - if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi -done -# 2. Available on the host? -if [ -z "${SOCAT_BIN}" ]; then - for p in /usr/bin/socat /usr/local/bin/socat; do - if [ -f "$p" ]; then SOCAT_BIN="$p"; break; fi - done -fi -# 3. Build from source. -if [ -z "${SOCAT_BIN}" ]; then - SOCAT_VERSION="1.8.1.1" - SOCAT_URL="http://www.dest-unreach.org/socat/download/socat-${SOCAT_VERSION}.tar.gz" - SOCAT_BUILD_DIR="/tmp/socat-build" - echo " Building socat ${SOCAT_VERSION} from source..." - rm -rf "${SOCAT_BUILD_DIR}" - mkdir -p "${SOCAT_BUILD_DIR}" - curl -fsSL "${SOCAT_URL}" | tar xz -C "${SOCAT_BUILD_DIR}" --strip-components=1 - (cd "${SOCAT_BUILD_DIR}" && LDFLAGS="-static" ./configure --quiet && make -j"$(nproc)" -s) - SOCAT_BIN="${SOCAT_BUILD_DIR}/socat" - if [ ! -f "${SOCAT_BIN}" ]; then - echo "ERROR: socat build failed" - exit 1 - fi - if ! file "${SOCAT_BIN}" | grep -q "statically linked"; then - echo "ERROR: socat is not statically linked!" - exit 1 - fi -fi -sudo cp "${SOCAT_BIN}" "${MOUNT_DIR}/usr/local/bin/socat" -sudo chmod 755 "${MOUNT_DIR}/usr/local/bin/socat" - # Step 4: Verify. echo "" echo "==> Installed files:" -ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" "${MOUNT_DIR}/usr/local/bin/socat" +ls -la "${MOUNT_DIR}/usr/local/bin/envd" "${MOUNT_DIR}/usr/local/bin/wrenn-init" "${MOUNT_DIR}/sbin/tini" echo "" echo "==> Done. Rootfs updated: ${ROOTFS}" From 1ce62934b3cdc174dad42361cd082e3468fd76c2 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 15:27:21 +0600 Subject: [PATCH 08/28] Add template build system with admin panel, async workers, and FlattenRootfs RPC Introduces an end-to-end template building pipeline: admins submit a recipe (list of shell commands) via the dashboard, a Redis-backed worker pool spins up a sandbox, executes each command, and produces either a full snapshot (with healthcheck) or an image-only template (rootfs flattened via a new FlattenRootfs host-agent RPC). Build progress and per-step logs are persisted to a new template_builds table and polled by the frontend. Backend: - New FlattenRootfs RPC (proto + host agent + sandbox manager) - BuildService with Redis queue (BLPOP) and configurable worker pool (default 2) - Admin-only REST endpoints: POST/GET /v1/admin/builds, GET /v1/admin/builds/{id} - Migration for template_builds table with JSONB logs and recipe columns - sqlc queries for build CRUD and progress updates Frontend: - /admin/templates page with Templates + Builds tabs - Create Template dialog with recipe textarea, healthcheck, specs - Build history with expandable per-step logs, status badges, progress bars - Auto-polling every 3s for active builds - AdminSidebar updated with Templates nav item --- cmd/control-plane/main.go | 4 + .../20260326090649_template_builds.sql | 25 + db/queries/template_builds.sql | 33 + frontend/src/lib/api/builds.ts | 52 ++ .../src/lib/components/AdminSidebar.svelte | 4 +- .../src/routes/admin/templates/+page.svelte | 837 ++++++++++++++++++ internal/api/handlers_builds.go | 156 ++++ internal/api/server.go | 10 +- internal/db/models.go | 20 + internal/db/template_builds.sql.go | 223 +++++ internal/hostagent/server.go | 13 + internal/id/id.go | 5 + internal/sandbox/manager.go | 82 ++ internal/service/build.go | 385 ++++++++ proto/hostagent/gen/hostagent.pb.go | 154 +++- .../hostagentv1connect/hostagent.connect.go | 37 + proto/hostagent/hostagent.proto | 17 + 17 files changed, 2031 insertions(+), 26 deletions(-) create mode 100644 db/migrations/20260326090649_template_builds.sql create mode 100644 db/queries/template_builds.sql create mode 100644 frontend/src/lib/api/builds.ts create mode 100644 frontend/src/routes/admin/templates/+page.svelte create mode 100644 internal/api/handlers_builds.go create mode 100644 internal/db/template_builds.sql.go create mode 100644 internal/service/build.go diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 40c3c48a..af57d2be 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -90,6 +90,10 @@ func main() { // API server. srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + // Start template build workers (2 concurrent). + stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2) + defer stopBuildWorkers() + // Start host monitor (passive + active reconciliation every 30s). monitor := api.NewHostMonitor(queries, hostPool, audit.New(queries), 30*time.Second) monitor.Start(ctx) diff --git a/db/migrations/20260326090649_template_builds.sql b/db/migrations/20260326090649_template_builds.sql new file mode 100644 index 00000000..8e5326dd --- /dev/null +++ b/db/migrations/20260326090649_template_builds.sql @@ -0,0 +1,25 @@ +-- +goose Up + +CREATE TABLE template_builds ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + base_template TEXT NOT NULL DEFAULT 'minimal', + recipe JSONB NOT NULL DEFAULT '[]', + healthcheck TEXT, + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + status TEXT NOT NULL DEFAULT 'pending', + current_step INTEGER NOT NULL DEFAULT 0, + total_steps INTEGER NOT NULL DEFAULT 0, + logs JSONB NOT NULL DEFAULT '[]', + error TEXT, + sandbox_id TEXT, + host_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ +); + +-- +goose Down + +DROP TABLE template_builds; diff --git a/db/queries/template_builds.sql b/db/queries/template_builds.sql new file mode 100644 index 00000000..ead4d925 --- /dev/null +++ b/db/queries/template_builds.sql @@ -0,0 +1,33 @@ +-- name: InsertTemplateBuild :one +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8) +RETURNING *; + +-- name: GetTemplateBuild :one +SELECT * FROM template_builds WHERE id = $1; + +-- name: ListTemplateBuilds :many +SELECT * FROM template_builds ORDER BY created_at DESC; + +-- name: UpdateBuildStatus :one +UPDATE template_builds +SET status = $2, + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed') THEN NOW() ELSE completed_at END +WHERE id = $1 +RETURNING *; + +-- name: UpdateBuildProgress :exec +UPDATE template_builds +SET current_step = $2, logs = $3 +WHERE id = $1; + +-- name: UpdateBuildSandbox :exec +UPDATE template_builds +SET sandbox_id = $2, host_id = $3 +WHERE id = $1; + +-- name: UpdateBuildError :exec +UPDATE template_builds +SET error = $2, status = 'failed', completed_at = NOW() +WHERE id = $1; diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts new file mode 100644 index 00000000..d826b36b --- /dev/null +++ b/frontend/src/lib/api/builds.ts @@ -0,0 +1,52 @@ +import { apiFetch, type ApiResult } from '$lib/api/client'; + +export type BuildLogEntry = { + step: number; + cmd: string; + stdout: string; + stderr: string; + exit: number; + ok: boolean; + elapsed_ms: number; +}; + +export type Build = { + id: string; + name: string; + base_template: string; + recipe: string[]; + healthcheck?: string; + vcpus: number; + memory_mb: number; + status: string; + current_step: number; + total_steps: number; + logs: BuildLogEntry[]; + error?: string; + sandbox_id?: string; + host_id?: string; + created_at: string; + started_at?: string; + completed_at?: string; +}; + +export type CreateBuildParams = { + name: string; + base_template?: string; + recipe: string[]; + healthcheck?: string; + vcpus?: number; + memory_mb?: number; +}; + +export async function createBuild(params: CreateBuildParams): Promise> { + return apiFetch('POST', '/api/v1/admin/builds', params); +} + +export async function listBuilds(): Promise> { + return apiFetch('GET', '/api/v1/admin/builds'); +} + +export async function getBuild(id: string): Promise> { + return apiFetch('GET', `/api/v1/admin/builds/${id}`); +} diff --git a/frontend/src/lib/components/AdminSidebar.svelte b/frontend/src/lib/components/AdminSidebar.svelte index 4bed5cc9..ebf4b646 100644 --- a/frontend/src/lib/components/AdminSidebar.svelte +++ b/frontend/src/lib/components/AdminSidebar.svelte @@ -3,6 +3,7 @@ import { auth } from '$lib/auth.svelte'; import { IconServer, + IconTemplate, IconSettings, IconLogout, IconSidebar, @@ -21,7 +22,8 @@ }; const managementItems: NavItem[] = [ - { label: 'Hosts', icon: IconServer, href: '/admin/hosts' } + { label: 'Hosts', icon: IconServer, href: '/admin/hosts' }, + { label: 'Templates', icon: IconTemplate, href: '/admin/templates' } ]; function isActive(href: string): boolean { diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte new file mode 100644 index 00000000..904bbad4 --- /dev/null +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -0,0 +1,837 @@ + + +
+ + +
+ +
+
+
+

+ Templates +

+

+ Build and manage global templates available to all teams. +

+
+ +
+ + + {#if !templatesLoading && !templatesError} +
+
+ {templateCount} + templates +
+
+ {baseCount} + base +
+
+ {snapshotCount} + snapshots +
+ {#if runningBuilds > 0} +
+ + + + + {runningBuilds} + building +
+ {/if} +
+ {/if} +
+ + +
+ {#each [['templates', 'Templates', templateCount], ['builds', 'Builds', builds.length]] as [id, label, count] (id)} + + {/each} +
+ + +
+ {#if activeTab === 'templates'} + {#if templatesLoading} + {@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])} + {:else if templatesError} +
+ {templatesError} +
+ {:else if templates.length === 0} + {@render emptyState('templates')} + {:else} + {@render templatesTable()} + {/if} + {:else} + {#if buildsLoading} + {@render skeletonRows(4, ['Build', 'Name', 'Status', 'Progress', 'Started', 'Duration'])} + {:else if buildsError} +
+ {buildsError} +
+ {:else if builds.length === 0} + {@render emptyState('builds')} + {:else} + {@render buildsTable()} + {/if} + {/if} +
+
+
+ + + +{#snippet skeletonRows(count: number, headers: string[])} +
+ + + + {#each headers as h} + + {/each} + + + + {#each Array(count) as _, i} + + {#each headers as _h, j} + + {/each} + + {/each} + +
{h}
+
+
+
+{/snippet} + +{#snippet emptyState(type: 'templates' | 'builds')} +
+
+ {#if type === 'templates'} + + {:else} + + {/if} +
+

+ {type === 'templates' ? 'No templates yet.' : 'No builds yet.'} +

+

+ {type === 'templates' + ? 'Create a template to provide pre-configured environments for all teams.' + : 'Start a template build to see progress and logs here.'} +

+
+{/snippet} + +{#snippet templatesTable()} +
+ + + + + + + + + + + + + {#each templates as tmpl (tmpl.name)} + + + + + + + + + {/each} + +
NameType
+ {tmpl.name} + + {#if tmpl.type === 'snapshot'} + + snapshot + + {:else} + + base + + {/if} + + {#if tmpl.type === 'snapshot'} + + {/if} +
+
+{/snippet} + +{#snippet buildsTable()} +
+ + + + + + + + + + + + + + {#each builds as build (build.id)} + toggleBuildExpand(build.id)} + > + + + + + + + + + + {#if expandedBuildId === build.id} + + + + {/if} + {/each} + +
BuildNameStatus
+
+ + + + {build.id} +
+
+ {build.name} + + + {#if build.status === 'running'} + + + + + {:else if build.status === 'success'} + + {:else if build.status === 'failed'} + + {:else} + + {/if} + {build.status} + +
+
+ {#if build.error} +
+ {build.error} +
+ {/if} + + {#if build.logs && build.logs.length > 0} +
+ {#each build.logs as log, i (i)} +
+ + + + + {#if expandedSteps.has(log.step)} +
+ {#if log.stdout} +
+ stdout +
{log.stdout}
+
+ {/if} + {#if log.stderr} +
+ stderr +
{log.stderr}
+
+ {/if} + {#if !log.stdout && !log.stderr} + No output + {/if} +
+ {/if} +
+ {/each} +
+ {:else} +
+ {#if build.status === 'pending' || build.status === 'running'} + + {build.status === 'pending' ? 'Waiting for worker…' : 'Running…'} + {:else} + No build logs recorded. + {/if} +
+ {/if} + + + {#if build.recipe && build.recipe.length > 0} +
+ Recipe +
+ {#each build.recipe as cmd, i} +
+ {i + 1}. + {cmd} +
+ {/each} +
+
+ {/if} + + {#if build.healthcheck} +
+ Healthcheck + {build.healthcheck} +
+ {/if} +
+
+
+{/snippet} + + +{#if showCreate} +
+
{ if (!creating) showCreate = false; }} + onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }} + >
+
+

+ Create Template +

+

+ Build a new global template by running commands on a base image. +

+ + {#if createError} +
+ {createError} +
+ {/if} + +
+
+ + +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+ + +

+ Each command runs with a 30s timeout. Non-zero exit codes abort the build. +

+
+ +
+ + +

+ If set, the build will poll this command every 1s (up to 60s) after the recipe completes. On success, a full snapshot (with memory state) is created. Without a healthcheck, only the rootfs is saved. +

+
+
+ +
+ + +
+
+
+{/if} + + +{#if deleteTarget} +
+
{ if (!deleting) deleteTarget = null; }} + onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }} + >
+
+

+ Delete Template +

+

+ Permanently remove {deleteTarget.name} from all hosts. +

+ + {#if deleteError} +
+ {deleteError} +
+ {/if} + +
+ + +
+
+
+{/if} + + diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go new file mode 100644 index 00000000..ae9a48e1 --- /dev/null +++ b/internal/api/handlers_builds.go @@ -0,0 +1,156 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/service" + "git.omukk.dev/wrenn/sandbox/internal/validate" +) + +type buildHandler struct { + svc *service.BuildService +} + +func newBuildHandler(svc *service.BuildService) *buildHandler { + return &buildHandler{svc: svc} +} + +type createBuildRequest struct { + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []string `json:"recipe"` + Healthcheck string `json:"healthcheck"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` +} + +type buildResponse struct { + ID string `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe json.RawMessage `json:"recipe"` + Healthcheck *string `json:"healthcheck,omitempty"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + Status string `json:"status"` + CurrentStep int32 `json:"current_step"` + TotalSteps int32 `json:"total_steps"` + Logs json.RawMessage `json:"logs"` + Error *string `json:"error,omitempty"` + SandboxID *string `json:"sandbox_id,omitempty"` + HostID *string `json:"host_id,omitempty"` + CreatedAt string `json:"created_at"` + StartedAt *string `json:"started_at,omitempty"` + CompletedAt *string `json:"completed_at,omitempty"` +} + +func buildToResponse(b db.TemplateBuild) buildResponse { + resp := buildResponse{ + ID: b.ID, + Name: b.Name, + BaseTemplate: b.BaseTemplate, + Recipe: b.Recipe, + VCPUs: b.Vcpus, + MemoryMB: b.MemoryMb, + Status: b.Status, + CurrentStep: b.CurrentStep, + TotalSteps: b.TotalSteps, + Logs: b.Logs, + } + if b.Healthcheck.Valid { + resp.Healthcheck = &b.Healthcheck.String + } + if b.Error.Valid { + resp.Error = &b.Error.String + } + if b.SandboxID.Valid { + resp.SandboxID = &b.SandboxID.String + } + if b.HostID.Valid { + resp.HostID = &b.HostID.String + } + if b.CreatedAt.Valid { + resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339) + } + if b.StartedAt.Valid { + s := b.StartedAt.Time.Format(time.RFC3339) + resp.StartedAt = &s + } + if b.CompletedAt.Valid { + s := b.CompletedAt.Time.Format(time.RFC3339) + resp.CompletedAt = &s + } + return resp +} + +// Create handles POST /v1/admin/builds. +func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) { + var req createBuildRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + if req.Name == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "name is required") + return + } + if err := validate.SafeName(req.Name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err)) + return + } + if len(req.Recipe) == 0 { + writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command") + return + } + + build, err := h.svc.Create(r.Context(), service.BuildCreateParams{ + Name: req.Name, + BaseTemplate: req.BaseTemplate, + Recipe: req.Recipe, + Healthcheck: req.Healthcheck, + VCPUs: req.VCPUs, + MemoryMB: req.MemoryMB, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "build_error", err.Error()) + return + } + + writeJSON(w, http.StatusCreated, buildToResponse(build)) +} + +// List handles GET /v1/admin/builds. +func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) { + builds, err := h.svc.List(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds") + return + } + + resp := make([]buildResponse, len(builds)) + for i, b := range builds { + resp[i] = buildToResponse(b) + } + + writeJSON(w, http.StatusOK, resp) +} + +// Get handles GET /v1/admin/builds/{id}. +func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) { + buildID := chi.URLParam(r, "id") + + build, err := h.svc.Get(r.Context(), buildID) + if err != nil { + writeError(w, http.StatusNotFound, "not_found", "build not found") + return + } + + writeJSON(w, http.StatusOK, buildToResponse(build)) +} diff --git a/internal/api/server.go b/internal/api/server.go index 918476bf..6999b8a8 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -22,7 +22,8 @@ var openapiYAML []byte // Server is the control plane HTTP server. type Server struct { - router chi.Router + router chi.Router + BuildSvc *service.BuildService } // New constructs the chi router and registers all routes. @@ -47,6 +48,7 @@ func New( teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} auditSvc := &service.AuditService{DB: queries} statsSvc := &service.StatsService{DB: queries, Pool: pgPool} + buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched} al := audit.New(queries) @@ -65,6 +67,7 @@ func New( auditH := newAuditHandler(auditSvc) statsH := newStatsHandler(statsSvc) metricsH := newSandboxMetricsHandler(queries, pool) + buildH := newBuildHandler(buildSvc) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -174,9 +177,12 @@ func New( r.Use(requireJWT(jwtSecret)) r.Use(requireAdmin(queries)) r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Post("/builds", buildH.Create) + r.Get("/builds", buildH.List) + r.Get("/builds/{id}", buildH.Get) }) - return &Server{router: r} + return &Server{router: r, BuildSvc: buildSvc} } // Handler returns the HTTP handler. diff --git a/internal/db/models.go b/internal/db/models.go index 0128f4a8..74596c6c 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -147,6 +147,26 @@ type Template struct { TeamID string `json:"team_id"` } +type TemplateBuild struct { + ID string `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []byte `json:"recipe"` + Healthcheck pgtype.Text `json:"healthcheck"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + Status string `json:"status"` + CurrentStep int32 `json:"current_step"` + TotalSteps int32 `json:"total_steps"` + Logs []byte `json:"logs"` + Error pgtype.Text `json:"error"` + SandboxID pgtype.Text `json:"sandbox_id"` + HostID pgtype.Text `json:"host_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + StartedAt pgtype.Timestamptz `json:"started_at"` + CompletedAt pgtype.Timestamptz `json:"completed_at"` +} + type User struct { ID string `json:"id"` Email string `json:"email"` diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go new file mode 100644 index 00000000..8142d294 --- /dev/null +++ b/internal/db/template_builds.sql.go @@ -0,0 +1,223 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: template_builds.sql + +package db + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const getTemplateBuild = `-- name: GetTemplateBuild :one +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at FROM template_builds WHERE id = $1 +` + +func (q *Queries) GetTemplateBuild(ctx context.Context, id string) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, getTemplateBuild, id) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + ) + return i, err +} + +const insertTemplateBuild = `-- name: InsertTemplateBuild :one +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8) +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at +` + +type InsertTemplateBuildParams struct { + ID string `json:"id"` + Name string `json:"name"` + BaseTemplate string `json:"base_template"` + Recipe []byte `json:"recipe"` + Healthcheck pgtype.Text `json:"healthcheck"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TotalSteps int32 `json:"total_steps"` +} + +func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, insertTemplateBuild, + arg.ID, + arg.Name, + arg.BaseTemplate, + arg.Recipe, + arg.Healthcheck, + arg.Vcpus, + arg.MemoryMb, + arg.TotalSteps, + ) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + ) + return i, err +} + +const listTemplateBuilds = `-- name: ListTemplateBuilds :many +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at FROM template_builds ORDER BY created_at DESC +` + +func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) { + rows, err := q.db.Query(ctx, listTemplateBuilds) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TemplateBuild + for rows.Next() { + var i TemplateBuild + if err := rows.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateBuildError = `-- name: UpdateBuildError :exec +UPDATE template_builds +SET error = $2, status = 'failed', completed_at = NOW() +WHERE id = $1 +` + +type UpdateBuildErrorParams struct { + ID string `json:"id"` + Error pgtype.Text `json:"error"` +} + +func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error { + _, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error) + return err +} + +const updateBuildProgress = `-- name: UpdateBuildProgress :exec +UPDATE template_builds +SET current_step = $2, logs = $3 +WHERE id = $1 +` + +type UpdateBuildProgressParams struct { + ID string `json:"id"` + CurrentStep int32 `json:"current_step"` + Logs []byte `json:"logs"` +} + +func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error { + _, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs) + return err +} + +const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec +UPDATE template_builds +SET sandbox_id = $2, host_id = $3 +WHERE id = $1 +` + +type UpdateBuildSandboxParams struct { + ID string `json:"id"` + SandboxID pgtype.Text `json:"sandbox_id"` + HostID pgtype.Text `json:"host_id"` +} + +func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error { + _, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID) + return err +} + +const updateBuildStatus = `-- name: UpdateBuildStatus :one +UPDATE template_builds +SET status = $2, + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed') THEN NOW() ELSE completed_at END +WHERE id = $1 +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at +` + +type UpdateBuildStatusParams struct { + ID string `json:"id"` + Status string `json:"status"` +} + +func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) { + row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status) + var i TemplateBuild + err := row.Scan( + &i.ID, + &i.Name, + &i.BaseTemplate, + &i.Recipe, + &i.Healthcheck, + &i.Vcpus, + &i.MemoryMb, + &i.Status, + &i.CurrentStep, + &i.TotalSteps, + &i.Logs, + &i.Error, + &i.SandboxID, + &i.HostID, + &i.CreatedAt, + &i.StartedAt, + &i.CompletedAt, + ) + return i, err +} diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index fb7fb664..86fdda0e 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -110,6 +110,19 @@ func (s *Server) DeleteSnapshot( return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil } +func (s *Server) FlattenRootfs( + ctx context.Context, + req *connect.Request[pb.FlattenRootfsRequest], +) (*connect.Response[pb.FlattenRootfsResponse], error) { + sizeBytes, err := s.mgr.FlattenRootfs(ctx, req.Msg.SandboxId, req.Msg.Name) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err)) + } + return connect.NewResponse(&pb.FlattenRootfsResponse{ + SizeBytes: sizeBytes, + }), nil +} + func (s *Server) PingSandbox( ctx context.Context, req *connect.Request[pb.PingSandboxRequest], diff --git a/internal/id/id.go b/internal/id/id.go index bbda47c7..836af6df 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -78,6 +78,11 @@ func NewAuditLogID() string { return "log-" + hex8() } +// NewBuildID generates a new template build ID in the format "bld-" + 8 hex chars. +func NewBuildID() string { + return "bld-" + hex8() +} + // NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token. func NewRefreshToken() string { b := make([]byte, 32) diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 94642639..1d103bc2 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -795,6 +795,88 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i return sizeBytes, nil } +// FlattenRootfs stops a running sandbox, flattens its device-mapper CoW +// rootfs into a standalone rootfs.ext4, and cleans up all resources. +// The result is an image-only template (no VM memory/CPU state) stored in +// ImagesDir/{name}/rootfs.ext4. +func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (int64, error) { + if err := validate.SafeName(name); err != nil { + return 0, fmt.Errorf("invalid template name: %w", err) + } + + m.mu.Lock() + sb, ok := m.boxes[sandboxID] + if ok { + delete(m.boxes, sandboxID) + } + m.mu.Unlock() + + if !ok { + return 0, fmt.Errorf("sandbox %s not found", sandboxID) + } + + // Stop the VM but keep the dm device alive for flattening. + m.stopSampler(sb) + if err := m.vm.Destroy(ctx, sb.ID); err != nil { + slog.Warn("vm destroy error during flatten", "id", sb.ID, "error", err) + } + + // Release network resources — not needed after VM is stopped. + if err := network.RemoveNetwork(sb.slot); err != nil { + slog.Warn("network cleanup error during flatten", "id", sb.ID, "error", err) + } + m.slots.Release(sb.SlotIndex) + + if sb.uffdSocketPath != "" { + os.Remove(sb.uffdSocketPath) + } + + // Create template directory and flatten the dm-snapshot. + if err := snapshot.EnsureDir(m.cfg.ImagesDir, name); err != nil { + m.cleanupDM(sb) + return 0, fmt.Errorf("create template dir: %w", err) + } + + outputPath := snapshot.RootfsPath(m.cfg.ImagesDir, name) + if sb.dmDevice == nil { + return 0, fmt.Errorf("sandbox %s has no dm device", sandboxID) + } + + if err := devicemapper.FlattenSnapshot(sb.dmDevice.DevicePath, outputPath); err != nil { + m.cleanupDM(sb) + warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + return 0, fmt.Errorf("flatten rootfs: %w", err) + } + + // Clean up dm device and loop device now that flatten is complete. + m.cleanupDM(sb) + + sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name) + if err != nil { + slog.Warn("failed to calculate template size", "error", err) + } + + slog.Info("rootfs flattened to image-only template", + "sandbox", sandboxID, + "name", name, + "size_bytes", sizeBytes, + ) + return sizeBytes, nil +} + +// cleanupDM tears down the dm-snapshot device and releases the base image loop device. +func (m *Manager) cleanupDM(sb *sandboxState) { + if sb.dmDevice != nil { + if err := devicemapper.RemoveSnapshot(context.Background(), sb.dmDevice); err != nil { + slog.Warn("dm-snapshot remove error", "id", sb.ID, "error", err) + } + os.Remove(sb.dmDevice.CowPath) + } + if sb.baseImagePath != "" { + m.loops.Release(sb.baseImagePath) + } +} + // DeleteSnapshot removes a snapshot template from disk. func (m *Manager) DeleteSnapshot(name string) error { if err := validate.SafeName(name); err != nil { diff --git a/internal/service/build.go b/internal/service/build.go new file mode 100644 index 00000000..054f8a24 --- /dev/null +++ b/internal/service/build.go @@ -0,0 +1,385 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "time" + + "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" + "github.com/redis/go-redis/v9" + + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/scheduler" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" +) + +const ( + buildQueueKey = "wrenn:build_queue" + buildCommandTimeout = 30 * time.Second + healthcheckInterval = 1 * time.Second + healthcheckTimeout = 60 * time.Second + platformTeamID = "platform" +) + +// buildAgentClient is the subset of the host agent client used by the build worker. +type buildAgentClient interface { + CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error) + DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error) + Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error) + CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error) + FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error) +} + +// BuildLogEntry represents a single entry in the build log JSONB array. +type BuildLogEntry struct { + Step int `json:"step"` + Cmd string `json:"cmd"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + Exit int32 `json:"exit"` + Ok bool `json:"ok"` + Elapsed int64 `json:"elapsed_ms"` +} + +// BuildService handles template build orchestration. +type BuildService struct { + DB *db.Queries + Redis *redis.Client + Pool *lifecycle.HostClientPool + Scheduler scheduler.HostScheduler +} + +// BuildCreateParams holds the parameters for creating a template build. +type BuildCreateParams struct { + Name string + BaseTemplate string + Recipe []string + Healthcheck string + VCPUs int32 + MemoryMB int32 +} + +// Create inserts a new build record and enqueues it to Redis. +func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) { + if p.BaseTemplate == "" { + p.BaseTemplate = "minimal" + } + if p.VCPUs <= 0 { + p.VCPUs = 1 + } + if p.MemoryMB <= 0 { + p.MemoryMB = 512 + } + + recipeJSON, err := json.Marshal(p.Recipe) + if err != nil { + return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err) + } + + buildID := id.NewBuildID() + + build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{ + ID: buildID, + Name: p.Name, + BaseTemplate: p.BaseTemplate, + Recipe: recipeJSON, + Healthcheck: pgtype.Text{String: p.Healthcheck, Valid: p.Healthcheck != ""}, + Vcpus: p.VCPUs, + MemoryMb: p.MemoryMB, + TotalSteps: int32(len(p.Recipe)), + }) + if err != nil { + return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) + } + + // Enqueue build ID to Redis for workers to pick up. + if err := s.Redis.RPush(ctx, buildQueueKey, buildID).Err(); err != nil { + return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err) + } + + return build, nil +} + +// Get returns a single build by ID. +func (s *BuildService) Get(ctx context.Context, buildID string) (db.TemplateBuild, error) { + return s.DB.GetTemplateBuild(ctx, buildID) +} + +// List returns all builds ordered by creation time. +func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) { + return s.DB.ListTemplateBuilds(ctx) +} + +// StartWorkers launches n goroutines that consume from the Redis build queue. +// The returned cancel function stops all workers. +func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc { + ctx, cancel := context.WithCancel(ctx) + for i := range n { + go s.worker(ctx, i) + } + slog.Info("build workers started", "count", n) + return cancel +} + +func (s *BuildService) worker(ctx context.Context, workerID int) { + log := slog.With("worker", workerID) + for { + // BLPOP blocks until a build ID is available or context is cancelled. + result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result() + if err != nil { + if ctx.Err() != nil { + log.Info("build worker shutting down") + return + } + log.Error("redis BLPOP error", "error", err) + time.Sleep(time.Second) + continue + } + // result[0] is the key, result[1] is the build ID. + buildID := result[1] + log.Info("picked up build", "build_id", buildID) + s.executeBuild(ctx, buildID) + } +} + +func (s *BuildService) executeBuild(ctx context.Context, buildID string) { + log := slog.With("build_id", buildID) + + build, err := s.DB.GetTemplateBuild(ctx, buildID) + if err != nil { + log.Error("failed to fetch build", "error", err) + return + } + + // Mark as running. + if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "running", + }); err != nil { + log.Error("failed to update build status", "error", err) + return + } + + // Parse recipe. + var recipe []string + if err := json.Unmarshal(build.Recipe, &recipe); err != nil { + s.failBuild(ctx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err)) + return + } + + // Pick a platform host and create a sandbox. + host, err := s.Scheduler.SelectHost(ctx, platformTeamID, false) + if err != nil { + s.failBuild(ctx, buildID, fmt.Sprintf("no host available: %v", err)) + return + } + + agent, err := s.Pool.GetForHost(host) + if err != nil { + s.failBuild(ctx, buildID, fmt.Sprintf("agent client error: %v", err)) + return + } + + sandboxID := id.NewSandboxID() + log = log.With("sandbox_id", sandboxID, "host_id", host.ID) + + resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ + SandboxId: sandboxID, + Template: build.BaseTemplate, + Vcpus: build.Vcpus, + MemoryMb: build.MemoryMb, + TimeoutSec: 0, // no auto-pause for builds + })) + if err != nil { + s.failBuild(ctx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) + return + } + _ = resp + + // Record sandbox/host association. + _ = s.DB.UpdateBuildSandbox(ctx, db.UpdateBuildSandboxParams{ + ID: buildID, + SandboxID: pgtype.Text{String: sandboxID, Valid: true}, + HostID: pgtype.Text{String: host.ID, Valid: true}, + }) + + // Execute recipe commands. + var logs []BuildLogEntry + for i, cmd := range recipe { + log.Info("executing build step", "step", i+1, "cmd", cmd) + + execCtx, cancel := context.WithTimeout(ctx, buildCommandTimeout) + start := time.Now() + + execResp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", cmd}, + TimeoutSec: int32(buildCommandTimeout.Seconds()), + })) + cancel() + + entry := BuildLogEntry{ + Step: i + 1, + Cmd: cmd, + Elapsed: time.Since(start).Milliseconds(), + } + + if err != nil { + entry.Stderr = err.Error() + entry.Ok = false + logs = append(logs, entry) + s.updateLogs(ctx, buildID, i+1, logs) + s.destroySandbox(ctx, agent, sandboxID) + s.failBuild(ctx, buildID, fmt.Sprintf("step %d exec error: %v", i+1, err)) + return + } + + entry.Stdout = string(execResp.Msg.Stdout) + entry.Stderr = string(execResp.Msg.Stderr) + entry.Exit = execResp.Msg.ExitCode + entry.Ok = execResp.Msg.ExitCode == 0 + logs = append(logs, entry) + + s.updateLogs(ctx, buildID, i+1, logs) + + if execResp.Msg.ExitCode != 0 { + s.destroySandbox(ctx, agent, sandboxID) + s.failBuild(ctx, buildID, fmt.Sprintf("step %d failed with exit code %d", i+1, execResp.Msg.ExitCode)) + return + } + } + + // Healthcheck or direct snapshot. + if build.Healthcheck.Valid && build.Healthcheck.String != "" { + log.Info("running healthcheck", "cmd", build.Healthcheck.String) + if err := s.waitForHealthcheck(ctx, agent, sandboxID, build.Healthcheck.String); err != nil { + s.destroySandbox(ctx, agent, sandboxID) + s.failBuild(ctx, buildID, fmt.Sprintf("healthcheck failed: %v", err)) + return + } + + // Healthcheck passed → full snapshot (with memory/CPU state). + log.Info("healthcheck passed, creating snapshot") + if _, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ + SandboxId: sandboxID, + Name: build.Name, + })); err != nil { + s.destroySandbox(ctx, agent, sandboxID) + s.failBuild(ctx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) + return + } + } else { + // No healthcheck → image-only template (rootfs only). + log.Info("no healthcheck, flattening rootfs") + if _, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ + SandboxId: sandboxID, + Name: build.Name, + })); err != nil { + s.destroySandbox(ctx, agent, sandboxID) + s.failBuild(ctx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) + return + } + } + + // Insert into templates table as a global (platform) template. + templateType := "base" + if build.Healthcheck.Valid && build.Healthcheck.String != "" { + templateType = "snapshot" + } + + if _, err := s.DB.InsertTemplate(ctx, db.InsertTemplateParams{ + Name: build.Name, + Type: templateType, + Vcpus: pgtype.Int4{Int32: build.Vcpus, Valid: true}, + MemoryMb: pgtype.Int4{Int32: build.MemoryMb, Valid: true}, + SizeBytes: 0, // Could query the host, but the template is created. + TeamID: platformTeamID, + }); err != nil { + log.Error("failed to insert template record", "error", err) + // Build succeeded on disk, just DB record failed — don't mark as failed. + } + + // For CreateSnapshot, the sandbox is already destroyed by the snapshot process. + // For FlattenRootfs, the sandbox is already destroyed by the flatten process. + // No additional destroy needed. + + // Mark build as success. + if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "success", + }); err != nil { + log.Error("failed to mark build as success", "error", err) + } + + log.Info("template build completed successfully", "name", build.Name) +} + +func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxID, cmd string) error { + deadline := time.After(healthcheckTimeout) + ticker := time.NewTicker(healthcheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-deadline: + return fmt.Errorf("healthcheck timed out after %s", healthcheckTimeout) + case <-ticker.C: + execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", cmd}, + TimeoutSec: 10, + })) + cancel() + + if err != nil { + slog.Debug("healthcheck exec error (retrying)", "error", err) + continue + } + if resp.Msg.ExitCode == 0 { + return nil + } + slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode) + } + } +} + +func (s *BuildService) updateLogs(ctx context.Context, buildID string, step int, logs []BuildLogEntry) { + logsJSON, err := json.Marshal(logs) + if err != nil { + slog.Warn("failed to marshal build logs", "error", err) + return + } + if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{ + ID: buildID, + CurrentStep: int32(step), + Logs: logsJSON, + }); err != nil { + slog.Warn("failed to update build progress", "error", err) + } +} + +func (s *BuildService) failBuild(ctx context.Context, buildID, errMsg string) { + slog.Error("build failed", "build_id", buildID, "error", errMsg) + if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{ + ID: buildID, + Error: pgtype.Text{String: errMsg, Valid: true}, + }); err != nil { + slog.Error("failed to update build error", "build_id", buildID, "error", err) + } +} + +func (s *BuildService) destroySandbox(ctx context.Context, agent buildAgentClient, sandboxID string) { + if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ + SandboxId: sandboxID, + })); err != nil { + slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxID, "error", err) + } +} diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index f496b2cb..c7436b76 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -2171,6 +2171,102 @@ func (x *FlushSandboxMetricsResponse) GetPoints_24H() []*MetricPoint { return nil } +type FlattenRootfsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` // template name — output written to images/{name}/rootfs.ext4 + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FlattenRootfsRequest) Reset() { + *x = FlattenRootfsRequest{} + mi := &file_hostagent_proto_msgTypes[40] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FlattenRootfsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FlattenRootfsRequest) ProtoMessage() {} + +func (x *FlattenRootfsRequest) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[40] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FlattenRootfsRequest.ProtoReflect.Descriptor instead. +func (*FlattenRootfsRequest) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{40} +} + +func (x *FlattenRootfsRequest) GetSandboxId() string { + if x != nil { + return x.SandboxId + } + return "" +} + +func (x *FlattenRootfsRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type FlattenRootfsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + SizeBytes int64 `protobuf:"varint,1,opt,name=size_bytes,json=sizeBytes,proto3" json:"size_bytes,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FlattenRootfsResponse) Reset() { + *x = FlattenRootfsResponse{} + mi := &file_hostagent_proto_msgTypes[41] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FlattenRootfsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FlattenRootfsResponse) ProtoMessage() {} + +func (x *FlattenRootfsResponse) ProtoReflect() protoreflect.Message { + mi := &file_hostagent_proto_msgTypes[41] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FlattenRootfsResponse.ProtoReflect.Descriptor instead. +func (*FlattenRootfsResponse) Descriptor() ([]byte, []int) { + return file_hostagent_proto_rawDescGZIP(), []int{41} +} + +func (x *FlattenRootfsResponse) GetSizeBytes() int64 { + if x != nil { + return x.SizeBytes + } + return 0 +} + var File_hostagent_proto protoreflect.FileDescriptor const file_hostagent_proto_rawDesc = "" + @@ -2319,7 +2415,14 @@ const file_hostagent_proto_rawDesc = "" + "points_10m\x18\x01 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints10m\x126\n" + "\tpoints_2h\x18\x02 \x03(\v2\x19.hostagent.v1.MetricPointR\bpoints2h\x128\n" + "\n" + - "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h2\xee\v\n" + + "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h\"I\n" + + "\x14FlattenRootfsRequest\x12\x1d\n" + + "\n" + + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\"6\n" + + "\x15FlattenRootfsResponse\x12\x1d\n" + + "\n" + + "size_bytes\x18\x01 \x01(\x03R\tsizeBytes2\xc8\f\n" + "\x10HostAgentService\x12X\n" + "\rCreateSandbox\x12\".hostagent.v1.CreateSandboxRequest\x1a#.hostagent.v1.CreateSandboxResponse\x12[\n" + "\x0eDestroySandbox\x12#.hostagent.v1.DestroySandboxRequest\x1a$.hostagent.v1.DestroySandboxResponse\x12U\n" + @@ -2338,7 +2441,8 @@ const file_hostagent_proto_rawDesc = "" + "\vPingSandbox\x12 .hostagent.v1.PingSandboxRequest\x1a!.hostagent.v1.PingSandboxResponse\x12L\n" + "\tTerminate\x12\x1e.hostagent.v1.TerminateRequest\x1a\x1f.hostagent.v1.TerminateResponse\x12d\n" + "\x11GetSandboxMetrics\x12&.hostagent.v1.GetSandboxMetricsRequest\x1a'.hostagent.v1.GetSandboxMetricsResponse\x12j\n" + - "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponseB\xb0\x01\n" + + "\x13FlushSandboxMetrics\x12(.hostagent.v1.FlushSandboxMetricsRequest\x1a).hostagent.v1.FlushSandboxMetricsResponse\x12X\n" + + "\rFlattenRootfs\x12\".hostagent.v1.FlattenRootfsRequest\x1a#.hostagent.v1.FlattenRootfsResponseB\xb0\x01\n" + "\x10com.hostagent.v1B\x0eHostagentProtoP\x01Z;git.omukk.dev/wrenn/sandbox/proto/hostagent/gen;hostagentv1\xa2\x02\x03HXX\xaa\x02\fHostagent.V1\xca\x02\fHostagent\\V1\xe2\x02\x18Hostagent\\V1\\GPBMetadata\xea\x02\rHostagent::V1b\x06proto3" var ( @@ -2353,7 +2457,7 @@ func file_hostagent_proto_rawDescGZIP() []byte { return file_hostagent_proto_rawDescData } -var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 40) +var file_hostagent_proto_msgTypes = make([]protoimpl.MessageInfo, 42) var file_hostagent_proto_goTypes = []any{ (*CreateSandboxRequest)(nil), // 0: hostagent.v1.CreateSandboxRequest (*CreateSandboxResponse)(nil), // 1: hostagent.v1.CreateSandboxResponse @@ -2395,6 +2499,8 @@ var file_hostagent_proto_goTypes = []any{ (*GetSandboxMetricsResponse)(nil), // 37: hostagent.v1.GetSandboxMetricsResponse (*FlushSandboxMetricsRequest)(nil), // 38: hostagent.v1.FlushSandboxMetricsRequest (*FlushSandboxMetricsResponse)(nil), // 39: hostagent.v1.FlushSandboxMetricsResponse + (*FlattenRootfsRequest)(nil), // 40: hostagent.v1.FlattenRootfsRequest + (*FlattenRootfsResponse)(nil), // 41: hostagent.v1.FlattenRootfsResponse } var file_hostagent_proto_depIdxs = []int32{ 16, // 0: hostagent.v1.ListSandboxesResponse.sandboxes:type_name -> hostagent.v1.SandboxInfo @@ -2423,25 +2529,27 @@ var file_hostagent_proto_depIdxs = []int32{ 33, // 23: hostagent.v1.HostAgentService.Terminate:input_type -> hostagent.v1.TerminateRequest 36, // 24: hostagent.v1.HostAgentService.GetSandboxMetrics:input_type -> hostagent.v1.GetSandboxMetricsRequest 38, // 25: hostagent.v1.HostAgentService.FlushSandboxMetrics:input_type -> hostagent.v1.FlushSandboxMetricsRequest - 1, // 26: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse - 3, // 27: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse - 5, // 28: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse - 7, // 29: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse - 13, // 30: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse - 15, // 31: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse - 18, // 32: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse - 20, // 33: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse - 9, // 34: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse - 11, // 35: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse - 22, // 36: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse - 28, // 37: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse - 30, // 38: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse - 32, // 39: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse - 34, // 40: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse - 37, // 41: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse - 39, // 42: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse - 26, // [26:43] is the sub-list for method output_type - 9, // [9:26] is the sub-list for method input_type + 40, // 26: hostagent.v1.HostAgentService.FlattenRootfs:input_type -> hostagent.v1.FlattenRootfsRequest + 1, // 27: hostagent.v1.HostAgentService.CreateSandbox:output_type -> hostagent.v1.CreateSandboxResponse + 3, // 28: hostagent.v1.HostAgentService.DestroySandbox:output_type -> hostagent.v1.DestroySandboxResponse + 5, // 29: hostagent.v1.HostAgentService.PauseSandbox:output_type -> hostagent.v1.PauseSandboxResponse + 7, // 30: hostagent.v1.HostAgentService.ResumeSandbox:output_type -> hostagent.v1.ResumeSandboxResponse + 13, // 31: hostagent.v1.HostAgentService.Exec:output_type -> hostagent.v1.ExecResponse + 15, // 32: hostagent.v1.HostAgentService.ListSandboxes:output_type -> hostagent.v1.ListSandboxesResponse + 18, // 33: hostagent.v1.HostAgentService.WriteFile:output_type -> hostagent.v1.WriteFileResponse + 20, // 34: hostagent.v1.HostAgentService.ReadFile:output_type -> hostagent.v1.ReadFileResponse + 9, // 35: hostagent.v1.HostAgentService.CreateSnapshot:output_type -> hostagent.v1.CreateSnapshotResponse + 11, // 36: hostagent.v1.HostAgentService.DeleteSnapshot:output_type -> hostagent.v1.DeleteSnapshotResponse + 22, // 37: hostagent.v1.HostAgentService.ExecStream:output_type -> hostagent.v1.ExecStreamResponse + 28, // 38: hostagent.v1.HostAgentService.WriteFileStream:output_type -> hostagent.v1.WriteFileStreamResponse + 30, // 39: hostagent.v1.HostAgentService.ReadFileStream:output_type -> hostagent.v1.ReadFileStreamResponse + 32, // 40: hostagent.v1.HostAgentService.PingSandbox:output_type -> hostagent.v1.PingSandboxResponse + 34, // 41: hostagent.v1.HostAgentService.Terminate:output_type -> hostagent.v1.TerminateResponse + 37, // 42: hostagent.v1.HostAgentService.GetSandboxMetrics:output_type -> hostagent.v1.GetSandboxMetricsResponse + 39, // 43: hostagent.v1.HostAgentService.FlushSandboxMetrics:output_type -> hostagent.v1.FlushSandboxMetricsResponse + 41, // 44: hostagent.v1.HostAgentService.FlattenRootfs:output_type -> hostagent.v1.FlattenRootfsResponse + 27, // [27:45] is the sub-list for method output_type + 9, // [9:27] is the sub-list for method input_type 9, // [9:9] is the sub-list for extension type_name 9, // [9:9] is the sub-list for extension extendee 0, // [0:9] is the sub-list for field type_name @@ -2471,7 +2579,7 @@ func file_hostagent_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_hostagent_proto_rawDesc), len(file_hostagent_proto_rawDesc)), NumEnums: 0, - NumMessages: 40, + NumMessages: 42, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go index 7f0fa707..02f4ecc6 100644 --- a/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go +++ b/proto/hostagent/gen/hostagentv1connect/hostagent.connect.go @@ -83,6 +83,9 @@ const ( // HostAgentServiceFlushSandboxMetricsProcedure is the fully-qualified name of the // HostAgentService's FlushSandboxMetrics RPC. HostAgentServiceFlushSandboxMetricsProcedure = "/hostagent.v1.HostAgentService/FlushSandboxMetrics" + // HostAgentServiceFlattenRootfsProcedure is the fully-qualified name of the HostAgentService's + // FlattenRootfs RPC. + HostAgentServiceFlattenRootfsProcedure = "/hostagent.v1.HostAgentService/FlattenRootfs" ) // HostAgentServiceClient is a client for the hostagent.v1.HostAgentService service. @@ -126,6 +129,11 @@ type HostAgentServiceClient interface { // FlushSandboxMetrics returns all ring buffer tiers and clears them. // Called by the control plane before pause/destroy to persist metrics to DB. FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) } // NewHostAgentServiceClient constructs a client for the hostagent.v1.HostAgentService service. By @@ -241,6 +249,12 @@ func NewHostAgentServiceClient(httpClient connect.HTTPClient, baseURL string, op connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")), connect.WithClientOptions(opts...), ), + flattenRootfs: connect.NewClient[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse]( + httpClient, + baseURL+HostAgentServiceFlattenRootfsProcedure, + connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")), + connect.WithClientOptions(opts...), + ), } } @@ -263,6 +277,7 @@ type hostAgentServiceClient struct { terminate *connect.Client[gen.TerminateRequest, gen.TerminateResponse] getSandboxMetrics *connect.Client[gen.GetSandboxMetricsRequest, gen.GetSandboxMetricsResponse] flushSandboxMetrics *connect.Client[gen.FlushSandboxMetricsRequest, gen.FlushSandboxMetricsResponse] + flattenRootfs *connect.Client[gen.FlattenRootfsRequest, gen.FlattenRootfsResponse] } // CreateSandbox calls hostagent.v1.HostAgentService.CreateSandbox. @@ -350,6 +365,11 @@ func (c *hostAgentServiceClient) FlushSandboxMetrics(ctx context.Context, req *c return c.flushSandboxMetrics.CallUnary(ctx, req) } +// FlattenRootfs calls hostagent.v1.HostAgentService.FlattenRootfs. +func (c *hostAgentServiceClient) FlattenRootfs(ctx context.Context, req *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) { + return c.flattenRootfs.CallUnary(ctx, req) +} + // HostAgentServiceHandler is an implementation of the hostagent.v1.HostAgentService service. type HostAgentServiceHandler interface { // CreateSandbox boots a new microVM with the given configuration. @@ -391,6 +411,11 @@ type HostAgentServiceHandler interface { // FlushSandboxMetrics returns all ring buffer tiers and clears them. // Called by the control plane before pause/destroy to persist metrics to DB. FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) } // NewHostAgentServiceHandler builds an HTTP handler from the service implementation. It returns the @@ -502,6 +527,12 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han connect.WithSchema(hostAgentServiceMethods.ByName("FlushSandboxMetrics")), connect.WithHandlerOptions(opts...), ) + hostAgentServiceFlattenRootfsHandler := connect.NewUnaryHandler( + HostAgentServiceFlattenRootfsProcedure, + svc.FlattenRootfs, + connect.WithSchema(hostAgentServiceMethods.ByName("FlattenRootfs")), + connect.WithHandlerOptions(opts...), + ) return "/hostagent.v1.HostAgentService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case HostAgentServiceCreateSandboxProcedure: @@ -538,6 +569,8 @@ func NewHostAgentServiceHandler(svc HostAgentServiceHandler, opts ...connect.Han hostAgentServiceGetSandboxMetricsHandler.ServeHTTP(w, r) case HostAgentServiceFlushSandboxMetricsProcedure: hostAgentServiceFlushSandboxMetricsHandler.ServeHTTP(w, r) + case HostAgentServiceFlattenRootfsProcedure: + hostAgentServiceFlattenRootfsHandler.ServeHTTP(w, r) default: http.NotFound(w, r) } @@ -614,3 +647,7 @@ func (UnimplementedHostAgentServiceHandler) GetSandboxMetrics(context.Context, * func (UnimplementedHostAgentServiceHandler) FlushSandboxMetrics(context.Context, *connect.Request[gen.FlushSandboxMetricsRequest]) (*connect.Response[gen.FlushSandboxMetricsResponse], error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlushSandboxMetrics is not implemented")) } + +func (UnimplementedHostAgentServiceHandler) FlattenRootfs(context.Context, *connect.Request[gen.FlattenRootfsRequest]) (*connect.Response[gen.FlattenRootfsResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("hostagent.v1.HostAgentService.FlattenRootfs is not implemented")) +} diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index 214a84e0..cd93a2db 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -61,6 +61,12 @@ service HostAgentService { // Called by the control plane before pause/destroy to persist metrics to DB. rpc FlushSandboxMetrics(FlushSandboxMetricsRequest) returns (FlushSandboxMetricsResponse); + // FlattenRootfs stops the sandbox VM, flattens the device-mapper CoW + // snapshot into a standalone rootfs.ext4 in the images directory, then + // cleans up all sandbox resources. Used by the template build system to + // produce image-only templates (no memory/CPU state). + rpc FlattenRootfs(FlattenRootfsRequest) returns (FlattenRootfsResponse); + } message CreateSandboxRequest { @@ -284,3 +290,14 @@ message FlushSandboxMetricsResponse { repeated MetricPoint points_2h = 2; repeated MetricPoint points_24h = 3; } + +// ── FlattenRootfs ──────────────────────────────────────────────────── + +message FlattenRootfsRequest { + string sandbox_id = 1; + string name = 2; // template name — output written to images/{name}/rootfs.ext4 +} + +message FlattenRootfsResponse { + int64 size_bytes = 1; +} From cdd89a7cee8ec76cdc43a5e81922c3d15db0b733 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 15:31:38 +0600 Subject: [PATCH 09/28] Fix review issues: detached contexts, loop device leak, timer leak, size_bytes - Use context.Background() with timeout in destroySandbox/failBuild so cleanup and DB writes survive parent context cancellation on shutdown - Fix loop device refcount leak in FlattenRootfs when dmDevice is nil - Replace time.After with time.NewTimer in healthcheck polling to avoid goroutine leak when healthcheck passes early - Capture size_bytes from CreateSnapshot/FlattenRootfs RPC responses instead of hardcoding 0 in the templates table insert - Avoid leaking internal error details to API clients in build handler --- internal/api/handlers_builds.go | 4 +++- internal/sandbox/manager.go | 2 ++ internal/service/build.go | 30 +++++++++++++++++++++--------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index ae9a48e1..e62b0c60 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "fmt" + "log/slog" "net/http" "time" @@ -119,7 +120,8 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) { MemoryMB: req.MemoryMB, }) if err != nil { - writeError(w, http.StatusInternalServerError, "build_error", err.Error()) + slog.Error("failed to create build", "error", err) + writeError(w, http.StatusInternalServerError, "build_error", "failed to create build") return } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 1d103bc2..88b058c6 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -839,6 +839,8 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (in outputPath := snapshot.RootfsPath(m.cfg.ImagesDir, name) if sb.dmDevice == nil { + m.cleanupDM(sb) + warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) return 0, fmt.Errorf("sandbox %s has no dm device", sandboxID) } diff --git a/internal/service/build.go b/internal/service/build.go index 054f8a24..3c7975d0 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -255,6 +255,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { } // Healthcheck or direct snapshot. + var sizeBytes int64 if build.Healthcheck.Valid && build.Healthcheck.String != "" { log.Info("running healthcheck", "cmd", build.Healthcheck.String) if err := s.waitForHealthcheck(ctx, agent, sandboxID, build.Healthcheck.String); err != nil { @@ -265,25 +266,29 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // Healthcheck passed → full snapshot (with memory/CPU state). log.Info("healthcheck passed, creating snapshot") - if _, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ + snapResp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ SandboxId: sandboxID, Name: build.Name, - })); err != nil { + })) + if err != nil { s.destroySandbox(ctx, agent, sandboxID) s.failBuild(ctx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) return } + sizeBytes = snapResp.Msg.SizeBytes } else { // No healthcheck → image-only template (rootfs only). log.Info("no healthcheck, flattening rootfs") - if _, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ + flatResp, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ SandboxId: sandboxID, Name: build.Name, - })); err != nil { + })) + if err != nil { s.destroySandbox(ctx, agent, sandboxID) s.failBuild(ctx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) return } + sizeBytes = flatResp.Msg.SizeBytes } // Insert into templates table as a global (platform) template. @@ -297,7 +302,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { Type: templateType, Vcpus: pgtype.Int4{Int32: build.Vcpus, Valid: true}, MemoryMb: pgtype.Int4{Int32: build.MemoryMb, Valid: true}, - SizeBytes: 0, // Could query the host, but the template is created. + SizeBytes: sizeBytes, TeamID: platformTeamID, }); err != nil { log.Error("failed to insert template record", "error", err) @@ -319,7 +324,8 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { } func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxID, cmd string) error { - deadline := time.After(healthcheckTimeout) + deadline := time.NewTimer(healthcheckTimeout) + defer deadline.Stop() ticker := time.NewTicker(healthcheckInterval) defer ticker.Stop() @@ -327,7 +333,7 @@ func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentC select { case <-ctx.Done(): return ctx.Err() - case <-deadline: + case <-deadline.C: return fmt.Errorf("healthcheck timed out after %s", healthcheckTimeout) case <-ticker.C: execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) @@ -366,8 +372,11 @@ func (s *BuildService) updateLogs(ctx context.Context, buildID string, step int, } } -func (s *BuildService) failBuild(ctx context.Context, buildID, errMsg string) { +func (s *BuildService) failBuild(_ context.Context, buildID, errMsg string) { slog.Error("build failed", "build_id", buildID, "error", errMsg) + // Use a detached context so DB writes survive parent context cancellation (e.g. shutdown). + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{ ID: buildID, Error: pgtype.Text{String: errMsg, Valid: true}, @@ -376,7 +385,10 @@ func (s *BuildService) failBuild(ctx context.Context, buildID, errMsg string) { } } -func (s *BuildService) destroySandbox(ctx context.Context, agent buildAgentClient, sandboxID string) { +func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxID string) { + // Use a detached context so cleanup succeeds even during shutdown. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ SandboxId: sandboxID, })); err != nil { From 4ddd49416082fea825ae6b0ea8b4106a676fa06e Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 16:16:21 +0600 Subject: [PATCH 10/28] Switch database IDs from TEXT to native UUID Consolidate 16 migrations into one with UUID columns for all entity IDs. TEXT is kept only for polymorphic fields (audit_logs.actor_id, resource_id) and template names. The id package now generates UUIDs via google/uuid, with Format*/Parse* helpers for the prefixed wire format (sb-{uuid}, usr-{uuid}, etc.). Auth context, services, and handlers pass pgtype.UUID internally; conversion to/from prefixed strings happens at API and RPC boundaries. Adds PlatformTeamID (all-zeros UUID) for shared resources. --- db/migrations/20260310094104_initial.sql | 245 ++++++++++++++++-- db/migrations/20260311224925_snapshots.sql | 14 - db/migrations/20260313210608_auth.sql | 46 ---- .../20260313210611_team_ownership.sql | 31 --- db/migrations/20260315001514_oauth.sql | 22 -- db/migrations/20260316203135_admin_users.sql | 21 -- db/migrations/20260316203138_byoc_teams.sql | 9 - db/migrations/20260316203142_hosts.sql | 47 ---- db/migrations/20260316223629_host_mtls.sql | 11 - .../20260324071453_team_management.sql | 17 -- db/migrations/20260324100234_user_names.sql | 5 - .../20260324120214_host_refresh_tokens.sql | 19 -- db/migrations/20260324220743_audit_logs.sql | 28 -- .../20260325074949_metrics_snapshots.sql | 18 -- ...260325135035_add_sandbox_metric_points.sql | 16 -- .../20260326090649_template_builds.sql | 25 -- db/queries/sandboxes.sql | 4 +- envd/internal/api/init.go | 7 +- internal/api/agent_helper.go | 4 +- internal/api/handler_sandbox_proxy.go | 27 +- internal/api/handlers_apikeys.go | 21 +- internal/api/handlers_audit.go | 16 +- internal/api/handlers_auth.go | 26 +- internal/api/handlers_builds.go | 25 +- internal/api/handlers_exec.go | 17 +- internal/api/handlers_exec_stream.go | 13 +- internal/api/handlers_files.go | 21 +- internal/api/handlers_files_stream.go | 21 +- internal/api/handlers_hosts.go | 159 ++++++++---- internal/api/handlers_metrics.go | 24 +- internal/api/handlers_oauth.go | 6 +- internal/api/handlers_sandbox.go | 45 +++- internal/api/handlers_snapshots.go | 25 +- internal/api/handlers_team.go | 47 +++- internal/api/handlers_users.go | 3 +- internal/api/host_monitor.go | 53 ++-- internal/api/middleware_auth.go | 18 +- internal/api/middleware_hosttoken.go | 9 +- internal/api/middleware_jwt.go | 16 +- internal/audit/logger.go | 146 ++++++----- internal/auth/context.go | 24 +- internal/auth/jwt.go | 18 +- internal/db/api_keys.sql.go | 28 +- internal/db/audit.sql.go | 10 +- internal/db/host_refresh_tokens.sql.go | 8 +- internal/db/hosts.sql.go | 62 ++--- internal/db/metrics.sql.go | 46 ++-- internal/db/models.go | 106 ++++---- internal/db/oauth.sql.go | 10 +- internal/db/sandboxes.sql.go | 86 +++--- internal/db/teams.sql.go | 80 +++--- internal/db/template_builds.sql.go | 26 +- internal/db/templates.sql.go | 20 +- internal/db/users.sql.go | 68 ++--- internal/id/id.go | 180 ++++++++----- internal/lifecycle/hostpool.go | 7 +- internal/sandbox/manager.go | 2 +- internal/scheduler/round_robin.go | 10 +- internal/service/apikey.go | 10 +- internal/service/audit.go | 21 +- internal/service/build.go | 97 +++---- internal/service/host.go | 126 +++++---- internal/service/sandbox.go | 61 +++-- internal/service/stats.go | 5 +- internal/service/team.go | 35 +-- internal/service/template.go | 4 +- 66 files changed, 1350 insertions(+), 1127 deletions(-) delete mode 100644 db/migrations/20260311224925_snapshots.sql delete mode 100644 db/migrations/20260313210608_auth.sql delete mode 100644 db/migrations/20260313210611_team_ownership.sql delete mode 100644 db/migrations/20260315001514_oauth.sql delete mode 100644 db/migrations/20260316203135_admin_users.sql delete mode 100644 db/migrations/20260316203138_byoc_teams.sql delete mode 100644 db/migrations/20260316203142_hosts.sql delete mode 100644 db/migrations/20260316223629_host_mtls.sql delete mode 100644 db/migrations/20260324071453_team_management.sql delete mode 100644 db/migrations/20260324100234_user_names.sql delete mode 100644 db/migrations/20260324120214_host_refresh_tokens.sql delete mode 100644 db/migrations/20260324220743_audit_logs.sql delete mode 100644 db/migrations/20260325074949_metrics_snapshots.sql delete mode 100644 db/migrations/20260325135035_add_sandbox_metric_points.sql delete mode 100644 db/migrations/20260326090649_template_builds.sql diff --git a/db/migrations/20260310094104_initial.sql b/db/migrations/20260310094104_initial.sql index c291815a..be5d29f6 100644 --- a/db/migrations/20260310094104_initial.sql +++ b/db/migrations/20260310094104_initial.sql @@ -1,25 +1,236 @@ -- +goose Up -CREATE TABLE sandboxes ( - id TEXT PRIMARY KEY, - owner_id TEXT NOT NULL DEFAULT '', - host_id TEXT NOT NULL DEFAULT 'default', - template TEXT NOT NULL DEFAULT 'minimal', - status TEXT NOT NULL DEFAULT 'pending', - vcpus INTEGER NOT NULL DEFAULT 1, - memory_mb INTEGER NOT NULL DEFAULT 512, - timeout_sec INTEGER NOT NULL DEFAULT 0, - guest_ip TEXT NOT NULL DEFAULT '', - host_ip TEXT NOT NULL DEFAULT '', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - started_at TIMESTAMPTZ, - last_active_at TIMESTAMPTZ, - last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW() +-- teams +CREATE TABLE teams ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL UNIQUE, + is_byoc BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); +CREATE INDEX idx_teams_slug ON teams(slug); + +-- users +CREATE TABLE users ( + id UUID PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + password_hash TEXT, + name TEXT NOT NULL DEFAULT '', + is_admin BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); +-- users_teams (junction) +CREATE TABLE users_teams ( + user_id UUID NOT NULL REFERENCES users(id), + team_id UUID NOT NULL REFERENCES teams(id), + is_default BOOLEAN NOT NULL DEFAULT FALSE, + role TEXT NOT NULL DEFAULT 'member', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (team_id, user_id) +); +CREATE INDEX idx_users_teams_user ON users_teams(user_id); + +-- team_api_keys +CREATE TABLE team_api_keys ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL REFERENCES teams(id), + name TEXT NOT NULL, + key_hash TEXT NOT NULL UNIQUE, + key_prefix TEXT NOT NULL, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_used TIMESTAMPTZ +); +CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id); + +-- oauth_providers +CREATE TABLE oauth_providers ( + provider TEXT NOT NULL, + provider_id TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES users(id), + email TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (provider, provider_id) +); +CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id); + +-- admin_permissions +CREATE TABLE admin_permissions ( + id UUID PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id), + permission TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, permission) +); +CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id); + +-- hosts +CREATE TABLE hosts ( + id UUID PRIMARY KEY, + type TEXT NOT NULL DEFAULT 'regular', + team_id UUID REFERENCES teams(id), + provider TEXT NOT NULL DEFAULT '', + availability_zone TEXT NOT NULL DEFAULT '', + arch TEXT NOT NULL DEFAULT '', + cpu_cores INTEGER NOT NULL DEFAULT 0, + memory_mb INTEGER NOT NULL DEFAULT 0, + disk_gb INTEGER NOT NULL DEFAULT 0, + address TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'pending', + last_heartbeat_at TIMESTAMPTZ, + metadata JSONB NOT NULL DEFAULT '{}', + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + cert_fingerprint TEXT NOT NULL DEFAULT '', + mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX idx_hosts_type ON hosts(type); +CREATE INDEX idx_hosts_team ON hosts(team_id); +CREATE INDEX idx_hosts_status ON hosts(status); + +-- host_tokens +CREATE TABLE host_tokens ( + id UUID PRIMARY KEY, + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL, + used_at TIMESTAMPTZ +); +CREATE INDEX idx_host_tokens_host ON host_tokens(host_id); + +-- host_tags +CREATE TABLE host_tags ( + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + tag TEXT NOT NULL, + PRIMARY KEY (host_id, tag) +); +CREATE INDEX idx_host_tags_tag ON host_tags(tag); + +-- host_refresh_tokens +CREATE TABLE host_refresh_tokens ( + id UUID PRIMARY KEY, + host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + token_hash TEXT NOT NULL UNIQUE, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMPTZ +); +CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id); + +-- templates (TEXT primary key — not UUID) +CREATE TABLE templates ( + name TEXT PRIMARY KEY, + type TEXT NOT NULL DEFAULT 'base', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + size_bytes BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + team_id UUID NOT NULL +); +CREATE INDEX idx_templates_team ON templates(team_id); + +-- sandboxes +CREATE TABLE sandboxes ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL REFERENCES teams(id), + host_id UUID NOT NULL, + template TEXT NOT NULL DEFAULT 'minimal', + status TEXT NOT NULL DEFAULT 'pending', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + timeout_sec INTEGER NOT NULL DEFAULT 300, + guest_ip TEXT NOT NULL DEFAULT '', + host_ip TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + started_at TIMESTAMPTZ, + last_active_at TIMESTAMPTZ, + last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW() +); CREATE INDEX idx_sandboxes_status ON sandboxes(status); CREATE INDEX idx_sandboxes_host_status ON sandboxes(host_id, status); +CREATE INDEX idx_sandboxes_team ON sandboxes(team_id); + +-- audit_logs (id and team_id are UUID; actor_id and resource_id are TEXT for polymorphism) +CREATE TABLE audit_logs ( + id UUID PRIMARY KEY, + team_id UUID NOT NULL, + actor_type TEXT NOT NULL, + actor_id TEXT, + actor_name TEXT NOT NULL DEFAULT '', + resource_type TEXT NOT NULL, + resource_id TEXT, + action TEXT NOT NULL, + scope TEXT NOT NULL DEFAULT 'team', + status TEXT NOT NULL DEFAULT 'success', + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX idx_audit_logs_team_time ON audit_logs(team_id, created_at DESC); +CREATE INDEX idx_audit_logs_team_resource ON audit_logs(team_id, resource_type, created_at DESC); + +-- sandbox_metrics_snapshots +CREATE TABLE sandbox_metrics_snapshots ( + id BIGSERIAL PRIMARY KEY, + team_id UUID NOT NULL, + sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + running_count INTEGER NOT NULL DEFAULT 0, + vcpus_reserved INTEGER NOT NULL DEFAULT 0, + memory_mb_reserved INTEGER NOT NULL DEFAULT 0 +); +CREATE INDEX idx_metrics_snapshots_team_time ON sandbox_metrics_snapshots(team_id, sampled_at DESC); + +-- sandbox_metric_points +CREATE TABLE sandbox_metric_points ( + sandbox_id UUID NOT NULL, + tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')), + ts BIGINT NOT NULL, + cpu_pct FLOAT8 NOT NULL DEFAULT 0, + mem_bytes BIGINT NOT NULL DEFAULT 0, + disk_bytes BIGINT NOT NULL DEFAULT 0, + PRIMARY KEY (sandbox_id, tier, ts) +); +CREATE INDEX idx_sandbox_metric_points_sandbox_tier ON sandbox_metric_points(sandbox_id, tier); + +-- template_builds +CREATE TABLE template_builds ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + base_template TEXT NOT NULL, + recipe JSONB NOT NULL DEFAULT '[]', + healthcheck TEXT NOT NULL DEFAULT '', + vcpus INTEGER NOT NULL DEFAULT 1, + memory_mb INTEGER NOT NULL DEFAULT 512, + status TEXT NOT NULL DEFAULT 'pending', + current_step INTEGER NOT NULL DEFAULT 0, + total_steps INTEGER NOT NULL DEFAULT 0, + logs JSONB NOT NULL DEFAULT '[]', + error TEXT NOT NULL DEFAULT '', + sandbox_id UUID, + host_id UUID, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ +); -- +goose Down - -DROP TABLE sandboxes; +DROP TABLE IF EXISTS template_builds; +DROP TABLE IF EXISTS sandbox_metric_points; +DROP TABLE IF EXISTS sandbox_metrics_snapshots; +DROP TABLE IF EXISTS audit_logs; +DROP TABLE IF EXISTS sandboxes; +DROP TABLE IF EXISTS templates; +DROP TABLE IF EXISTS host_refresh_tokens; +DROP TABLE IF EXISTS host_tags; +DROP TABLE IF EXISTS host_tokens; +DROP TABLE IF EXISTS hosts; +DROP TABLE IF EXISTS admin_permissions; +DROP TABLE IF EXISTS oauth_providers; +DROP TABLE IF EXISTS team_api_keys; +DROP TABLE IF EXISTS users_teams; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS teams; diff --git a/db/migrations/20260311224925_snapshots.sql b/db/migrations/20260311224925_snapshots.sql deleted file mode 100644 index 8a0427c2..00000000 --- a/db/migrations/20260311224925_snapshots.sql +++ /dev/null @@ -1,14 +0,0 @@ --- +goose Up - -CREATE TABLE templates ( - name TEXT PRIMARY KEY, - type TEXT NOT NULL DEFAULT 'base', -- 'base' or 'snapshot' - vcpus INTEGER, - memory_mb INTEGER, - size_bytes BIGINT NOT NULL DEFAULT 0, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- +goose Down - -DROP TABLE templates; diff --git a/db/migrations/20260313210608_auth.sql b/db/migrations/20260313210608_auth.sql deleted file mode 100644 index 03970a8f..00000000 --- a/db/migrations/20260313210608_auth.sql +++ /dev/null @@ -1,46 +0,0 @@ --- +goose Up - -CREATE TABLE users ( - id TEXT PRIMARY KEY, - email TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE teams ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE users_teams ( - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, - is_default BOOLEAN NOT NULL DEFAULT TRUE, - role TEXT NOT NULL DEFAULT 'owner', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - PRIMARY KEY (team_id, user_id) -); - -CREATE INDEX idx_users_teams_user ON users_teams(user_id); - -CREATE TABLE team_api_keys ( - id TEXT PRIMARY KEY, - team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE, - name TEXT NOT NULL DEFAULT '', - key_hash TEXT NOT NULL UNIQUE, - key_prefix TEXT NOT NULL DEFAULT '', - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - last_used TIMESTAMPTZ -); - -CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id); - --- +goose Down - -DROP TABLE team_api_keys; -DROP TABLE users_teams; -DROP TABLE teams; -DROP TABLE users; diff --git a/db/migrations/20260313210611_team_ownership.sql b/db/migrations/20260313210611_team_ownership.sql deleted file mode 100644 index 849e781e..00000000 --- a/db/migrations/20260313210611_team_ownership.sql +++ /dev/null @@ -1,31 +0,0 @@ --- +goose Up - -ALTER TABLE sandboxes - ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; - -UPDATE sandboxes SET team_id = owner_id WHERE owner_id != ''; - -ALTER TABLE sandboxes - DROP COLUMN owner_id; - -ALTER TABLE templates - ADD COLUMN team_id TEXT NOT NULL DEFAULT ''; - -CREATE INDEX idx_sandboxes_team ON sandboxes(team_id); -CREATE INDEX idx_templates_team ON templates(team_id); - --- +goose Down - -ALTER TABLE sandboxes - ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''; - -UPDATE sandboxes SET owner_id = team_id WHERE team_id != ''; - -ALTER TABLE sandboxes - DROP COLUMN team_id; - -ALTER TABLE templates - DROP COLUMN team_id; - -DROP INDEX IF EXISTS idx_sandboxes_team; -DROP INDEX IF EXISTS idx_templates_team; diff --git a/db/migrations/20260315001514_oauth.sql b/db/migrations/20260315001514_oauth.sql deleted file mode 100644 index c3c33e9b..00000000 --- a/db/migrations/20260315001514_oauth.sql +++ /dev/null @@ -1,22 +0,0 @@ --- +goose Up - -ALTER TABLE users - ALTER COLUMN password_hash DROP NOT NULL; - -CREATE TABLE oauth_providers ( - provider TEXT NOT NULL, - provider_id TEXT NOT NULL, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - email TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - PRIMARY KEY (provider, provider_id) -); - -CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id); - --- +goose Down - -DROP TABLE oauth_providers; - -UPDATE users SET password_hash = '' WHERE password_hash IS NULL; -ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL; diff --git a/db/migrations/20260316203135_admin_users.sql b/db/migrations/20260316203135_admin_users.sql deleted file mode 100644 index eff669b5..00000000 --- a/db/migrations/20260316203135_admin_users.sql +++ /dev/null @@ -1,21 +0,0 @@ --- +goose Up - -ALTER TABLE users - ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE; - -CREATE TABLE admin_permissions ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, - permission TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - UNIQUE (user_id, permission) -); - -CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id); - --- +goose Down - -DROP TABLE admin_permissions; - -ALTER TABLE users - DROP COLUMN is_admin; diff --git a/db/migrations/20260316203138_byoc_teams.sql b/db/migrations/20260316203138_byoc_teams.sql deleted file mode 100644 index bb2c8ec2..00000000 --- a/db/migrations/20260316203138_byoc_teams.sql +++ /dev/null @@ -1,9 +0,0 @@ --- +goose Up - -ALTER TABLE teams - ADD COLUMN is_byoc BOOLEAN NOT NULL DEFAULT FALSE; - --- +goose Down - -ALTER TABLE teams - DROP COLUMN is_byoc; diff --git a/db/migrations/20260316203142_hosts.sql b/db/migrations/20260316203142_hosts.sql deleted file mode 100644 index 372b3802..00000000 --- a/db/migrations/20260316203142_hosts.sql +++ /dev/null @@ -1,47 +0,0 @@ --- +goose Up - -CREATE TABLE hosts ( - id TEXT PRIMARY KEY, - type TEXT NOT NULL DEFAULT 'regular', -- 'regular' or 'byoc' - team_id TEXT REFERENCES teams(id) ON DELETE SET NULL, - provider TEXT, - availability_zone TEXT, - arch TEXT, - cpu_cores INTEGER, - memory_mb INTEGER, - disk_gb INTEGER, - address TEXT, -- ip:port of host agent - status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'online', 'offline', 'draining' - last_heartbeat_at TIMESTAMPTZ, - metadata JSONB NOT NULL DEFAULT '{}', - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE host_tokens ( - id TEXT PRIMARY KEY, - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - created_by TEXT NOT NULL REFERENCES users(id), - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - expires_at TIMESTAMPTZ NOT NULL, - used_at TIMESTAMPTZ -); - -CREATE TABLE host_tags ( - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - tag TEXT NOT NULL, - PRIMARY KEY (host_id, tag) -); - -CREATE INDEX idx_hosts_type ON hosts(type); -CREATE INDEX idx_hosts_team ON hosts(team_id); -CREATE INDEX idx_hosts_status ON hosts(status); -CREATE INDEX idx_host_tokens_host ON host_tokens(host_id); -CREATE INDEX idx_host_tags_tag ON host_tags(tag); - --- +goose Down - -DROP TABLE host_tags; -DROP TABLE host_tokens; -DROP TABLE hosts; diff --git a/db/migrations/20260316223629_host_mtls.sql b/db/migrations/20260316223629_host_mtls.sql deleted file mode 100644 index f56b923e..00000000 --- a/db/migrations/20260316223629_host_mtls.sql +++ /dev/null @@ -1,11 +0,0 @@ --- +goose Up - -ALTER TABLE hosts - ADD COLUMN cert_fingerprint TEXT, - ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; - --- +goose Down - -ALTER TABLE hosts - DROP COLUMN cert_fingerprint, - DROP COLUMN mtls_enabled; diff --git a/db/migrations/20260324071453_team_management.sql b/db/migrations/20260324071453_team_management.sql deleted file mode 100644 index 1495d6dd..00000000 --- a/db/migrations/20260324071453_team_management.sql +++ /dev/null @@ -1,17 +0,0 @@ --- +goose Up - -ALTER TABLE teams ADD COLUMN slug TEXT; -ALTER TABLE teams ADD COLUMN deleted_at TIMESTAMPTZ; - --- Backfill slugs for existing teams using MD5 of their ID. --- MD5 returns 32 hex chars; take chars 1-6 and 7-12 to form a 6-6 slug. -UPDATE teams SET slug = LEFT(MD5(id), 6) || '-' || SUBSTRING(MD5(id), 7, 6); - -ALTER TABLE teams ALTER COLUMN slug SET NOT NULL; -CREATE UNIQUE INDEX idx_teams_slug ON teams(slug); - --- +goose Down - -DROP INDEX idx_teams_slug; -ALTER TABLE teams DROP COLUMN deleted_at; -ALTER TABLE teams DROP COLUMN slug; diff --git a/db/migrations/20260324100234_user_names.sql b/db/migrations/20260324100234_user_names.sql deleted file mode 100644 index 2775d12d..00000000 --- a/db/migrations/20260324100234_user_names.sql +++ /dev/null @@ -1,5 +0,0 @@ --- +goose Up -ALTER TABLE users ADD COLUMN name TEXT NOT NULL DEFAULT ''; - --- +goose Down -ALTER TABLE users DROP COLUMN name; diff --git a/db/migrations/20260324120214_host_refresh_tokens.sql b/db/migrations/20260324120214_host_refresh_tokens.sql deleted file mode 100644 index 02a13f74..00000000 --- a/db/migrations/20260324120214_host_refresh_tokens.sql +++ /dev/null @@ -1,19 +0,0 @@ --- +goose Up - --- Refresh tokens for host agent JWT rotation. --- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation). --- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that. -CREATE TABLE host_refresh_tokens ( - id TEXT PRIMARY KEY, - host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, - token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token - expires_at TIMESTAMPTZ NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete -); - -CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id); - --- +goose Down - -DROP TABLE host_refresh_tokens; diff --git a/db/migrations/20260324220743_audit_logs.sql b/db/migrations/20260324220743_audit_logs.sql deleted file mode 100644 index 91b73754..00000000 --- a/db/migrations/20260324220743_audit_logs.sql +++ /dev/null @@ -1,28 +0,0 @@ --- +goose Up - -CREATE TABLE audit_logs ( - id TEXT PRIMARY KEY, - team_id TEXT NOT NULL, - actor_type TEXT NOT NULL, -- 'user', 'api_key', 'system' - actor_id TEXT, -- user_id or api_key_id; NULL for system - actor_name TEXT, -- display name snapshotted at write time; NULL for system - resource_type TEXT NOT NULL, -- 'sandbox', 'snapshot', 'team', 'api_key', 'member', 'host' - resource_id TEXT, -- primary ID of the affected resource; NULL when not applicable - action TEXT NOT NULL, -- 'create', 'pause', 'resume', 'destroy', 'delete', 'rename', - -- 'revoke', 'add', 'remove', 'leave', 'role_update', - -- 'marked_down', 'marked_up' - scope TEXT NOT NULL, -- 'team' or 'admin' - status TEXT NOT NULL, -- 'success', 'info', 'warning', 'error' - metadata JSONB NOT NULL DEFAULT '{}', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Primary access pattern: team feed sorted newest-first with cursor pagination. -CREATE INDEX idx_audit_logs_team_time ON audit_logs (team_id, created_at DESC); - --- Secondary index: filtered by resource_type and action within a team. -CREATE INDEX idx_audit_logs_team_resource ON audit_logs (team_id, resource_type, action, created_at DESC); - --- +goose Down - -DROP TABLE audit_logs; diff --git a/db/migrations/20260325074949_metrics_snapshots.sql b/db/migrations/20260325074949_metrics_snapshots.sql deleted file mode 100644 index 7d373e86..00000000 --- a/db/migrations/20260325074949_metrics_snapshots.sql +++ /dev/null @@ -1,18 +0,0 @@ --- +goose Up - -CREATE TABLE sandbox_metrics_snapshots ( - id BIGSERIAL PRIMARY KEY, - team_id TEXT NOT NULL, - sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - running_count INTEGER NOT NULL, - vcpus_reserved INTEGER NOT NULL, - memory_mb_reserved INTEGER NOT NULL -); - --- All queries filter on team_id first then range-scan sampled_at. -CREATE INDEX idx_metrics_snapshots_team_time - ON sandbox_metrics_snapshots (team_id, sampled_at DESC); - --- +goose Down - -DROP TABLE sandbox_metrics_snapshots; diff --git a/db/migrations/20260325135035_add_sandbox_metric_points.sql b/db/migrations/20260325135035_add_sandbox_metric_points.sql deleted file mode 100644 index 08e86835..00000000 --- a/db/migrations/20260325135035_add_sandbox_metric_points.sql +++ /dev/null @@ -1,16 +0,0 @@ --- +goose Up -CREATE TABLE sandbox_metric_points ( - sandbox_id TEXT NOT NULL, - tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')), - ts BIGINT NOT NULL, - cpu_pct FLOAT8 NOT NULL DEFAULT 0, - mem_bytes BIGINT NOT NULL DEFAULT 0, - disk_bytes BIGINT NOT NULL DEFAULT 0, - PRIMARY KEY (sandbox_id, tier, ts) -); - -CREATE INDEX idx_sandbox_metric_points_sandbox_tier - ON sandbox_metric_points (sandbox_id, tier); - --- +goose Down -DROP TABLE IF EXISTS sandbox_metric_points; diff --git a/db/migrations/20260326090649_template_builds.sql b/db/migrations/20260326090649_template_builds.sql deleted file mode 100644 index 8e5326dd..00000000 --- a/db/migrations/20260326090649_template_builds.sql +++ /dev/null @@ -1,25 +0,0 @@ --- +goose Up - -CREATE TABLE template_builds ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - base_template TEXT NOT NULL DEFAULT 'minimal', - recipe JSONB NOT NULL DEFAULT '[]', - healthcheck TEXT, - vcpus INTEGER NOT NULL DEFAULT 1, - memory_mb INTEGER NOT NULL DEFAULT 512, - status TEXT NOT NULL DEFAULT 'pending', - current_step INTEGER NOT NULL DEFAULT 0, - total_steps INTEGER NOT NULL DEFAULT 0, - logs JSONB NOT NULL DEFAULT '[]', - error TEXT, - sandbox_id TEXT, - host_id TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - started_at TIMESTAMPTZ, - completed_at TIMESTAMPTZ -); - --- +goose Down - -DROP TABLE template_builds; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 131fe1ed..71e61dc9 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -50,7 +50,7 @@ WHERE id = $1; UPDATE sandboxes SET status = $2, last_updated = NOW() -WHERE id = ANY($1::text[]); +WHERE id = ANY($1::uuid[]); -- name: ListActiveSandboxesByTeam :many SELECT * FROM sandboxes @@ -72,4 +72,4 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending'); UPDATE sandboxes SET status = 'running', last_updated = NOW() -WHERE id = ANY($1::text[]) AND status = 'missing'; +WHERE id = ANY($1::uuid[]) AND status = 'missing'; diff --git a/envd/internal/api/init.go b/envd/internal/api/init.go index bd2456e8..301400cb 100644 --- a/envd/internal/api/init.go +++ b/envd/internal/api/init.go @@ -14,12 +14,12 @@ import ( "os/exec" "time" - "github.com/awnumar/memguard" - "github.com/rs/zerolog" - "github.com/txn2/txeh" "git.omukk.dev/wrenn/sandbox/envd/internal/host" "git.omukk.dev/wrenn/sandbox/envd/internal/logs" "git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys" + "github.com/awnumar/memguard" + "github.com/rs/zerolog" + "github.com/txn2/txeh" ) var ( @@ -287,4 +287,3 @@ func getIPFamily(address string) (txeh.IPFamily, error) { return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address) } } - diff --git a/internal/api/agent_helper.go b/internal/api/agent_helper.go index ac5b38e0..98a881d6 100644 --- a/internal/api/agent_helper.go +++ b/internal/api/agent_helper.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -11,7 +13,7 @@ import ( // agentForHost looks up the host record and returns a Connect RPC client for it. // Returns an error if the host is not found or has no address. -func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) { +func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) { host, err := queries.GetHost(ctx, hostID) if err != nil { return nil, fmt.Errorf("host not found: %w", err) diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 019fd7fd..322a559c 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -10,14 +10,17 @@ import ( "strconv" "strings" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" ) // sandboxHostPattern matches hostnames like "49999-sb-abcd1234.localhost" or // "49999-sb-abcd1234.example.com". Captures: port, sandbox ID. -var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f]+)\.`) +var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f-]+)\.`) // SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests // whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching @@ -57,7 +60,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) } port := matches[1] - sandboxID := matches[2] + sandboxIDStr := matches[2] // Validate port. portNum, err := strconv.Atoi(port) @@ -73,6 +76,12 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + http.Error(w, "invalid sandbox ID", http.StatusBadRequest) + return + } + ctx := r.Context() // Look up sandbox and verify ownership. @@ -96,13 +105,13 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - if !agentHost.Address.Valid || agentHost.Address.String == "" { + if agentHost.Address == "" { http.Error(w, "host agent has no address", http.StatusServiceUnavailable) return } - agentAddr := lifecycle.EnsureScheme(agentHost.Address.String) - upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxID, port, r.URL.Path) + agentAddr := lifecycle.EnsureScheme(agentHost.Address) + upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) target, err := url.Parse(agentAddr) if err != nil { @@ -121,7 +130,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("sandbox proxy error", - "sandbox_id", sandboxID, + "sandbox_id", sandboxIDStr, "port", port, "error", err, ) @@ -134,16 +143,16 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) // authenticateRequest validates the request's API key and returns the team ID. // Only API key authentication is supported for sandbox proxy requests (not JWT). -func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (string, error) { +func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) { key := r.Header.Get("X-API-Key") if key == "" { - return "", fmt.Errorf("X-API-Key header required") + return pgtype.UUID{}, fmt.Errorf("X-API-Key header required") } hash := auth.HashAPIKey(key) row, err := h.db.GetAPIKeyByHash(r.Context(), hash) if err != nil { - return "", fmt.Errorf("invalid API key") + return pgtype.UUID{}, fmt.Errorf("invalid API key") } return row.TeamID, nil } diff --git a/internal/api/handlers_apikeys.go b/internal/api/handlers_apikeys.go index 2637181d..700ddc52 100644 --- a/internal/api/handlers_apikeys.go +++ b/internal/api/handlers_apikeys.go @@ -9,6 +9,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -39,11 +40,11 @@ type apiKeyResponse struct { func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse { resp := apiKeyResponse{ - ID: k.ID, - TeamID: k.TeamID, + ID: id.FormatAPIKeyID(k.ID), + TeamID: id.FormatTeamID(k.TeamID), Name: k.Name, KeyPrefix: k.KeyPrefix, - CreatedBy: k.CreatedBy, + CreatedBy: id.FormatUserID(k.CreatedBy), } if k.CreatedAt.Valid { resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339) @@ -57,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse { func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse { resp := apiKeyResponse{ - ID: k.ID, - TeamID: k.TeamID, + ID: id.FormatAPIKeyID(k.ID), + TeamID: id.FormatTeamID(k.TeamID), Name: k.Name, KeyPrefix: k.KeyPrefix, - CreatedBy: k.CreatedBy, + CreatedBy: id.FormatUserID(k.CreatedBy), CreatorEmail: k.CreatorEmail, } if k.CreatedAt.Valid { @@ -118,7 +119,13 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) { // Delete handles DELETE /v1/api-keys/{id}. func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(r.Context()) - keyID := chi.URLParam(r, "id") + keyIDStr := chi.URLParam(r, "id") + + keyID, err := id.ParseAPIKeyID(keyIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID") + return + } if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key") diff --git a/internal/api/handlers_audit.go b/internal/api/handlers_audit.go index 7812309d..a19ab1dc 100644 --- a/internal/api/handlers_audit.go +++ b/internal/api/handlers_audit.go @@ -6,7 +6,10 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -65,13 +68,24 @@ func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) { limit = n } + // Parse ?before_id cursor (UUID). + var beforeID pgtype.UUID + if s := r.URL.Query().Get("before_id"); s != "" { + parsed, err := id.ParseAuditLogID(s) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID") + return + } + beforeID = parsed + } + entries, err := h.svc.List(r.Context(), service.AuditListParams{ TeamID: ac.TeamID, AdminScoped: ac.Role == "owner" || ac.Role == "admin", ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]), Actions: parseMultiParam(r.URL.Query()["action"]), Before: before, - BeforeID: r.URL.Query().Get("before_id"), + BeforeID: beforeID, Limit: limit, }) if err != nil { diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go index ba60d8e1..b1d4915f 100644 --- a/internal/api/handlers_auth.go +++ b/internal/api/handlers_auth.go @@ -20,7 +20,7 @@ import ( // It prefers the user's default team; if none is flagged as default it falls // back to the earliest-joined team. Returns pgx.ErrNoRows when the user has // no team memberships at all. -func loginTeam(ctx context.Context, q *db.Queries, userID string) (db.Team, string, error) { +func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) { team, err := q.GetDefaultTeamForUser(ctx, userID) if err == nil { membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID}) @@ -176,8 +176,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusCreated, authResponse{ Token: token, - UserID: userID, - TeamID: teamID, + UserID: id.FormatUserID(userID), + TeamID: id.FormatTeamID(teamID), Email: req.Email, Name: req.Name, }) @@ -236,8 +236,8 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, authResponse{ Token: token, - UserID: user.ID, - TeamID: team.ID, + UserID: id.FormatUserID(user.ID), + TeamID: id.FormatTeamID(team.ID), Email: user.Email, Name: user.Name, }) @@ -260,10 +260,16 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { return } + teamID, err := id.ParseTeamID(req.TeamID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id") + return + } + ctx := r.Context() // Verify team exists and is not deleted. - team, err := h.db.GetTeam(ctx, req.TeamID) + team, err := h.db.GetTeam(ctx, teamID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { writeError(w, http.StatusNotFound, "not_found", "team not found") @@ -280,7 +286,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { // Verify membership from DB — JWT role is not trusted here. membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: ac.UserID, - TeamID: req.TeamID, + TeamID: teamID, }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -298,7 +304,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { return } - token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin) + token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token") return @@ -306,8 +312,8 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, authResponse{ Token: token, - UserID: ac.UserID, - TeamID: req.TeamID, + UserID: id.FormatUserID(ac.UserID), + TeamID: id.FormatTeamID(teamID), Email: ac.Email, Name: user.Name, }) diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index e62b0c60..f1b3973d 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" ) @@ -53,7 +54,7 @@ type buildResponse struct { func buildToResponse(b db.TemplateBuild) buildResponse { resp := buildResponse{ - ID: b.ID, + ID: id.FormatBuildID(b.ID), Name: b.Name, BaseTemplate: b.BaseTemplate, Recipe: b.Recipe, @@ -64,17 +65,19 @@ func buildToResponse(b db.TemplateBuild) buildResponse { TotalSteps: b.TotalSteps, Logs: b.Logs, } - if b.Healthcheck.Valid { - resp.Healthcheck = &b.Healthcheck.String + if b.Healthcheck != "" { + resp.Healthcheck = &b.Healthcheck } - if b.Error.Valid { - resp.Error = &b.Error.String + if b.Error != "" { + resp.Error = &b.Error } if b.SandboxID.Valid { - resp.SandboxID = &b.SandboxID.String + s := id.FormatSandboxID(b.SandboxID) + resp.SandboxID = &s } if b.HostID.Valid { - resp.HostID = &b.HostID.String + s := id.FormatHostID(b.HostID) + resp.HostID = &s } if b.CreatedAt.Valid { resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339) @@ -146,7 +149,13 @@ func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/admin/builds/{id}. func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) { - buildID := chi.URLParam(r, "id") + buildIDStr := chi.URLParam(r, "id") + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID") + return + } build, err := h.svc.Get(r.Context(), buildID) if err != nil { diff --git a/internal/api/handlers_exec.go b/internal/api/handlers_exec.go index 84b38330..596457b2 100644 --- a/internal/api/handlers_exec.go +++ b/internal/api/handlers_exec.go @@ -14,6 +14,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -46,10 +47,16 @@ type execResponse struct { // Exec handles POST /v1/sandboxes/{id}/exec. func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -80,7 +87,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { } resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: req.Cmd, Args: req.Args, TimeoutSec: req.TimeoutSec, @@ -101,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { Valid: true, }, }); err != nil { - slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err) + slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err) } // Use base64 encoding if output contains non-UTF-8 bytes. @@ -112,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { if !utf8.Valid(stdout) || !utf8.Valid(stderr) { encoding = "base64" writeJSON(w, http.StatusOK, execResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Cmd: req.Cmd, Stdout: base64.StdEncoding.EncodeToString(stdout), Stderr: base64.StdEncoding.EncodeToString(stderr), @@ -124,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, execResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Cmd: req.Cmd, Stdout: string(stdout), Stderr: string(stderr), diff --git a/internal/api/handlers_exec_stream.go b/internal/api/handlers_exec_stream.go index 3ecfdfe7..52dfd17e 100644 --- a/internal/api/handlers_exec_stream.go +++ b/internal/api/handlers_exec_stream.go @@ -14,6 +14,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -48,10 +49,16 @@ type wsOutMsg struct { // ExecStream handles WS /v1/sandboxes/{id}/exec/stream. func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -91,7 +98,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { defer cancel() stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: startMsg.Cmd, Args: startMsg.Args, })) @@ -157,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) { Valid: true, }, }); err != nil { - slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err) + slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err) } } diff --git a/internal/api/handlers_files.go b/internal/api/handlers_files.go index c5fff70d..a2e9936c 100644 --- a/internal/api/handlers_files.go +++ b/internal/api/handlers_files.go @@ -11,6 +11,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl // - "path" text field: absolute destination path inside the sandbox // - "file" file field: binary content to write func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -82,7 +89,7 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) { } if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: filePath, Content: content, })); err != nil { @@ -101,10 +108,16 @@ type readFileRequest struct { // Download handles POST /v1/sandboxes/{id}/files/read. // Accepts JSON body with path, returns raw file content with Content-Disposition. func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -133,7 +146,7 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) { } resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: req.Path, })) if err != nil { diff --git a/internal/api/handlers_files_stream.go b/internal/api/handlers_files_stream.go index 66e89c73..e6c040f2 100644 --- a/internal/api/handlers_files_stream.go +++ b/internal/api/handlers_files_stream.go @@ -12,6 +12,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -29,10 +30,16 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file // Expects multipart/form-data with "path" text field and "file" file field. // Streams file content directly from the request body to the host agent without buffering. func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -101,7 +108,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request if err := stream.Send(&pb.WriteFileStreamRequest{ Content: &pb.WriteFileStreamRequest_Meta{ Meta: &pb.WriteFileStreamMeta{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: filePath, }, }, @@ -146,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request // StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read. // Accepts JSON body with path, streams file content back without buffering. func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -178,7 +191,7 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque // Open server-streaming RPC to host agent. stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Path: req.Path, })) if err != nil { diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go index f4f79173..c910c612 100644 --- a/internal/api/handlers_hosts.go +++ b/internal/api/handlers_hosts.go @@ -8,9 +8,12 @@ import ( "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -93,34 +96,35 @@ type hostResponse struct { func hostToResponse(h db.Host) hostResponse { resp := hostResponse{ - ID: h.ID, + ID: id.FormatHostID(h.ID), Type: h.Type, Status: h.Status, - CreatedBy: h.CreatedBy, + CreatedBy: id.FormatUserID(h.CreatedBy), } if h.TeamID.Valid { - resp.TeamID = &h.TeamID.String + s := id.FormatTeamID(h.TeamID) + resp.TeamID = &s } - if h.Provider.Valid { - resp.Provider = &h.Provider.String + if h.Provider != "" { + resp.Provider = &h.Provider } - if h.AvailabilityZone.Valid { - resp.AvailabilityZone = &h.AvailabilityZone.String + if h.AvailabilityZone != "" { + resp.AvailabilityZone = &h.AvailabilityZone } - if h.Arch.Valid { - resp.Arch = &h.Arch.String + if h.Arch != "" { + resp.Arch = &h.Arch } - if h.CpuCores.Valid { - resp.CPUCores = &h.CpuCores.Int32 + if h.CpuCores != 0 { + resp.CPUCores = &h.CpuCores } - if h.MemoryMb.Valid { - resp.MemoryMB = &h.MemoryMb.Int32 + if h.MemoryMb != 0 { + resp.MemoryMB = &h.MemoryMb } - if h.DiskGb.Valid { - resp.DiskGB = &h.DiskGb.Int32 + if h.DiskGb != 0 { + resp.DiskGB = &h.DiskGb } - if h.Address.Valid { - resp.Address = &h.Address.String + if h.Address != "" { + resp.Address = &h.Address } if h.LastHeartbeatAt.Valid { s := h.LastHeartbeatAt.Time.Format(time.RFC3339) @@ -133,7 +137,7 @@ func hostToResponse(h db.Host) hostResponse { } // isAdmin fetches the user record and returns whether they are an admin. -func (h *hostHandler) isAdmin(r *http.Request, userID string) bool { +func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool { user, err := h.queries.GetUserByID(r.Context(), userID) if err != nil { return false @@ -151,14 +155,23 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(r.Context()) - result, err := h.svc.Create(r.Context(), service.HostCreateParams{ - Type: req.Type, - TeamID: req.TeamID, - Provider: req.Provider, - AvailabilityZone: req.AvailabilityZone, - RequestingUserID: ac.UserID, - IsRequestorAdmin: h.isAdmin(r, ac.UserID), - }) + // Parse optional team ID from request body. + var params service.HostCreateParams + params.Type = req.Type + params.Provider = req.Provider + params.AvailabilityZone = req.AvailabilityZone + params.RequestingUserID = ac.UserID + params.IsRequestorAdmin = h.isAdmin(r, ac.UserID) + if req.TeamID != "" { + teamID, err := id.ParseTeamID(req.TeamID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id") + return + } + params.TeamID = teamID + } + + result, err := h.svc.Create(r.Context(), params) if err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -166,8 +179,7 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { } // Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team). - hostTeamID := result.Host.TeamID.String - h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, hostTeamID) + h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID) writeJSON(w, http.StatusCreated, createHostResponse{ Host: hostToResponse(result.Host), @@ -192,14 +204,22 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { seen := make(map[string]struct{}) for _, host := range hosts { if host.TeamID.Valid { - seen[host.TeamID.String] = struct{}{} + key := id.FormatTeamID(host.TeamID) + seen[key] = struct{}{} } } if len(seen) > 0 { teamNames = make(map[string]string, len(seen)) - for id := range seen { - if team, err := h.queries.GetTeam(r.Context(), id); err == nil { - teamNames[id] = team.Name + for _, host := range hosts { + if !host.TeamID.Valid { + continue + } + key := id.FormatTeamID(host.TeamID) + if _, ok := teamNames[key]; ok { + continue + } + if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil { + teamNames[key] = team.Name } } } @@ -209,7 +229,8 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { for i, host := range hosts { resp[i] = hostToResponse(host) if host.TeamID.Valid { - if name, ok := teamNames[host.TeamID.String]; ok { + key := id.FormatTeamID(host.TeamID) + if name, ok := teamNames[key]; ok { resp[i].TeamName = &name } } @@ -220,9 +241,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/hosts/{id}. func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -236,9 +263,15 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { // DeletePreview handles GET /v1/hosts/{id}/delete-preview. // Returns what would be affected without making changes, for confirmation UI. func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -256,19 +289,25 @@ func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) { // Without ?force=true: returns 409 with affected sandbox IDs if any are active. // With ?force=true: gracefully stops all sandboxes then deletes the host. func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) force := r.URL.Query().Get("force") == "true" + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + // Fetch host before deletion to capture team_id for audit. deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID) if hostErr != nil { - slog.Warn("audit: could not fetch host before delete", "host_id", hostID, "error", hostErr) + slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr) } - err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force) + err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force) if err == nil { - h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID.String) + h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID) w.WriteHeader(http.StatusNoContent) return } @@ -292,9 +331,15 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { // RegenerateToken handles POST /v1/hosts/{id}/token. func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -348,9 +393,15 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { // Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated). func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") hc := auth.MustHostFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + // Prevent a host from heartbeating for a different host. if hostID != hc.HostID { writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch") @@ -368,7 +419,7 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { // Log marked_up if the host just recovered from unreachable. if prevHost.Status == "unreachable" { - h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID.String, hc.HostID) + h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID) } w.WriteHeader(http.StatusNoContent) @@ -376,10 +427,16 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { // AddTag handles POST /v1/hosts/{id}/tags. func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) admin := h.isAdmin(r, ac.UserID) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + var req addTagRequest if err := decodeJSON(r, &req); err != nil { writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") @@ -401,10 +458,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) { // RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}. func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") tag := chi.URLParam(r, "tag") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -443,9 +506,15 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { // ListTags handles GET /v1/hosts/{id}/tags. func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) { - hostID := chi.URLParam(r, "id") + hostIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + hostID, err := id.ParseHostID(hostIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID") + return + } + tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) if err != nil { status, code, msg := serviceErrToHTTP(err) diff --git a/internal/api/handlers_metrics.go b/internal/api/handlers_metrics.go index 793349e5..25f485c6 100644 --- a/internal/api/handlers_metrics.go +++ b/internal/api/handlers_metrics.go @@ -7,9 +7,11 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -38,10 +40,16 @@ type metricsResponse struct { // GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h. func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ctx := r.Context() ac := auth.MustFromContext(ctx) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + rangeTier := r.URL.Query().Get("range") if rangeTier == "" { rangeTier = "10m" @@ -60,15 +68,15 @@ func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Reques switch sb.Status { case "running": - h.getFromAgent(w, r, sandboxID, rangeTier, sb.HostID) + h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID) case "paused": - h.getFromDB(ctx, w, sandboxID, rangeTier) + h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier) default: writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status) } } -func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxID, rangeTier, hostID string) { +func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) { ctx := r.Context() agent, err := agentForHost(ctx, h.db, h.pool, hostID) @@ -78,7 +86,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ } resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Range: rangeTier, })) if err != nil { @@ -98,7 +106,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ } writeJSON(w, http.StatusOK, metricsResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Range: rangeTier, Points: points, }) @@ -118,7 +126,7 @@ var rangeToDB = map[string]struct { "24h": {"24h", 24 * time.Hour}, } -func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxID, rangeTier string) { +func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) { mapping := rangeToDB[rangeTier] rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{ SandboxID: sandboxID, @@ -141,7 +149,7 @@ func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWr } writeJSON(w, http.StatusOK, metricsResponse{ - SandboxID: sandboxID, + SandboxID: sandboxIDStr, Range: rangeTier, Points: points, }) diff --git a/internal/api/handlers_oauth.go b/internal/api/handlers_oauth.go index 348dd859..a9c448ac 100644 --- a/internal/api/handlers_oauth.go +++ b/internal/api/handlers_oauth.go @@ -162,7 +162,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { redirectWithError(w, r, redirectBase, "internal_error") return } - redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name) return } if !errors.Is(err, pgx.ErrNoRows) { @@ -262,7 +262,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { return } - redirectWithToken(w, r, redirectBase, token, userID, teamID, email, profile.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name) } // retryAsLogin handles the race where a concurrent request already created the user. @@ -296,7 +296,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov redirectWithError(w, r, redirectBase, "internal_error") return } - redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name) + redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name) } func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) { diff --git a/internal/api/handlers_sandbox.go b/internal/api/handlers_sandbox.go index b2709a5d..a19a7cc3 100644 --- a/internal/api/handlers_sandbox.go +++ b/internal/api/handlers_sandbox.go @@ -10,6 +10,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -46,7 +47,7 @@ type sandboxResponse struct { func sandboxToResponse(sb db.Sandbox) sandboxResponse { resp := sandboxResponse{ - ID: sb.ID, + ID: id.FormatSandboxID(sb.ID), Status: sb.Status, Template: sb.Template, VCPUs: sb.Vcpus, @@ -81,7 +82,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { } ac := auth.MustFromContext(r.Context()) - if ac.TeamID == "" { + if !ac.TeamID.Valid { writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate") return } @@ -122,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) { // Get handles GET /v1/sandboxes/{id}. func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") @@ -136,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) { // Pause handles POST /v1/sandboxes/{id}/pause. func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -152,9 +165,15 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) { // Resume handles POST /v1/sandboxes/{id}/resume. func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID) if err != nil { status, code, msg := serviceErrToHTTP(err) @@ -168,9 +187,15 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) { // Ping handles POST /v1/sandboxes/{id}/ping. func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) @@ -182,9 +207,15 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) { // Destroy handles DELETE /v1/sandboxes/{id}. func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) { - sandboxID := chi.URLParam(r, "id") + sandboxIDStr := chi.URLParam(r, "id") ac := auth.MustFromContext(r.Context()) + sandboxID, err := id.ParseSandboxID(sandboxIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID") + return + } + if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil { status, code, msg := serviceErrToHTTP(err) writeError(w, status, code, msg) diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index fbfcdc18..f3e29074 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -10,7 +10,6 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" - "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" @@ -51,7 +50,7 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri } if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { if connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err) + slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err) } } } @@ -78,11 +77,11 @@ func templateToResponse(t db.Template) snapshotResponse { Type: t.Type, SizeBytes: t.SizeBytes, } - if t.Vcpus.Valid { - resp.VCPUs = &t.Vcpus.Int32 + if t.Vcpus != 0 { + resp.VCPUs = &t.Vcpus } - if t.MemoryMb.Valid { - resp.MemoryMB = &t.MemoryMb.Int32 + if t.MemoryMb != 0 { + resp.MemoryMB = &t.MemoryMb } if t.CreatedAt.Valid { resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339) @@ -103,6 +102,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } + sandboxID, err := id.ParseSandboxID(req.SandboxID) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id") + return + } + if req.Name == "" { req.Name = id.NewSnapshotName() } @@ -133,7 +138,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { } // Verify sandbox exists, belongs to team, and is running or paused. - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID}) + sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID}) if err != nil { writeError(w, http.StatusNotFound, "not_found", "sandbox not found") return @@ -162,7 +167,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { // Mark sandbox as paused (if it was running, it got paused by the snapshot). if sb.Status != "paused" { if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: req.SandboxID, Status: "paused", + ID: sandboxID, Status: "paused", }); err != nil { slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err) } @@ -171,8 +176,8 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{ Name: req.Name, Type: "snapshot", - Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true}, - MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true}, + Vcpus: sb.Vcpus, + MemoryMb: sb.MemoryMb, SizeBytes: resp.Msg.SizeBytes, TeamID: ac.TeamID, }) diff --git a/internal/api/handlers_team.go b/internal/api/handlers_team.go index 3950ab73..2bf99f9a 100644 --- a/internal/api/handlers_team.go +++ b/internal/api/handlers_team.go @@ -7,10 +7,12 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/service" ) @@ -48,7 +50,7 @@ type memberResponse struct { func teamToResponse(t db.Team) teamResponse { resp := teamResponse{ - ID: t.ID, + ID: id.FormatTeamID(t.ID), Name: t.Name, Slug: t.Slug, IsByoc: t.IsByoc, @@ -72,11 +74,16 @@ func memberInfoToResponse(m service.MemberInfo) memberResponse { // requireTeamAccess is an inline check used by every team-scoped handler: // the JWT team_id must match the URL {id} before any DB call is made. // Returns false and writes 403 if they don't match. -func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (string, bool) { - teamID := chi.URLParam(r, "id") +func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) { + teamIDStr := chi.URLParam(r, "id") + teamID, err := id.ParseTeamID(teamIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID") + return pgtype.UUID{}, false + } if ac.TeamID != teamID { writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first") - return "", false + return pgtype.UUID{}, false } return teamID, true } @@ -185,7 +192,7 @@ func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) { // Fetch old name for audit log before renaming. oldTeam, err := h.svc.GetTeam(r.Context(), teamID) if err != nil { - slog.Warn("audit: could not fetch old team name for rename log", "team_id", teamID, "error", err) + slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err) } if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil { @@ -267,7 +274,11 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) { return } - h.audit.LogMemberAdd(r.Context(), ac, member.UserID, member.Email, member.Role) + // member.UserID is already formatted with prefix; parse it back for the audit logger. + targetUserID, parseErr := id.ParseUserID(member.UserID) + if parseErr == nil { + h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role) + } writeJSON(w, http.StatusCreated, memberInfoToResponse(member)) } @@ -279,7 +290,13 @@ func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { if !ok { return } - targetUserID := chi.URLParam(r, "uid") + targetUserIDStr := chi.URLParam(r, "uid") + + targetUserID, err := id.ParseUserID(targetUserIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID") + return + } if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil { status, code, msg := serviceErrToHTTP(err) @@ -299,7 +316,13 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) { if !ok { return } - targetUserID := chi.URLParam(r, "uid") + targetUserIDStr := chi.URLParam(r, "uid") + + targetUserID, err := id.ParseUserID(targetUserIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID") + return + } var req struct { Role string `json:"role"` @@ -341,7 +364,13 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) { // SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only). // Enables or disables the BYOC feature flag for a team. func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) { - teamID := chi.URLParam(r, "id") + teamIDStr := chi.URLParam(r, "id") + + teamID, err := id.ParseTeamID(teamIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID") + return + } var req struct { Enabled bool `json:"enabled"` diff --git a/internal/api/handlers_users.go b/internal/api/handlers_users.go index 8269d3c0..17050641 100644 --- a/internal/api/handlers_users.go +++ b/internal/api/handlers_users.go @@ -8,6 +8,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) type usersHandler struct { @@ -45,7 +46,7 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) { } resp := make([]userResult, len(results)) for i, u := range results { - resp[i] = userResult{UserID: u.ID, Email: u.Email} + resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email} } writeJSON(w, http.StatusOK, resp) } diff --git a/internal/api/host_monitor.go b/internal/api/host_monitor.go index 4bf19d87..95fde106 100644 --- a/internal/api/host_monitor.go +++ b/internal/api/host_monitor.go @@ -6,9 +6,11 @@ import ( "time" "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -82,15 +84,15 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold if stale && host.Status != "unreachable" { - slog.Info("host monitor: marking host unreachable", "host_id", host.ID, + slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID), "last_heartbeat", host.LastHeartbeatAt.Time) if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil { - slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err) } if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil { - slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err) } - m.audit.LogHostMarkedDown(ctx, host.TeamID.String, host.ID) + m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID) return } @@ -110,19 +112,20 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { if err != nil { // RPC failure is a transient condition; the passive phase will catch it // if heartbeats stop arriving. - slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err) + slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err) return } // Build set of sandbox IDs alive on the host. + // The host agent returns sandbox IDs as strings (formatted with prefix). alive := make(map[string]struct{}, len(resp.Msg.Sandboxes)) for _, sb := range resp.Msg.Sandboxes { alive[sb.SandboxId] = struct{}{} } autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds)) - for _, id := range resp.Msg.AutoPausedSandboxIds { - autoPaused[id] = struct{}{} + for _, apID := range resp.Msg.AutoPausedSandboxIds { + autoPaused[apID] = struct{}{} } // --- Restore sandboxes that are "missing" in DB but alive on host --- @@ -134,30 +137,31 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { Column2: []string{"missing"}, }) if err != nil { - slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } else { - var toRestore []string - var toStop []string + var toRestore []pgtype.UUID + var toStop []pgtype.UUID for _, sb := range missingSandboxes { - if _, ok := alive[sb.ID]; ok { + sbIDStr := id.FormatSandboxID(sb.ID) + if _, ok := alive[sbIDStr]; ok { toRestore = append(toRestore, sb.ID) } else { toStop = append(toStop, sb.ID) } } if len(toRestore) > 0 { - slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore)) + slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore)) if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil { - slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } } if len(toStop) > 0 { - slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop)) + slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toStop, Status: "stopped", }); err != nil { - slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) } } } @@ -169,18 +173,19 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { Column2: []string{"running"}, }) if err != nil { - slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err) return } - var toPause, toStop []string - sbTeamID := make(map[string]string, len(runningSandboxes)) + var toPause, toStop []pgtype.UUID + sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes)) for _, sb := range runningSandboxes { + sbIDStr := id.FormatSandboxID(sb.ID) sbTeamID[sb.ID] = sb.TeamID - if _, ok := alive[sb.ID]; ok { + if _, ok := alive[sbIDStr]; ok { continue } - if _, ok := autoPaused[sb.ID]; ok { + if _, ok := autoPaused[sbIDStr]; ok { toPause = append(toPause, sb.ID) } else { toStop = append(toStop, sb.ID) @@ -188,24 +193,24 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) { } if len(toPause) > 0 { - slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause)) + slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toPause, Status: "paused", }); err != nil { - slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err) } for _, sbID := range toPause { m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID) } } if len(toStop) > 0 { - slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop)) + slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop)) if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{ Column1: toStop, Status: "stopped", }); err != nil { - slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err) + slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err) } } } diff --git a/internal/api/middleware_auth.go b/internal/api/middleware_auth.go index dee4240f..985b2899 100644 --- a/internal/api/middleware_auth.go +++ b/internal/api/middleware_auth.go @@ -7,6 +7,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT. @@ -24,7 +25,7 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler } if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil { - slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err) + slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err) } ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ @@ -45,9 +46,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler return } + teamID, err := id.ParseTeamID(claims.TeamID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token") + return + } + userID, err := id.ParseUserID(claims.Subject) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token") + return + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ - TeamID: claims.TeamID, - UserID: claims.Subject, + TeamID: teamID, + UserID: userID, Email: claims.Email, Name: claims.Name, Role: claims.Role, diff --git a/internal/api/middleware_hosttoken.go b/internal/api/middleware_hosttoken.go index a5c5e6f1..b926e41f 100644 --- a/internal/api/middleware_hosttoken.go +++ b/internal/api/middleware_hosttoken.go @@ -4,6 +4,7 @@ import ( "net/http" "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireHostToken validates the X-Host-Token header containing a host JWT, @@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler { return } - ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID}) + hostID, err := id.ParseHostID(claims.HostID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token") + return + } + + ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID}) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/api/middleware_jwt.go b/internal/api/middleware_jwt.go index 96b1c68a..37215388 100644 --- a/internal/api/middleware_jwt.go +++ b/internal/api/middleware_jwt.go @@ -5,6 +5,7 @@ import ( "strings" "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/id" ) // requireJWT validates the Authorization: Bearer header, verifies the JWT @@ -25,9 +26,20 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler { return } + teamID, err := id.ParseTeamID(claims.TeamID) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token") + return + } + userID, err := id.ParseUserID(claims.Subject) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token") + return + } + ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{ - TeamID: claims.TeamID, - UserID: claims.Subject, + TeamID: teamID, + UserID: userID, Email: claims.Email, Name: claims.Name, Role: claims.Role, diff --git a/internal/audit/logger.go b/internal/audit/logger.go index 8f44059f..e60d8a6c 100644 --- a/internal/audit/logger.go +++ b/internal/audit/logger.go @@ -25,18 +25,15 @@ func New(queries *db.Queries) *AuditLogger { } // actorFields extracts actor_type, actor_id, and actor_name from an AuthContext. -func actorFields(ac auth.AuthContext) (actorType string, actorID pgtype.Text, actorName pgtype.Text) { - if ac.UserID != "" { - return "user", - pgtype.Text{String: ac.UserID, Valid: true}, - pgtype.Text{String: ac.Name, Valid: ac.Name != ""} +// actor_id is stored as a prefixed string in the TEXT column. +func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) { + if ac.UserID.Valid { + return "user", id.FormatUserID(ac.UserID), ac.Name } - if ac.APIKeyID != "" { - return "api_key", - pgtype.Text{String: ac.APIKeyID, Valid: true}, - pgtype.Text{String: ac.APIKeyName, Valid: true} + if ac.APIKeyID.Valid { + return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName } - return "system", pgtype.Text{}, pgtype.Text{} + return "system", "", "" } func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) { @@ -44,7 +41,6 @@ func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) { slog.Warn("audit: failed to write log entry", "action", p.Action, "resource_type", p.ResourceType, - "team_id", p.TeamID, "error", err, ) } @@ -61,18 +57,26 @@ func marshalMeta(meta map[string]any) []byte { return b } +// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one. +func optText(s string) pgtype.Text { + if s == "" { + return pgtype.Text{} + } + return pgtype.Text{String: s, Valid: true} +} + // --- Sandbox events (scope: team) --- -func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID, template string) { +func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "create", Scope: "team", Status: "success", @@ -80,16 +84,16 @@ func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "pause", Scope: "team", Status: "success", @@ -98,15 +102,15 @@ func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, } // LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler). -func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID string) { +func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) { l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "pause", Scope: "team", Status: "info", @@ -114,16 +118,16 @@ func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID }) } -func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "resume", Scope: "team", Status: "success", @@ -131,16 +135,16 @@ func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID string) { +func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "sandbox", - ResourceID: pgtype.Text{String: sandboxID, Valid: true}, + ResourceID: optText(id.FormatSandboxID(sandboxID)), Action: "destroy", Scope: "team", Status: "warning", @@ -156,10 +160,10 @@ func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "snapshot", - ResourceID: pgtype.Text{String: name, Valid: true}, + ResourceID: optText(name), Action: "create", Scope: "team", Status: "success", @@ -173,10 +177,10 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "snapshot", - ResourceID: pgtype.Text{String: name, Valid: true}, + ResourceID: optText(name), Action: "delete", Scope: "team", Status: "warning", @@ -186,16 +190,16 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext // --- Team events (scope: team) --- -func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID, oldName, newName string) { +func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "team", - ResourceID: pgtype.Text{String: teamID, Valid: true}, + ResourceID: optText(id.FormatTeamID(teamID)), Action: "rename", Scope: "team", Status: "info", @@ -205,16 +209,16 @@ func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, te // --- API key events (scope: team) --- -func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID, keyName string) { +func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "api_key", - ResourceID: pgtype.Text{String: keyID, Valid: true}, + ResourceID: optText(id.FormatAPIKeyID(keyID)), Action: "create", Scope: "team", Status: "success", @@ -222,16 +226,16 @@ func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, }) } -func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID string) { +func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "api_key", - ResourceID: pgtype.Text{String: keyID, Valid: true}, + ResourceID: optText(id.FormatAPIKeyID(keyID)), Action: "revoke", Scope: "team", Status: "warning", @@ -241,16 +245,16 @@ func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, // --- Member events (scope: admin) --- -func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID, targetEmail, role string) { +func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "add", Scope: "admin", Status: "success", @@ -258,16 +262,16 @@ func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, tar }) } -func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID string) { +func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "remove", Scope: "admin", Status: "warning", @@ -277,14 +281,18 @@ func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) { actorType, actorID, actorName := actorFields(ac) + resourceID := "" + if ac.UserID.Valid { + resourceID = id.FormatUserID(ac.UserID) + } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: ac.UserID, Valid: ac.UserID != ""}, + ResourceID: optText(resourceID), Action: "leave", Scope: "admin", Status: "info", @@ -292,16 +300,16 @@ func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) { }) } -func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID, newRole string) { +func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) { actorType, actorID, actorName := actorFields(ac) l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: ac.TeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "member", - ResourceID: pgtype.Text{String: targetUserID, Valid: true}, + ResourceID: optText(id.FormatUserID(targetUserID)), Action: "role_update", Scope: "admin", Status: "info", @@ -311,24 +319,24 @@ func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthConte // --- Host events (scope: admin) --- -func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID string) { +func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) // For shared hosts with no owning team, use the caller's team. logTeamID := teamID - if logTeamID == "" { + if !logTeamID.Valid { logTeamID = ac.TeamID } - if logTeamID == "" { + if !logTeamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: logTeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "create", Scope: "admin", Status: "success", @@ -336,23 +344,23 @@ func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, ho }) } -func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID string) { +func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) { actorType, actorID, actorName := actorFields(ac) logTeamID := teamID - if logTeamID == "" { + if !logTeamID.Valid { logTeamID = ac.TeamID } - if logTeamID == "" { + if !logTeamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ ID: id.NewAuditLogID(), TeamID: logTeamID, ActorType: actorType, - ActorID: actorID, + ActorID: optText(actorID), ActorName: actorName, ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "delete", Scope: "admin", Status: "warning", @@ -361,9 +369,8 @@ func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, ho } // LogHostMarkedDown records a system-initiated host status transition to unreachable. -// teamID must be non-empty (BYOC hosts only); shared hosts are not logged. -func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID string) { - if teamID == "" { +func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) { + if !teamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ @@ -371,9 +378,9 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "marked_down", Scope: "admin", Status: "error", @@ -382,9 +389,8 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri } // LogHostMarkedUp records a system-initiated host status transition back to online. -// teamID must be non-empty (BYOC hosts only); shared hosts are not logged. -func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string) { - if teamID == "" { +func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) { + if !teamID.Valid { return } l.write(ctx, db.InsertAuditLogParams{ @@ -392,9 +398,9 @@ func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string TeamID: teamID, ActorType: "system", ActorID: pgtype.Text{}, - ActorName: pgtype.Text{}, + ActorName: "", ResourceType: "host", - ResourceID: pgtype.Text{String: hostID, Valid: true}, + ResourceID: optText(id.FormatHostID(hostID)), Action: "marked_up", Scope: "admin", Status: "success", diff --git a/internal/auth/context.go b/internal/auth/context.go index 22bf7957..762227cf 100644 --- a/internal/auth/context.go +++ b/internal/auth/context.go @@ -1,6 +1,10 @@ package auth -import "context" +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) type contextKey int @@ -8,14 +12,14 @@ const authCtxKey contextKey = 0 // AuthContext is stamped into request context by auth middleware. type AuthContext struct { - TeamID string - UserID string // empty when authenticated via API key - Email string // empty when authenticated via API key - Name string // empty when authenticated via API key - Role string // owner, admin, or member; empty when authenticated via API key - IsAdmin bool // platform-level admin; always false when authenticated via API key - APIKeyID string // populated when authenticated via API key; empty for JWT auth - APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth + TeamID pgtype.UUID + UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key + Email string // empty when authenticated via API key + Name string // empty when authenticated via API key + Role string // owner, admin, or member; empty when authenticated via API key + IsAdmin bool // platform-level admin; always false when authenticated via API key + APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth + APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth } // WithAuthContext returns a new context with the given AuthContext. @@ -43,7 +47,7 @@ const hostCtxKey contextKey = 1 // HostContext is stamped into request context by host token middleware. type HostContext struct { - HostID string + HostID pgtype.UUID } // WithHostContext returns a new context with the given HostContext. diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index fd1bc02d..840cd3bb 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -5,6 +5,9 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" ) const jwtExpiry = 6 * time.Hour @@ -23,16 +26,16 @@ type Claims struct { } // SignJWT signs a new 6-hour JWT for the given user. -func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) { +func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) { now := time.Now() claims := Claims{ - TeamID: teamID, + TeamID: id.FormatTeamID(teamID), Role: role, Email: email, Name: name, IsAdmin: isAdmin, RegisteredClaims: jwt.RegisteredClaims{ - Subject: userID, + Subject: id.FormatUserID(userID), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)), }, @@ -70,14 +73,15 @@ type HostClaims struct { jwt.RegisteredClaims } -// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent. -func SignHostJWT(secret []byte, hostID string) (string, error) { +// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent. +func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) { + formatted := id.FormatHostID(hostID) now := time.Now() claims := HostClaims{ Type: "host", - HostID: hostID, + HostID: formatted, RegisteredClaims: jwt.RegisteredClaims{ - Subject: hostID, + Subject: formatted, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)), }, diff --git a/internal/db/api_keys.sql.go b/internal/db/api_keys.sql.go index b4f0ffcc..4b8d3699 100644 --- a/internal/db/api_keys.sql.go +++ b/internal/db/api_keys.sql.go @@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2 ` type DeleteAPIKeyParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error { @@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_ ` type InsertAPIKeyParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` - Name string `json:"name"` - KeyHash string `json:"key_hash"` - KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + Name string `json:"name"` + KeyHash string `json:"key_hash"` + KeyPrefix string `json:"key_prefix"` + CreatedBy pgtype.UUID `json:"created_by"` } func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) { @@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC ` -func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) { +func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) { rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID) if err != nil { return nil, err @@ -126,18 +126,18 @@ ORDER BY k.created_at DESC ` type ListAPIKeysByTeamWithCreatorRow struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` Name string `json:"name"` KeyHash string `json:"key_hash"` KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` LastUsed pgtype.Timestamptz `json:"last_used"` CreatorEmail string `json:"creator_email"` } -func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) { +func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) { rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID) if err != nil { return nil, err @@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec UPDATE team_api_keys SET last_used = NOW() WHERE id = $1 ` -func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error { +func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id) return err } diff --git a/internal/db/audit.sql.go b/internal/db/audit.sql.go index 9370eca9..69b2b8ce 100644 --- a/internal/db/audit.sql.go +++ b/internal/db/audit.sql.go @@ -17,11 +17,11 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ` type InsertAuditLogParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` ActorType string `json:"actor_type"` ActorID pgtype.Text `json:"actor_id"` - ActorName pgtype.Text `json:"actor_name"` + ActorName string `json:"actor_name"` ResourceType string `json:"resource_type"` ResourceID pgtype.Text `json:"resource_id"` Action string `json:"action"` @@ -60,12 +60,12 @@ LIMIT $7 ` type ListAuditLogsParams struct { - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` Column2 []string `json:"column_2"` Column3 []string `json:"column_3"` Column4 []string `json:"column_4"` Column5 pgtype.Timestamptz `json:"column_5"` - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Limit int32 `json:"limit"` } diff --git a/internal/db/host_refresh_tokens.sql.go b/internal/db/host_refresh_tokens.sql.go index d02a0e7a..0ec162d9 100644 --- a/internal/db/host_refresh_tokens.sql.go +++ b/internal/db/host_refresh_tokens.sql.go @@ -47,8 +47,8 @@ RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at ` type InsertHostRefreshTokenParams struct { - ID string `json:"id"` - HostID string `json:"host_id"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` TokenHash string `json:"token_hash"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` } @@ -76,7 +76,7 @@ const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1 ` -func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error { +func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, revokeHostRefreshToken, id) return err } @@ -86,7 +86,7 @@ UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE host_id = $1 AND revoked_at IS NULL ` -func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error { +func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error { _, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID) return err } diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index 2d7b8e0c..8bfd8d32 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING ` type AddHostTagParams struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error { @@ -29,7 +29,7 @@ const deleteHost = `-- name: DeleteHost :exec DELETE FROM hosts WHERE id = $1 ` -func (q *Queries) DeleteHost(ctx context.Context, id string) error { +func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, deleteHost, id) return err } @@ -38,7 +38,7 @@ const getHost = `-- name: GetHost :one SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 ` -func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { +func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { row := q.db.QueryRow(ctx, getHost, id) var i Host err := row.Scan( @@ -69,8 +69,8 @@ SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_m ` type GetHostByTeamParams struct { - ID string `json:"id"` - TeamID pgtype.Text `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) { @@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag ` -func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) { +func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) { rows, err := q.db.Query(ctx, getHostTags, hostID) if err != nil { return nil, err @@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC ` -func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) { +func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) { rows, err := q.db.Query(ctx, getHostTokensByHost, hostID) if err != nil { return nil, err @@ -161,12 +161,12 @@ RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memor ` type InsertHostParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Type string `json:"type"` - TeamID pgtype.Text `json:"team_id"` - Provider pgtype.Text `json:"provider"` - AvailabilityZone pgtype.Text `json:"availability_zone"` - CreatedBy string `json:"created_by"` + TeamID pgtype.UUID `json:"team_id"` + Provider string `json:"provider"` + AvailabilityZone string `json:"availability_zone"` + CreatedBy pgtype.UUID `json:"created_by"` } func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) { @@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at ` type InsertHostTokenParams struct { - ID string `json:"id"` - HostID string `json:"host_id"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` + CreatedBy pgtype.UUID `json:"created_by"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` } @@ -414,7 +414,7 @@ const listHostsByTeam = `-- name: ListHostsByTeam :many SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC ` -func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) { +func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) { rows, err := q.db.Query(ctx, listHostsByTeam, teamID) if err != nil { return nil, err @@ -500,7 +500,7 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec UPDATE host_tokens SET used_at = NOW() WHERE id = $1 ` -func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error { +func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, markHostTokenUsed, id) return err } @@ -509,7 +509,7 @@ const markHostUnreachable = `-- name: MarkHostUnreachable :exec UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1 ` -func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error { +func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, markHostUnreachable, id) return err } @@ -528,12 +528,12 @@ WHERE id = $1 AND status = 'pending' ` type RegisterHostParams struct { - ID string `json:"id"` - Arch pgtype.Text `json:"arch"` - CpuCores pgtype.Int4 `json:"cpu_cores"` - MemoryMb pgtype.Int4 `json:"memory_mb"` - DiskGb pgtype.Int4 `json:"disk_gb"` - Address pgtype.Text `json:"address"` + ID pgtype.UUID `json:"id"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` } func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { @@ -556,8 +556,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2 ` type RemoveHostTagParams struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error { @@ -569,7 +569,7 @@ const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 ` -func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error { +func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, updateHostHeartbeat, id) return err } @@ -584,7 +584,7 @@ WHERE id = $1 // Updates last_heartbeat_at and transitions unreachable hosts back to online. // Returns 0 if no host was found (deleted), which the caller treats as 404. -func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) { +func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) { result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id) if err != nil { return 0, err @@ -597,8 +597,8 @@ UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1 ` type UpdateHostStatusParams struct { - ID string `json:"id"` - Status string `json:"status"` + ID pgtype.UUID `json:"id"` + Status string `json:"status"` } func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error { diff --git a/internal/db/metrics.sql.go b/internal/db/metrics.sql.go index 80501559..f522dc2d 100644 --- a/internal/db/metrics.sql.go +++ b/internal/db/metrics.sql.go @@ -7,6 +7,8 @@ package db import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec @@ -14,7 +16,7 @@ DELETE FROM sandbox_metric_points WHERE sandbox_id = $1 ` -func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID string) error { +func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error { _, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID) return err } @@ -25,8 +27,8 @@ WHERE sandbox_id = $1 AND tier = $2 ` type DeleteSandboxMetricPointsByTierParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` } func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error { @@ -53,7 +55,7 @@ type GetLiveMetricsRow struct { // Reads directly from sandboxes for accurate real-time current values. // CPU reserved = running + starting only (paused VMs release CPU). // RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling). -func (q *Queries) GetLiveMetrics(ctx context.Context, teamID string) (GetLiveMetricsRow, error) { +func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) { row := q.db.QueryRow(ctx, getLiveMetrics, teamID) var i GetLiveMetricsRow err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved) @@ -76,7 +78,7 @@ type GetPeakMetricsRow struct { PeakMemoryMb int32 `json:"peak_memory_mb"` } -func (q *Queries) GetPeakMetrics(ctx context.Context, teamID string) (GetPeakMetricsRow, error) { +func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) { row := q.db.QueryRow(ctx, getPeakMetrics, teamID) var i GetPeakMetricsRow err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb) @@ -91,9 +93,9 @@ ORDER BY ts ASC ` type GetSandboxMetricPointsParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` } type GetSandboxMetricPointsRow struct { @@ -134,10 +136,10 @@ VALUES ($1, $2, $3, $4) ` type InsertMetricsSnapshotParams struct { - TeamID string `json:"team_id"` - RunningCount int32 `json:"running_count"` - VcpusReserved int32 `json:"vcpus_reserved"` - MemoryMbReserved int32 `json:"memory_mb_reserved"` + TeamID pgtype.UUID `json:"team_id"` + RunningCount int32 `json:"running_count"` + VcpusReserved int32 `json:"vcpus_reserved"` + MemoryMbReserved int32 `json:"memory_mb_reserved"` } func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error { @@ -157,12 +159,12 @@ ON CONFLICT (sandbox_id, tier, ts) DO NOTHING ` type InsertSandboxMetricPointParams struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` - CpuPct float64 `json:"cpu_pct"` - MemBytes int64 `json:"mem_bytes"` - DiskBytes int64 `json:"disk_bytes"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` + CpuPct float64 `json:"cpu_pct"` + MemBytes int64 `json:"mem_bytes"` + DiskBytes int64 `json:"disk_bytes"` } func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error { @@ -210,10 +212,10 @@ GROUP BY team_id ` type SampleSandboxMetricsRow struct { - TeamID string `json:"team_id"` - RunningCount int32 `json:"running_count"` - VcpusReserved int32 `json:"vcpus_reserved"` - MemoryMbReserved int32 `json:"memory_mb_reserved"` + TeamID pgtype.UUID `json:"team_id"` + RunningCount int32 `json:"running_count"` + VcpusReserved int32 `json:"vcpus_reserved"` + MemoryMbReserved int32 `json:"memory_mb_reserved"` } // Aggregates per-team resource usage from the live sandboxes table. diff --git a/internal/db/models.go b/internal/db/models.go index 74596c6c..3aa765c5 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -9,18 +9,18 @@ import ( ) type AdminPermission struct { - ID string `json:"id"` - UserID string `json:"user_id"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` Permission string `json:"permission"` CreatedAt pgtype.Timestamptz `json:"created_at"` } type AuditLog struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` ActorType string `json:"actor_type"` ActorID pgtype.Text `json:"actor_id"` - ActorName pgtype.Text `json:"actor_name"` + ActorName string `json:"actor_name"` ResourceType string `json:"resource_type"` ResourceID pgtype.Text `json:"resource_id"` Action string `json:"action"` @@ -31,29 +31,29 @@ type AuditLog struct { } type Host struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Type string `json:"type"` - TeamID pgtype.Text `json:"team_id"` - Provider pgtype.Text `json:"provider"` - AvailabilityZone pgtype.Text `json:"availability_zone"` - Arch pgtype.Text `json:"arch"` - CpuCores pgtype.Int4 `json:"cpu_cores"` - MemoryMb pgtype.Int4 `json:"memory_mb"` - DiskGb pgtype.Int4 `json:"disk_gb"` - Address pgtype.Text `json:"address"` + TeamID pgtype.UUID `json:"team_id"` + Provider string `json:"provider"` + AvailabilityZone string `json:"availability_zone"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` Status string `json:"status"` LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"` Metadata []byte `json:"metadata"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` - CertFingerprint pgtype.Text `json:"cert_fingerprint"` + CertFingerprint string `json:"cert_fingerprint"` MtlsEnabled bool `json:"mtls_enabled"` } type HostRefreshToken struct { - ID string `json:"id"` - HostID string `json:"host_id"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` TokenHash string `json:"token_hash"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` CreatedAt pgtype.Timestamptz `json:"created_at"` @@ -61,14 +61,14 @@ type HostRefreshToken struct { } type HostTag struct { - HostID string `json:"host_id"` - Tag string `json:"tag"` + HostID pgtype.UUID `json:"host_id"` + Tag string `json:"tag"` } type HostToken struct { - ID string `json:"id"` - HostID string `json:"host_id"` - CreatedBy string `json:"created_by"` + ID pgtype.UUID `json:"id"` + HostID pgtype.UUID `json:"host_id"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` ExpiresAt pgtype.Timestamptz `json:"expires_at"` UsedAt pgtype.Timestamptz `json:"used_at"` @@ -77,14 +77,15 @@ type HostToken struct { type OauthProvider struct { Provider string `json:"provider"` ProviderID string `json:"provider_id"` - UserID string `json:"user_id"` + UserID pgtype.UUID `json:"user_id"` Email string `json:"email"` CreatedAt pgtype.Timestamptz `json:"created_at"` } type Sandbox struct { - ID string `json:"id"` - HostID string `json:"host_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` Template string `json:"template"` Status string `json:"status"` Vcpus int32 `json:"vcpus"` @@ -96,21 +97,20 @@ type Sandbox struct { StartedAt pgtype.Timestamptz `json:"started_at"` LastActiveAt pgtype.Timestamptz `json:"last_active_at"` LastUpdated pgtype.Timestamptz `json:"last_updated"` - TeamID string `json:"team_id"` } type SandboxMetricPoint struct { - SandboxID string `json:"sandbox_id"` - Tier string `json:"tier"` - Ts int64 `json:"ts"` - CpuPct float64 `json:"cpu_pct"` - MemBytes int64 `json:"mem_bytes"` - DiskBytes int64 `json:"disk_bytes"` + SandboxID pgtype.UUID `json:"sandbox_id"` + Tier string `json:"tier"` + Ts int64 `json:"ts"` + CpuPct float64 `json:"cpu_pct"` + MemBytes int64 `json:"mem_bytes"` + DiskBytes int64 `json:"disk_bytes"` } type SandboxMetricsSnapshot struct { ID int64 `json:"id"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` SampledAt pgtype.Timestamptz `json:"sampled_at"` RunningCount int32 `json:"running_count"` VcpusReserved int32 `json:"vcpus_reserved"` @@ -118,21 +118,21 @@ type SandboxMetricsSnapshot struct { } type Team struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - IsByoc bool `json:"is_byoc"` Slug string `json:"slug"` + IsByoc bool `json:"is_byoc"` + CreatedAt pgtype.Timestamptz `json:"created_at"` DeletedAt pgtype.Timestamptz `json:"deleted_at"` } type TeamApiKey struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` Name string `json:"name"` KeyHash string `json:"key_hash"` KeyPrefix string `json:"key_prefix"` - CreatedBy string `json:"created_by"` + CreatedBy pgtype.UUID `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` LastUsed pgtype.Timestamptz `json:"last_used"` } @@ -140,46 +140,46 @@ type TeamApiKey struct { type Template struct { Name string `json:"name"` Type string `json:"type"` - Vcpus pgtype.Int4 `json:"vcpus"` - MemoryMb pgtype.Int4 `json:"memory_mb"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` CreatedAt pgtype.Timestamptz `json:"created_at"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` } type TemplateBuild struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` BaseTemplate string `json:"base_template"` Recipe []byte `json:"recipe"` - Healthcheck pgtype.Text `json:"healthcheck"` + Healthcheck string `json:"healthcheck"` Vcpus int32 `json:"vcpus"` MemoryMb int32 `json:"memory_mb"` Status string `json:"status"` CurrentStep int32 `json:"current_step"` TotalSteps int32 `json:"total_steps"` Logs []byte `json:"logs"` - Error pgtype.Text `json:"error"` - SandboxID pgtype.Text `json:"sandbox_id"` - HostID pgtype.Text `json:"host_id"` + Error string `json:"error"` + SandboxID pgtype.UUID `json:"sandbox_id"` + HostID pgtype.UUID `json:"host_id"` CreatedAt pgtype.Timestamptz `json:"created_at"` StartedAt pgtype.Timestamptz `json:"started_at"` CompletedAt pgtype.Timestamptz `json:"completed_at"` } type User struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Email string `json:"email"` PasswordHash pgtype.Text `json:"password_hash"` + Name string `json:"name"` + IsAdmin bool `json:"is_admin"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` - IsAdmin bool `json:"is_admin"` - Name string `json:"name"` } type UsersTeam struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` IsDefault bool `json:"is_default"` Role string `json:"role"` CreatedAt pgtype.Timestamptz `json:"created_at"` diff --git a/internal/db/oauth.sql.go b/internal/db/oauth.sql.go index ab79eec5..0270defa 100644 --- a/internal/db/oauth.sql.go +++ b/internal/db/oauth.sql.go @@ -7,6 +7,8 @@ package db import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const getOAuthProvider = `-- name: GetOAuthProvider :one @@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4) ` type InsertOAuthProviderParams struct { - Provider string `json:"provider"` - ProviderID string `json:"provider_id"` - UserID string `json:"user_id"` - Email string `json:"email"` + Provider string `json:"provider"` + ProviderID string `json:"provider_id"` + UserID pgtype.UUID `json:"user_id"` + Email string `json:"email"` } func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error { diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 620f77e2..07effdf2 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -15,12 +15,12 @@ const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec UPDATE sandboxes SET status = 'running', last_updated = NOW() -WHERE id = ANY($1::text[]) AND status = 'missing' +WHERE id = ANY($1::uuid[]) AND status = 'missing' ` // Called by the reconciler when a host comes back online and its sandboxes are // confirmed alive. Restores only sandboxes that are in 'missing' state. -func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error { +func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error { _, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1) return err } @@ -29,12 +29,12 @@ const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec UPDATE sandboxes SET status = $2, last_updated = NOW() -WHERE id = ANY($1::text[]) +WHERE id = ANY($1::uuid[]) ` type BulkUpdateStatusByIDsParams struct { - Column1 []string `json:"column_1"` - Status string `json:"status"` + Column1 []pgtype.UUID `json:"column_1"` + Status string `json:"status"` } func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error { @@ -43,14 +43,15 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu } const getSandbox = `-- name: GetSandbox :one -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 ` -func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { +func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) { row := q.db.QueryRow(ctx, getSandbox, id) var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -63,18 +64,17 @@ func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) { &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ) return i, err } const getSandboxByTeam = `-- name: GetSandboxByTeam :one -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 AND team_id = $2 ` type GetSandboxByTeamParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) { @@ -82,6 +82,7 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -94,7 +95,6 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ) return i, err } @@ -102,18 +102,18 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara const insertSandbox = `-- name: InsertSandbox :one INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type InsertSandboxParams struct { - ID string `json:"id"` - TeamID string `json:"team_id"` - HostID string `json:"host_id"` - Template string `json:"template"` - Status string `json:"status"` - Vcpus int32 `json:"vcpus"` - MemoryMb int32 `json:"memory_mb"` - TimeoutSec int32 `json:"timeout_sec"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` + Template string `json:"template"` + Status string `json:"status"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TimeoutSec int32 `json:"timeout_sec"` } func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) { @@ -130,6 +130,7 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -142,18 +143,17 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ) return i, err } const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') ORDER BY created_at DESC ` -func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) { +func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) { rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID) if err != nil { return nil, err @@ -164,6 +164,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -176,7 +177,6 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ); err != nil { return nil, err } @@ -189,7 +189,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) } const listSandboxes = `-- name: ListSandboxes :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes ORDER BY created_at DESC ` func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { @@ -203,6 +203,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -215,7 +216,6 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ); err != nil { return nil, err } @@ -228,14 +228,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { } const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) ORDER BY created_at DESC ` type ListSandboxesByHostAndStatusParams struct { - HostID string `json:"host_id"` - Column2 []string `json:"column_2"` + HostID pgtype.UUID `json:"host_id"` + Column2 []string `json:"column_2"` } func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) { @@ -249,6 +249,7 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -261,7 +262,6 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ); err != nil { return nil, err } @@ -274,12 +274,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand } const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many -SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE team_id = $1 AND status NOT IN ('stopped', 'error') ORDER BY created_at DESC ` -func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) { +func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) { rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID) if err != nil { return nil, err @@ -290,6 +290,7 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San var i Sandbox if err := rows.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -302,7 +303,6 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ); err != nil { return nil, err } @@ -324,7 +324,7 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending') // Called when the host monitor marks a host unreachable. // Marks running/starting/pending sandboxes on that host as 'missing' so users see // the sandbox is not currently reachable, without permanently losing the record. -func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error { +func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error { _, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID) return err } @@ -337,7 +337,7 @@ WHERE id = $1 ` type UpdateLastActiveParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` LastActiveAt pgtype.Timestamptz `json:"last_active_at"` } @@ -355,11 +355,11 @@ SET status = 'running', last_active_at = $4, last_updated = NOW() WHERE id = $1 -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type UpdateSandboxRunningParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` HostIp string `json:"host_ip"` GuestIp string `json:"guest_ip"` StartedAt pgtype.Timestamptz `json:"started_at"` @@ -375,6 +375,7 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -387,7 +388,6 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ) return i, err } @@ -397,12 +397,12 @@ UPDATE sandboxes SET status = $2, last_updated = NOW() WHERE id = $1 -RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type UpdateSandboxStatusParams struct { - ID string `json:"id"` - Status string `json:"status"` + ID pgtype.UUID `json:"id"` + Status string `json:"status"` } func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) { @@ -410,6 +410,7 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat var i Sandbox err := row.Scan( &i.ID, + &i.TeamID, &i.HostID, &i.Template, &i.Status, @@ -422,7 +423,6 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, - &i.TeamID, ) return i, err } diff --git a/internal/db/teams.sql.go b/internal/db/teams.sql.go index a00f5efc..334141ff 100644 --- a/internal/db/teams.sql.go +++ b/internal/db/teams.sql.go @@ -16,8 +16,8 @@ DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2 ` type DeleteTeamMemberParams struct { - TeamID string `json:"team_id"` - UserID string `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` } func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error { @@ -26,7 +26,7 @@ func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberPara } const getBYOCTeams = `-- name: GetBYOCTeams :many -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at ` func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { @@ -41,9 +41,9 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { if err := rows.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ); err != nil { return nil, err @@ -57,46 +57,46 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) { } const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one -SELECT t.id, t.name, t.created_at, t.is_byoc, t.slug, t.deleted_at FROM teams t +SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t JOIN users_teams ut ON ut.team_id = t.id WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL LIMIT 1 ` -func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) { +func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) { row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID) var i Team err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err } const getTeam = `-- name: GetTeam :one -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE id = $1 +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1 ` -func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) { +func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) { row := q.db.QueryRow(ctx, getTeam, id) var i Team err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err } const getTeamBySlug = `-- name: GetTeamBySlug :one -SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL +SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL ` func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) { @@ -105,9 +105,9 @@ func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err @@ -122,14 +122,14 @@ ORDER BY ut.created_at ` type GetTeamMembersRow struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` Email string `json:"email"` Role string `json:"role"` JoinedAt pgtype.Timestamptz `json:"joined_at"` } -func (q *Queries) GetTeamMembers(ctx context.Context, teamID string) ([]GetTeamMembersRow, error) { +func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) { rows, err := q.db.Query(ctx, getTeamMembers, teamID) if err != nil { return nil, err @@ -160,8 +160,8 @@ SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE use ` type GetTeamMembershipParams struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) { @@ -186,7 +186,7 @@ ORDER BY ut.created_at ` type GetTeamsForUserRow struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` Slug string `json:"slug"` IsByoc bool `json:"is_byoc"` @@ -195,7 +195,7 @@ type GetTeamsForUserRow struct { Role string `json:"role"` } -func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeamsForUserRow, error) { +func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) { rows, err := q.db.Query(ctx, getTeamsForUser, userID) if err != nil { return nil, err @@ -226,13 +226,13 @@ func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeam const insertTeam = `-- name: InsertTeam :one INSERT INTO teams (id, name, slug) VALUES ($1, $2, $3) -RETURNING id, name, created_at, is_byoc, slug, deleted_at +RETURNING id, name, slug, is_byoc, created_at, deleted_at ` type InsertTeamParams struct { - ID string `json:"id"` - Name string `json:"name"` - Slug string `json:"slug"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` + Slug string `json:"slug"` } func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) { @@ -241,9 +241,9 @@ func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, e err := row.Scan( &i.ID, &i.Name, - &i.CreatedAt, - &i.IsByoc, &i.Slug, + &i.IsByoc, + &i.CreatedAt, &i.DeletedAt, ) return i, err @@ -255,10 +255,10 @@ VALUES ($1, $2, $3, $4) ` type InsertTeamMemberParams struct { - UserID string `json:"user_id"` - TeamID string `json:"team_id"` - IsDefault bool `json:"is_default"` - Role string `json:"role"` + UserID pgtype.UUID `json:"user_id"` + TeamID pgtype.UUID `json:"team_id"` + IsDefault bool `json:"is_default"` + Role string `json:"role"` } func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error { @@ -276,8 +276,8 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1 ` type SetTeamBYOCParams struct { - ID string `json:"id"` - IsByoc bool `json:"is_byoc"` + ID pgtype.UUID `json:"id"` + IsByoc bool `json:"is_byoc"` } func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error { @@ -289,7 +289,7 @@ const softDeleteTeam = `-- name: SoftDeleteTeam :exec UPDATE teams SET deleted_at = NOW() WHERE id = $1 ` -func (q *Queries) SoftDeleteTeam(ctx context.Context, id string) error { +func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error { _, err := q.db.Exec(ctx, softDeleteTeam, id) return err } @@ -299,9 +299,9 @@ UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2 ` type UpdateMemberRoleParams struct { - TeamID string `json:"team_id"` - UserID string `json:"user_id"` - Role string `json:"role"` + TeamID pgtype.UUID `json:"team_id"` + UserID pgtype.UUID `json:"user_id"` + Role string `json:"role"` } func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error { @@ -314,8 +314,8 @@ UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL ` type UpdateTeamNameParams struct { - ID string `json:"id"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` } func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error { diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go index 8142d294..9e770ee1 100644 --- a/internal/db/template_builds.sql.go +++ b/internal/db/template_builds.sql.go @@ -15,7 +15,7 @@ const getTemplateBuild = `-- name: GetTemplateBuild :one SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at FROM template_builds WHERE id = $1 ` -func (q *Queries) GetTemplateBuild(ctx context.Context, id string) (TemplateBuild, error) { +func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) { row := q.db.QueryRow(ctx, getTemplateBuild, id) var i TemplateBuild err := row.Scan( @@ -47,11 +47,11 @@ RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status ` type InsertTemplateBuildParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Name string `json:"name"` BaseTemplate string `json:"base_template"` Recipe []byte `json:"recipe"` - Healthcheck pgtype.Text `json:"healthcheck"` + Healthcheck string `json:"healthcheck"` Vcpus int32 `json:"vcpus"` MemoryMb int32 `json:"memory_mb"` TotalSteps int32 `json:"total_steps"` @@ -140,8 +140,8 @@ WHERE id = $1 ` type UpdateBuildErrorParams struct { - ID string `json:"id"` - Error pgtype.Text `json:"error"` + ID pgtype.UUID `json:"id"` + Error string `json:"error"` } func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error { @@ -156,9 +156,9 @@ WHERE id = $1 ` type UpdateBuildProgressParams struct { - ID string `json:"id"` - CurrentStep int32 `json:"current_step"` - Logs []byte `json:"logs"` + ID pgtype.UUID `json:"id"` + CurrentStep int32 `json:"current_step"` + Logs []byte `json:"logs"` } func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error { @@ -173,9 +173,9 @@ WHERE id = $1 ` type UpdateBuildSandboxParams struct { - ID string `json:"id"` - SandboxID pgtype.Text `json:"sandbox_id"` - HostID pgtype.Text `json:"host_id"` + ID pgtype.UUID `json:"id"` + SandboxID pgtype.UUID `json:"sandbox_id"` + HostID pgtype.UUID `json:"host_id"` } func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error { @@ -193,8 +193,8 @@ RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status ` type UpdateBuildStatusParams struct { - ID string `json:"id"` - Status string `json:"status"` + ID pgtype.UUID `json:"id"` + Status string `json:"status"` } func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) { diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go index cafae692..8703bc92 100644 --- a/internal/db/templates.sql.go +++ b/internal/db/templates.sql.go @@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2 ` type DeleteTemplateByTeamParams struct { - Name string `json:"name"` - TeamID string `json:"team_id"` + Name string `json:"name"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error { @@ -58,8 +58,8 @@ SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templa ` type GetTemplateByTeamParams struct { - Name string `json:"name"` - TeamID string `json:"team_id"` + Name string `json:"name"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) { @@ -86,10 +86,10 @@ RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id type InsertTemplateParams struct { Name string `json:"name"` Type string `json:"type"` - Vcpus pgtype.Int4 `json:"vcpus"` - MemoryMb pgtype.Int4 `json:"memory_mb"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` SizeBytes int64 `json:"size_bytes"` - TeamID string `json:"team_id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) { @@ -150,7 +150,7 @@ const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC ` -func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) { +func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) { rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID) if err != nil { return nil, err @@ -183,8 +183,8 @@ SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templa ` type ListTemplatesByTeamAndTypeParams struct { - TeamID string `json:"team_id"` - Type string `json:"type"` + TeamID pgtype.UUID `json:"team_id"` + Type string `json:"type"` } func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) { diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go index 50ba287e..9de866bb 100644 --- a/internal/db/users.sql.go +++ b/internal/db/users.sql.go @@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2 ` type DeleteAdminPermissionParams struct { - UserID string `json:"user_id"` - Permission string `json:"permission"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error { @@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission ` -func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) { +func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) { rows, err := q.db.Query(ctx, getAdminPermissions, userID) if err != nil { return nil, err @@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm } const getAdminUsers = `-- name: GetAdminUsers :many -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE is_admin = TRUE ORDER BY created_at +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at ` func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { @@ -71,10 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) { } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE email = $1 +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1 ` func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { @@ -97,29 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE id = $1 +SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1 ` -func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) { +func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) { row := q.db.QueryRow(ctx, getUserByID, id) var i User err := row.Scan( &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -131,8 +131,8 @@ SELECT EXISTS( ` type HasAdminPermissionParams struct { - UserID string `json:"user_id"` - Permission string `json:"permission"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) { @@ -148,9 +148,9 @@ VALUES ($1, $2, $3) ` type InsertAdminPermissionParams struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Permission string `json:"permission"` + ID pgtype.UUID `json:"id"` + UserID pgtype.UUID `json:"user_id"` + Permission string `json:"permission"` } func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error { @@ -161,11 +161,11 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm const insertUser = `-- name: InsertUser :one INSERT INTO users (id, email, password_hash, name) VALUES ($1, $2, $3, $4) -RETURNING id, email, password_hash, created_at, updated_at, is_admin, name +RETURNING id, email, password_hash, name, is_admin, created_at, updated_at ` type InsertUserParams struct { - ID string `json:"id"` + ID pgtype.UUID `json:"id"` Email string `json:"email"` PasswordHash pgtype.Text `json:"password_hash"` Name string `json:"name"` @@ -183,10 +183,10 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -194,13 +194,13 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e const insertUserOAuth = `-- name: InsertUserOAuth :one INSERT INTO users (id, email, name) VALUES ($1, $2, $3) -RETURNING id, email, password_hash, created_at, updated_at, is_admin, name +RETURNING id, email, password_hash, name, is_admin, created_at, updated_at ` type InsertUserOAuthParams struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` } func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) { @@ -210,10 +210,10 @@ func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams &i.ID, &i.Email, &i.PasswordHash, + &i.Name, + &i.IsAdmin, &i.CreatedAt, &i.UpdatedAt, - &i.IsAdmin, - &i.Name, ) return i, err } @@ -223,8 +223,8 @@ SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10 ` type SearchUsersByEmailPrefixRow struct { - ID string `json:"id"` - Email string `json:"email"` + ID pgtype.UUID `json:"id"` + Email string `json:"email"` } func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) { @@ -252,8 +252,8 @@ UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1 ` type SetUserAdminParams struct { - ID string `json:"id"` - IsAdmin bool `json:"is_admin"` + ID pgtype.UUID `json:"id"` + IsAdmin bool `json:"is_admin"` } func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error { @@ -266,8 +266,8 @@ UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1 ` type UpdateUserNameParams struct { - ID string `json:"id"` - Name string `json:"name"` + ID pgtype.UUID `json:"id"` + Name string `json:"name"` } func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error { diff --git a/internal/id/id.go b/internal/id/id.go index 836af6df..c27869ad 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -4,8 +4,114 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) +// --- Generation --- + +// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use. +func newUUID() pgtype.UUID { + return pgtype.UUID{Bytes: uuid.New(), Valid: true} +} + +func NewSandboxID() pgtype.UUID { return newUUID() } +func NewUserID() pgtype.UUID { return newUUID() } +func NewTeamID() pgtype.UUID { return newUUID() } +func NewAPIKeyID() pgtype.UUID { return newUUID() } +func NewHostID() pgtype.UUID { return newUUID() } +func NewHostTokenID() pgtype.UUID { return newUUID() } +func NewRefreshTokenID() pgtype.UUID { return newUUID() } +func NewAuditLogID() pgtype.UUID { return newUUID() } +func NewBuildID() pgtype.UUID { return newUUID() } +func NewAdminPermissionID() pgtype.UUID { return newUUID() } + +// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars. +// Templates use TEXT primary keys (not UUID), so this stays as a string. +func NewSnapshotName() string { + return "template-" + hex8() +} + +// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy". +func NewTeamSlug() string { + b := make([]byte, 6) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:]) +} + +// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy). +func NewRegistrationToken() string { + return hexToken(32) +} + +// NewRefreshToken generates a 64-char hex token (32 bytes of entropy). +func NewRefreshToken() string { + return hexToken(32) +} + +// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) --- + +const ( + PrefixSandbox = "sb-" + PrefixUser = "usr-" + PrefixTeam = "team-" + PrefixAPIKey = "key-" + PrefixHost = "host-" + PrefixHostToken = "htok-" + PrefixRefreshToken = "hrt-" + PrefixAuditLog = "log-" + PrefixBuild = "bld-" + PrefixAdminPermission = "perm-" +) + +func formatUUID(prefix string, id pgtype.UUID) string { + return prefix + uuid.UUID(id.Bytes).String() +} + +func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) } +func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) } +func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) } +func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) } +func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) } +func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) } +func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) } +func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) } +func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) } + +// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) --- + +func parseUUID(prefix, s string) (pgtype.UUID, error) { + if !strings.HasPrefix(s, prefix) { + return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s) + } + u, err := uuid.Parse(strings.TrimPrefix(s, prefix)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err) + } + return pgtype.UUID{Bytes: u, Valid: true}, nil +} + +func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) } +func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) } +func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) } +func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) } +func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) } +func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) } +func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) } +func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) } + +// --- Well-known IDs --- + +// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources +// (e.g. base templates, shared infrastructure). +var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true} + +// --- Helpers --- + func hex8() string { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { @@ -14,78 +120,8 @@ func hex8() string { return hex.EncodeToString(b) } -// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars. -func NewSandboxID() string { - return "sb-" + hex8() -} - -// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars. -func NewSnapshotName() string { - return "template-" + hex8() -} - -// NewUserID generates a new user ID in the format "usr-" + 8 hex chars. -func NewUserID() string { - return "usr-" + hex8() -} - -// NewTeamID generates a new team ID in the format "team-" + 8 hex chars. -func NewTeamID() string { - return "team-" + hex8() -} - -// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy" -// where each part is 3 random bytes encoded as hex (6 hex chars each). -func NewTeamSlug() string { - b := make([]byte, 6) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("crypto/rand failed: %v", err)) - } - return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:]) -} - -// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars. -func NewAPIKeyID() string { - return "key-" + hex8() -} - -// NewHostID generates a new host ID in the format "host-" + 8 hex chars. -func NewHostID() string { - return "host-" + hex8() -} - -// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars. -func NewHostTokenID() string { - return "htok-" + hex8() -} - -// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy). -func NewRegistrationToken() string { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - panic(fmt.Sprintf("crypto/rand failed: %v", err)) - } - return hex.EncodeToString(b) -} - -// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars. -func NewRefreshTokenID() string { - return "hrt-" + hex8() -} - -// NewAuditLogID generates a new audit log ID in the format "log-" + 8 hex chars. -func NewAuditLogID() string { - return "log-" + hex8() -} - -// NewBuildID generates a new template build ID in the format "bld-" + 8 hex chars. -func NewBuildID() string { - return "bld-" + hex8() -} - -// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token. -func NewRefreshToken() string { - b := make([]byte, 32) +func hexToken(nBytes int) string { + b := make([]byte, nBytes) if _, err := rand.Read(b); err != nil { panic(fmt.Sprintf("crypto/rand failed: %v", err)) } diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go index c6e724b1..f1341653 100644 --- a/internal/lifecycle/hostpool.go +++ b/internal/lifecycle/hostpool.go @@ -8,6 +8,7 @@ import ( "time" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -53,10 +54,10 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen // GetForHost is a convenience wrapper that extracts the address from a db.Host // and returns an error if the host has no address recorded yet. func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) { - if !h.Address.Valid || h.Address.String == "" { - return nil, fmt.Errorf("host %s has no address", h.ID) + if h.Address == "" { + return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID)) } - return p.Get(h.ID, h.Address.String), nil + return p.Get(id.FormatHostID(h.ID), h.Address), nil } // Evict removes the cached client for the given host, forcing a new client to be diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 88b058c6..15453eb4 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -96,7 +96,7 @@ func New(cfg Config) *Manager { // If sandboxID is empty, a new ID is generated. func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec int) (*models.Sandbox, error) { if sandboxID == "" { - sandboxID = id.NewSandboxID() + sandboxID = id.FormatSandboxID(id.NewSandboxID()) } if vcpus <= 0 { diff --git a/internal/scheduler/round_robin.go b/internal/scheduler/round_robin.go index 31433a0b..c2ab0f40 100644 --- a/internal/scheduler/round_robin.go +++ b/internal/scheduler/round_robin.go @@ -5,6 +5,8 @@ import ( "fmt" "sync/atomic" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" ) @@ -15,7 +17,7 @@ type HostScheduler interface { // For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID // are considered. For non-BYOC teams, only online regular (platform) hosts // are considered. Returns an error if no suitable host is available. - SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) + SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) } // RoundRobinScheduler cycles through eligible online hosts in round-robin order. @@ -32,7 +34,7 @@ func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler { } // SelectHost returns the next eligible online host in round-robin order. -func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) { +func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) { hosts, err := s.db.ListActiveHosts(ctx) if err != nil { return db.Host{}, fmt.Errorf("list hosts: %w", err) @@ -40,12 +42,12 @@ func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isB var eligible []db.Host for _, h := range hosts { - if h.Status != "online" || !h.Address.Valid || h.Address.String == "" { + if h.Status != "online" || h.Address == "" { continue } if isByoc { // BYOC team: only use hosts belonging to this team. - if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID.String != teamID { + if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID { continue } } else { diff --git a/internal/service/apikey.go b/internal/service/apikey.go index c49ddcaa..7a2b073a 100644 --- a/internal/service/apikey.go +++ b/internal/service/apikey.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" @@ -22,7 +24,7 @@ type APIKeyCreateResult struct { } // Create generates a new API key for the given team. -func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) { +func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) { if name == "" { name = "Unnamed API Key" } @@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) } // List returns all API keys belonging to the given team. -func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) { +func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) { return s.DB.ListAPIKeysByTeam(ctx, teamID) } // ListWithCreator returns all API keys for the team, joined with the creator's email. -func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) { +func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) { return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID) } // Delete removes an API key by ID, scoped to the given team. -func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error { +func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error { return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID}) } diff --git a/internal/service/audit.go b/internal/service/audit.go index 53061421..67faafa2 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" ) const auditMaxLimit = 200 @@ -31,13 +32,13 @@ type AuditEntry struct { // AuditListParams controls the ListAuditLogs query. type AuditListParams struct { - TeamID string - AdminScoped bool // true → include admin-scoped events; false → team-scoped only - ResourceTypes []string // empty = no filter; multiple values = OR match - Actions []string // empty = no filter; multiple values = OR match - Before time.Time // zero = no cursor (start from latest) - BeforeID string // tie-breaker: id of the last item at the Before timestamp; empty = no tie-break - Limit int // clamped to auditMaxLimit by the handler + TeamID pgtype.UUID + AdminScoped bool // true → include admin-scoped events; false → team-scoped only + ResourceTypes []string // empty = no filter; multiple values = OR match + Actions []string // empty = no filter; multiple values = OR match + Before time.Time // zero = no cursor (start from latest) + BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break + Limit int // clamped to auditMaxLimit by the handler } // AuditService provides the read side of the audit log. @@ -94,11 +95,11 @@ func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntr _ = json.Unmarshal(row.Metadata, &meta) } entries[i] = AuditEntry{ - ID: row.ID, - TeamID: row.TeamID, + ID: id.FormatAuditLogID(row.ID), + TeamID: id.FormatTeamID(row.TeamID), ActorType: row.ActorType, ActorID: row.ActorID.String, - ActorName: row.ActorName.String, + ActorName: row.ActorName, ResourceType: row.ResourceType, ResourceID: row.ResourceID.String, Action: row.Action, diff --git a/internal/service/build.go b/internal/service/build.go index 3c7975d0..1bd82a8f 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -19,11 +19,10 @@ import ( ) const ( - buildQueueKey = "wrenn:build_queue" - buildCommandTimeout = 30 * time.Second - healthcheckInterval = 1 * time.Second - healthcheckTimeout = 60 * time.Second - platformTeamID = "platform" + buildQueueKey = "wrenn:build_queue" + buildCommandTimeout = 30 * time.Second + healthcheckInterval = 1 * time.Second + healthcheckTimeout = 60 * time.Second ) // buildAgentClient is the subset of the host agent client used by the build worker. @@ -82,13 +81,14 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp } buildID := id.NewBuildID() + buildIDStr := id.FormatBuildID(buildID) build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{ ID: buildID, Name: p.Name, BaseTemplate: p.BaseTemplate, Recipe: recipeJSON, - Healthcheck: pgtype.Text{String: p.Healthcheck, Valid: p.Healthcheck != ""}, + Healthcheck: p.Healthcheck, Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TotalSteps: int32(len(p.Recipe)), @@ -97,8 +97,8 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) } - // Enqueue build ID to Redis for workers to pick up. - if err := s.Redis.RPush(ctx, buildQueueKey, buildID).Err(); err != nil { + // Enqueue build ID (as formatted string) to Redis for workers to pick up. + if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil { return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err) } @@ -106,7 +106,7 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp } // Get returns a single build by ID. -func (s *BuildService) Get(ctx context.Context, buildID string) (db.TemplateBuild, error) { +func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) { return s.DB.GetTemplateBuild(ctx, buildID) } @@ -140,15 +140,21 @@ func (s *BuildService) worker(ctx context.Context, workerID int) { time.Sleep(time.Second) continue } - // result[0] is the key, result[1] is the build ID. - buildID := result[1] - log.Info("picked up build", "build_id", buildID) - s.executeBuild(ctx, buildID) + // result[0] is the key, result[1] is the build ID (formatted string). + buildIDStr := result[1] + log.Info("picked up build", "build_id", buildIDStr) + s.executeBuild(ctx, buildIDStr) } } -func (s *BuildService) executeBuild(ctx context.Context, buildID string) { - log := slog.With("build_id", buildID) +func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { + log := slog.With("build_id", buildIDStr) + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + log.Error("invalid build ID from queue", "error", err) + return + } build, err := s.DB.GetTemplateBuild(ctx, buildID) if err != nil { @@ -172,7 +178,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { } // Pick a platform host and create a sandbox. - host, err := s.Scheduler.SelectHost(ctx, platformTeamID, false) + host, err := s.Scheduler.SelectHost(ctx, id.PlatformTeamID, false) if err != nil { s.failBuild(ctx, buildID, fmt.Sprintf("no host available: %v", err)) return @@ -185,10 +191,11 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { } sandboxID := id.NewSandboxID() - log = log.With("sandbox_id", sandboxID, "host_id", host.ID) + sandboxIDStr := id.FormatSandboxID(sandboxID) + log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID)) resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Template: build.BaseTemplate, Vcpus: build.Vcpus, MemoryMb: build.MemoryMb, @@ -203,8 +210,8 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // Record sandbox/host association. _ = s.DB.UpdateBuildSandbox(ctx, db.UpdateBuildSandboxParams{ ID: buildID, - SandboxID: pgtype.Text{String: sandboxID, Valid: true}, - HostID: pgtype.Text{String: host.ID, Valid: true}, + SandboxID: sandboxID, + HostID: host.ID, }) // Execute recipe commands. @@ -216,7 +223,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { start := time.Now() execResp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: "/bin/sh", Args: []string{"-c", cmd}, TimeoutSec: int32(buildCommandTimeout.Seconds()), @@ -234,7 +241,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { entry.Ok = false logs = append(logs, entry) s.updateLogs(ctx, buildID, i+1, logs) - s.destroySandbox(ctx, agent, sandboxID) + s.destroySandbox(ctx, agent, sandboxIDStr) s.failBuild(ctx, buildID, fmt.Sprintf("step %d exec error: %v", i+1, err)) return } @@ -248,7 +255,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { s.updateLogs(ctx, buildID, i+1, logs) if execResp.Msg.ExitCode != 0 { - s.destroySandbox(ctx, agent, sandboxID) + s.destroySandbox(ctx, agent, sandboxIDStr) s.failBuild(ctx, buildID, fmt.Sprintf("step %d failed with exit code %d", i+1, execResp.Msg.ExitCode)) return } @@ -256,10 +263,10 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // Healthcheck or direct snapshot. var sizeBytes int64 - if build.Healthcheck.Valid && build.Healthcheck.String != "" { - log.Info("running healthcheck", "cmd", build.Healthcheck.String) - if err := s.waitForHealthcheck(ctx, agent, sandboxID, build.Healthcheck.String); err != nil { - s.destroySandbox(ctx, agent, sandboxID) + if build.Healthcheck != "" { + log.Info("running healthcheck", "cmd", build.Healthcheck) + if err := s.waitForHealthcheck(ctx, agent, sandboxIDStr, build.Healthcheck); err != nil { + s.destroySandbox(ctx, agent, sandboxIDStr) s.failBuild(ctx, buildID, fmt.Sprintf("healthcheck failed: %v", err)) return } @@ -267,11 +274,11 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // Healthcheck passed → full snapshot (with memory/CPU state). log.Info("healthcheck passed, creating snapshot") snapResp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Name: build.Name, })) if err != nil { - s.destroySandbox(ctx, agent, sandboxID) + s.destroySandbox(ctx, agent, sandboxIDStr) s.failBuild(ctx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) return } @@ -280,11 +287,11 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // No healthcheck → image-only template (rootfs only). log.Info("no healthcheck, flattening rootfs") flatResp, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Name: build.Name, })) if err != nil { - s.destroySandbox(ctx, agent, sandboxID) + s.destroySandbox(ctx, agent, sandboxIDStr) s.failBuild(ctx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) return } @@ -293,17 +300,17 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { // Insert into templates table as a global (platform) template. templateType := "base" - if build.Healthcheck.Valid && build.Healthcheck.String != "" { + if build.Healthcheck != "" { templateType = "snapshot" } if _, err := s.DB.InsertTemplate(ctx, db.InsertTemplateParams{ Name: build.Name, Type: templateType, - Vcpus: pgtype.Int4{Int32: build.Vcpus, Valid: true}, - MemoryMb: pgtype.Int4{Int32: build.MemoryMb, Valid: true}, + Vcpus: build.Vcpus, + MemoryMb: build.MemoryMb, SizeBytes: sizeBytes, - TeamID: platformTeamID, + TeamID: id.PlatformTeamID, }); err != nil { log.Error("failed to insert template record", "error", err) // Build succeeded on disk, just DB record failed — don't mark as failed. @@ -323,7 +330,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildID string) { log.Info("template build completed successfully", "name", build.Name) } -func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxID, cmd string) error { +func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr, cmd string) error { deadline := time.NewTimer(healthcheckTimeout) defer deadline.Stop() ticker := time.NewTicker(healthcheckInterval) @@ -338,7 +345,7 @@ func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentC case <-ticker.C: execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Cmd: "/bin/sh", Args: []string{"-c", cmd}, TimeoutSec: 10, @@ -357,7 +364,7 @@ func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentC } } -func (s *BuildService) updateLogs(ctx context.Context, buildID string, step int, logs []BuildLogEntry) { +func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []BuildLogEntry) { logsJSON, err := json.Marshal(logs) if err != nil { slog.Warn("failed to marshal build logs", "error", err) @@ -372,26 +379,26 @@ func (s *BuildService) updateLogs(ctx context.Context, buildID string, step int, } } -func (s *BuildService) failBuild(_ context.Context, buildID, errMsg string) { - slog.Error("build failed", "build_id", buildID, "error", errMsg) +func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) { + slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg) // Use a detached context so DB writes survive parent context cancellation (e.g. shutdown). ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{ ID: buildID, - Error: pgtype.Text{String: errMsg, Valid: true}, + Error: errMsg, }); err != nil { - slog.Error("failed to update build error", "build_id", buildID, "error", err) + slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err) } } -func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxID string) { +func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) { // Use a detached context so cleanup succeeds even during shutdown. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil { - slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxID, "error", err) + slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err) } } diff --git a/internal/service/host.go b/internal/service/host.go index b3538dfb..195b9ff8 100644 --- a/internal/service/host.go +++ b/internal/service/host.go @@ -32,10 +32,10 @@ type HostService struct { // HostCreateParams holds the parameters for creating a host. type HostCreateParams struct { Type string - TeamID string // required for BYOC, empty for regular + TeamID pgtype.UUID // required for BYOC, zero value for regular Provider string AvailabilityZone string - RequestingUserID string + RequestingUserID pgtype.UUID IsRequestorAdmin bool } @@ -103,7 +103,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat } } else { // BYOC: platform admin, or team owner/admin. - if p.TeamID == "" { + if !p.TeamID.Valid { return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts") } if !p.IsRequestorAdmin { @@ -124,7 +124,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat } // Validate team exists, is not deleted, and has BYOC enabled. - if p.TeamID != "" { + if p.TeamID.Valid { team, err := s.DB.GetTeam(ctx, p.TeamID) if err != nil || team.DeletedAt.Valid { return HostCreateResult{}, fmt.Errorf("invalid request: team not found") @@ -136,25 +136,12 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat hostID := id.NewHostID() - var teamID pgtype.Text - if p.TeamID != "" { - teamID = pgtype.Text{String: p.TeamID, Valid: true} - } - var provider pgtype.Text - if p.Provider != "" { - provider = pgtype.Text{String: p.Provider, Valid: true} - } - var az pgtype.Text - if p.AvailabilityZone != "" { - az = pgtype.Text{String: p.AvailabilityZone, Valid: true} - } - host, err := s.DB.InsertHost(ctx, db.InsertHostParams{ ID: hostID, Type: p.Type, - TeamID: teamID, - Provider: provider, - AvailabilityZone: az, + TeamID: p.TeamID, + Provider: p.Provider, + AvailabilityZone: p.AvailabilityZone, CreatedBy: p.RequestingUserID, }) if err != nil { @@ -166,8 +153,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat tokenID := id.NewHostTokenID() payload, _ := json.Marshal(regTokenPayload{ - HostID: hostID, - TokenID: tokenID, + HostID: id.FormatHostID(hostID), + TokenID: id.FormatHostTokenID(tokenID), }) if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) @@ -180,7 +167,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat CreatedBy: p.RequestingUserID, ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) } return HostCreateResult{Host: host, RegistrationToken: token}, nil @@ -189,7 +176,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat // RegenerateToken issues a new registration token for a host still in "pending" // status. This allows retry when a previous registration attempt failed after // the original token was consumed. -func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) { +func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return HostCreateResult{}, fmt.Errorf("host not found: %w", err) @@ -202,7 +189,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI if host.Type != "byoc" { return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts") } - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team") } membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ @@ -224,8 +211,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI tokenID := id.NewHostTokenID() payload, _ := json.Marshal(regTokenPayload{ - HostID: hostID, - TokenID: tokenID, + HostID: id.FormatHostID(hostID), + TokenID: id.FormatHostTokenID(tokenID), }) if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) @@ -238,7 +225,7 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI CreatedBy: userID, ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, }); err != nil { - slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err) } return HostCreateResult{Host: host, RegistrationToken: token}, nil @@ -262,24 +249,33 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("corrupted registration token") } - if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil { + hostID, err := id.ParseHostID(payload.HostID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err) + } + tokenID, err := id.ParseHostTokenID(payload.TokenID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err) + } + + if _, err := s.DB.GetHost(ctx, hostID); err != nil { return HostRegisterResult{}, fmt.Errorf("host not found: %w", err) } // Sign JWT before mutating DB — if signing fails, the host stays pending. - hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID) + hostJWT, err := auth.SignHostJWT(s.JWT, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) } // Atomically update only if still pending (defense-in-depth against races). rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ - ID: payload.HostID, - Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""}, - CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0}, - MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0}, - DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0}, - Address: pgtype.Text{String: p.Address, Valid: p.Address != ""}, + ID: hostID, + Arch: p.Arch, + CpuCores: p.CPUCores, + MemoryMb: p.MemoryMB, + DiskGb: p.DiskGB, + Address: p.Address, }) if err != nil { return HostRegisterResult{}, fmt.Errorf("register host: %w", err) @@ -289,18 +285,18 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR } // Mark audit trail. - if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil { + if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil { slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err) } // Issue a long-lived refresh token. - refreshToken, err := s.issueRefreshToken(ctx, payload.HostID) + refreshToken, err := s.issueRefreshToken(ctx, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err) } // Re-fetch the host to get the updated state. - host, err := s.DB.GetHost(ctx, payload.HostID) + host, err := s.DB.GetHost(ctx, hostID) if err != nil { return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) } @@ -349,7 +345,7 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef // issueRefreshToken creates a new refresh token record in the DB and returns // the opaque token string. -func (s *HostService) issueRefreshToken(ctx context.Context, hostID string) (string, error) { +func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) { token := id.NewRefreshToken() hash := hashToken(token) now := time.Now() @@ -375,7 +371,7 @@ func hashToken(token string) string { // Heartbeat updates the last heartbeat timestamp for a host and transitions // any 'unreachable' host back to 'online'. Returns a "host not found" error // (which becomes 404) if the host record no longer exists (e.g., was deleted). -func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { +func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error { n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID) if err != nil { return err @@ -388,21 +384,21 @@ func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { // List returns hosts visible to the caller. // Admins see all hosts; non-admins see only BYOC hosts belonging to their team. -func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) { +func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) { if isAdmin { return s.DB.ListHosts(ctx) } - return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true}) + return s.DB.ListHostsByTeam(ctx, teamID) } // Get returns a single host, enforcing access control. -func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) { +func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return db.Host{}, fmt.Errorf("host not found: %w", err) } if !isAdmin { - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return db.Host{}, fmt.Errorf("host not found") } } @@ -411,8 +407,8 @@ func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bo // DeletePreview returns what would be affected by deleting the host, without // making any changes. Use this to show the user a confirmation prompt. -func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, isAdmin bool) (HostDeletePreview, error) { - host, err := s.checkDeletePermission(ctx, hostID, "", teamID, isAdmin) +func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) { + host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin) if err != nil { return HostDeletePreview{}, err } @@ -427,7 +423,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, ids := make([]string, len(sandboxes)) for i, sb := range sandboxes { - ids[i] = sb.ID + ids[i] = id.FormatSandboxID(sb.ID) } return HostDeletePreview{Host: host, SandboxIDs: ids}, nil @@ -436,7 +432,7 @@ func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID string, // Delete removes a host. Without force it returns an error listing active // sandboxes so the caller can present a confirmation. With force it gracefully // destroys all running sandboxes before deleting the host record. -func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin, force bool) error { +func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error { host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin) if err != nil { return err @@ -453,35 +449,37 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, if len(sandboxes) > 0 && !force { ids := make([]string, len(sandboxes)) for i, sb := range sandboxes { - ids[i] = sb.ID + ids[i] = id.FormatSandboxID(sb.ID) } return &HostHasSandboxesError{SandboxIDs: ids} } + hostIDStr := id.FormatHostID(hostID) + // Gracefully destroy running sandboxes and terminate the agent (best-effort). - if host.Address.Valid && host.Address.String != "" { + if host.Address != "" { agent, err := s.Pool.GetForHost(host) if err == nil { for _, sb := range sandboxes { if sb.Status == "running" || sb.Status == "starting" { _, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sb.ID, + SandboxId: id.FormatSandboxID(sb.ID), })) if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound { - slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", sb.ID, "error", rpcErr) + slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr) } } } // Tell the agent to shut itself down immediately. if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil { - slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostID, "error", rpcErr) + slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr) } } } // Mark all affected sandboxes as stopped in DB. if len(sandboxes) > 0 { - sbIDs := make([]string, len(sandboxes)) + sbIDs := make([]pgtype.UUID, len(sandboxes)) for i, sb := range sandboxes { sbIDs[i] = sb.ID } @@ -489,18 +487,18 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, Column1: sbIDs, Status: "stopped", }); err != nil { - slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostID, "error", err) + slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err) } } // Revoke all refresh tokens for this host. if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil { - slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostID, "error", err) + slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err) } // Evict the client from the pool so no further RPCs are sent. if s.Pool != nil { - s.Pool.Evict(hostID) + s.Pool.Evict(id.FormatHostID(hostID)) } return s.DB.DeleteHost(ctx, hostID) @@ -508,7 +506,7 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, // checkDeletePermission verifies the caller has permission to delete the given // host and returns the host record on success. -func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (db.Host, error) { +func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) { host, err := s.DB.GetHost(ctx, hostID) if err != nil { return db.Host{}, fmt.Errorf("host not found: %w", err) @@ -521,11 +519,11 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, if host.Type != "byoc" { return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts") } - if !host.TeamID.Valid || host.TeamID.String != teamID { + if !host.TeamID.Valid || host.TeamID != teamID { return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team") } - if userID != "" { + if userID.Valid { membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: userID, TeamID: teamID, @@ -545,7 +543,7 @@ func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, } // AddTag adds a tag to a host. -func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { +func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return err } @@ -553,7 +551,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin } // RemoveTag removes a tag from a host. -func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { +func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return err } @@ -561,7 +559,7 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd } // ListTags returns all tags for a host. -func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) { +func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) { if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { return nil, err } diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index 142b9bd0..89e40c51 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -27,7 +27,7 @@ type SandboxService struct { // SandboxCreateParams holds the parameters for creating a sandbox. type SandboxCreateParams struct { - TeamID string + TeamID pgtype.UUID Template string VCPUs int32 MemoryMB int32 @@ -35,7 +35,7 @@ type SandboxCreateParams struct { } // agentForSandbox looks up the host for the given sandbox and returns a client. -func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID string) (hostagentClient, db.Sandbox, error) { +func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) { sb, err := s.DB.GetSandbox(ctx, sandboxID) if err != nil { return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -80,15 +80,11 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. // If the template is a snapshot, use its baked-in vcpus/memory. if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" { - if tmpl.Vcpus.Valid { - p.VCPUs = tmpl.Vcpus.Int32 - } - if tmpl.MemoryMb.Valid { - p.MemoryMB = tmpl.MemoryMb.Int32 - } + p.VCPUs = tmpl.Vcpus + p.MemoryMB = tmpl.MemoryMb } - if p.TeamID == "" { + if !p.TeamID.Valid { return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required") } @@ -110,6 +106,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } sandboxID := id.NewSandboxID() + sandboxIDStr := id.FormatSandboxID(sandboxID) if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{ ID: sandboxID, @@ -125,7 +122,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, Template: p.Template, Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, @@ -135,7 +132,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ ID: sandboxID, Status: "error", }); dbErr != nil { - slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr) + slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr) } return db.Sandbox{}, fmt.Errorf("agent create: %w", err) } @@ -158,17 +155,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. } // List returns active sandboxes (excludes stopped/error) belonging to the given team. -func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) { +func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) { return s.DB.ListSandboxesByTeam(ctx, teamID) } // Get returns a single sandbox by ID, scoped to the given team. -func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) } // Pause snapshots and freezes a running sandbox to disk. -func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -182,11 +179,13 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d return db.Sandbox{}, err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + // Flush all metrics tiers before pausing so data survives in DB. s.flushAndPersistMetrics(ctx, agent, sandboxID, true) if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil { return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) } @@ -201,7 +200,7 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d } // Resume restores a paused sandbox from snapshot. -func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) { +func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err) @@ -215,8 +214,10 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) ( return db.Sandbox{}, err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, TimeoutSec: sb.TimeoutSec, })) if err != nil { @@ -240,7 +241,7 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) ( } // Destroy stops a sandbox and marks it as stopped. -func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error { +func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return fmt.Errorf("sandbox not found: %w", err) @@ -251,6 +252,8 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) return err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + // If running, flush 24h tier metrics for analytics before destroying. if sb.Status == "running" { s.flushAndPersistMetrics(ctx, agent, sandboxID, false) @@ -258,7 +261,7 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) // Destroy on host agent. A not-found response is fine — sandbox is already gone. if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { return fmt.Errorf("agent destroy: %w", err) } @@ -284,12 +287,13 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) // flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores // the returned data to DB. If allTiers is true, all three tiers are saved; // otherwise only the 24h tier (for post-destroy analytics). -func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID string, allTiers bool) { +func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) { + sandboxIDStr := id.FormatSandboxID(sandboxID) resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })) if err != nil { - slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxID, "error", err) + slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err) return } msg := resp.Msg @@ -301,7 +305,8 @@ func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hosta s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H) } -func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tier string, points []*pb.MetricPoint) { +func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) { + sandboxIDStr := id.FormatSandboxID(sandboxID) for _, p := range points { if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{ SandboxID: sandboxID, @@ -311,13 +316,13 @@ func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID, tie MemBytes: p.MemBytes, DiskBytes: p.DiskBytes, }); err != nil { - slog.Warn("persist metric point failed", "sandbox_id", sandboxID, "tier", tier, "error", err) + slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err) } } } // Ping resets the inactivity timer for a running sandbox. -func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error { +func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error { sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}) if err != nil { return fmt.Errorf("sandbox not found: %w", err) @@ -331,8 +336,10 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err return err } + sandboxIDStr := id.FormatSandboxID(sandboxID) + if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{ - SandboxId: sandboxID, + SandboxId: sandboxIDStr, })); err != nil { return fmt.Errorf("agent ping: %w", err) } @@ -344,7 +351,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err Valid: true, }, }); err != nil { - slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err) + slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err) } return nil } diff --git a/internal/service/stats.go b/internal/service/stats.go index 1a075aa0..88cace72 100644 --- a/internal/service/stats.go +++ b/internal/service/stats.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -72,7 +73,7 @@ type StatsService struct { // GetStats returns current stats, 30-day peaks, and a time-series for the // given team and time range. If no snapshots exist yet, zeros are returned. -func (s *StatsService) GetStats(ctx context.Context, teamID string, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) { +func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) { cfg, ok := rangeConfigs[r] if !ok { return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r) @@ -132,7 +133,7 @@ GROUP BY bucket ORDER BY bucket ASC ` -func (s *StatsService) queryTimeSeries(ctx context.Context, teamID string, cfg rangeConfig) ([]StatPoint, error) { +func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) { rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral) if err != nil { return nil, err diff --git a/internal/service/team.go b/internal/service/team.go index d4c911c2..667cd044 100644 --- a/internal/service/team.go +++ b/internal/service/team.go @@ -9,6 +9,7 @@ import ( "connectrpc.com/connect" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -43,7 +44,7 @@ type MemberInfo struct { // callerRole fetches the calling user's role in the given team from DB. // Returns an error wrapping "forbidden" if the caller is not a member. -func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID string) (string, error) { +func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) { m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ UserID: callerUserID, TeamID: teamID, @@ -66,7 +67,7 @@ func requireAdmin(role string) error { } // GetTeam returns the team by ID. Returns an error if the team is deleted or not found. -func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, error) { +func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) { team, err := s.DB.GetTeam(ctx, teamID) if err != nil { if err == pgx.ErrNoRows { @@ -81,7 +82,7 @@ func (s *TeamService) GetTeam(ctx context.Context, teamID string) (db.Team, erro } // ListTeamsForUser returns all active teams the user belongs to, with their role in each. -func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]TeamWithRole, error) { +func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) { rows, err := s.DB.GetTeamsForUser(ctx, userID) if err != nil { return nil, fmt.Errorf("list teams: %w", err) @@ -97,7 +98,7 @@ func (s *TeamService) ListTeamsForUser(ctx context.Context, userID string) ([]Te } // CreateTeam creates a new team owned by the given user. -func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string) (TeamWithRole, error) { +func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) { if !teamNameRE.MatchString(name) { return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _") } @@ -137,7 +138,7 @@ func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID, name string) } // RenameTeam updates the team name. Caller must be admin or owner (verified from DB). -func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newName string) error { +func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error { if !teamNameRE.MatchString(newName) { return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _") } @@ -159,7 +160,7 @@ func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID, newN // DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes. // Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates) // are preserved; only the team's deleted_at is set and active VMs are stopped. -func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID string) error { +func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -174,16 +175,16 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin return fmt.Errorf("list active sandboxes: %w", err) } - var stopIDs []string + var stopIDs []pgtype.UUID for _, sb := range sandboxes { host, hostErr := s.DB.GetHost(ctx, sb.HostID) if hostErr == nil { agent, agentErr := s.HostPool.GetForHost(host) if agentErr == nil { if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{ - SandboxId: sb.ID, + SandboxId: id.FormatSandboxID(sb.ID), })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", sb.ID, "error", err) + slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err) } } } @@ -208,7 +209,7 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID strin } // GetMembers returns all members of the team with their emails and roles. -func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberInfo, error) { +func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) { rows, err := s.DB.GetTeamMembers(ctx, teamID) if err != nil { return nil, fmt.Errorf("get members: %w", err) @@ -220,7 +221,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn joinedAt = r.JoinedAt.Time } members[i] = MemberInfo{ - UserID: r.ID, + UserID: id.FormatUserID(r.ID), Name: r.Name, Email: r.Email, Role: r.Role, @@ -232,7 +233,7 @@ func (s *TeamService) GetMembers(ctx context.Context, teamID string) ([]MemberIn // AddMember adds an existing user (looked up by email) to the team as a member. // Caller must be admin or owner (verified from DB). -func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email string) (MemberInfo, error) { +func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return MemberInfo{}, err @@ -269,12 +270,12 @@ func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID, email return MemberInfo{}, fmt.Errorf("insert member: %w", err) } - return MemberInfo{UserID: target.ID, Name: target.Name, Email: target.Email, Role: "member"}, nil + return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil } // RemoveMember removes a user from the team. // Caller must be admin or owner (verified from DB). Owner cannot be removed. -func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID string) error { +func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error { callerRole, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -310,7 +311,7 @@ func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, ta // UpdateMemberRole changes a member's role to admin or member. // Caller must be admin or owner (verified from DB). Owner's role cannot be changed. // Valid target roles: "admin", "member". -func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID, newRole string) error { +func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error { if newRole != "admin" && newRole != "member" { return fmt.Errorf("invalid: role must be admin or member") } @@ -350,7 +351,7 @@ func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID // LeaveTeam removes the calling user from the team. // The owner cannot leave; they must delete the team instead. -func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string) error { +func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error { role, err := s.callerRole(ctx, teamID, callerUserID) if err != nil { return err @@ -371,7 +372,7 @@ func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string // SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot // be disabled — it is a one-way transition. // Admin-only — the caller must verify admin status before invoking this. -func (s *TeamService) SetBYOC(ctx context.Context, teamID string, enabled bool) error { +func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error { team, err := s.DB.GetTeam(ctx, teamID) if err != nil { return fmt.Errorf("team not found: %w", err) diff --git a/internal/service/template.go b/internal/service/template.go index d669e455..22bc4d60 100644 --- a/internal/service/template.go +++ b/internal/service/template.go @@ -3,6 +3,8 @@ package service import ( "context" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/db" ) @@ -14,7 +16,7 @@ type TemplateService struct { // List returns all templates belonging to the given team. If typeFilter is // non-empty, only templates of that type ("base" or "snapshot") are returned. -func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) { +func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) { if typeFilter != "" { return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{ TeamID: teamID, From c0d6381bbe6e53740923cde3d7c52add0f151d48 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 23:45:41 +0600 Subject: [PATCH 11/28] Add disk_size_mb, auto-expand base images, admin templates endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Disk sizing: - Add disk_size_mb column to sandboxes table (default 20480 = 20GB) - Add disk_size_mb to CreateSandboxRequest proto, passed through the full chain: service → RPC → host agent → sandbox manager → devicemapper - devicemapper.CreateSnapshot takes separate cowSizeBytes param so the sparse CoW file can be sized independently from the origin - EnsureImageSizes() runs at host agent startup: expands any base image smaller than 20GB via truncate + resize2fs (sparse, no extra physical disk). Sandboxes then get the full 20GB via fast dm-snapshot path - FlattenRootfs shrinks output images with resize2fs -M so stored templates are compact; EnsureImageSizes re-expands on next startup Admin templates visibility: - Add GET /v1/admin/templates endpoint listing all templates across teams - Frontend admin templates page uses listAdminTemplates() instead of team-scoped listSnapshots() - Platform templates (team_id = all-zeros UUID) now visible to all teams: GetTemplateByTeam, ListTemplatesByTeam, ListTemplatesByTeamAndType queries include platform team_id in WHERE clause --- cmd/host-agent/main.go | 8 ++ db/migrations/20260310094104_initial.sql | 1 + db/queries/sandboxes.sql | 4 +- db/queries/templates.sql | 9 ++- frontend/src/lib/api/builds.ts | 14 ++++ .../src/routes/admin/templates/+page.svelte | 12 +-- internal/api/handlers_builds.go | 41 +++++++++- internal/api/server.go | 3 +- internal/db/models.go | 1 + internal/db/sandboxes.sql.go | 33 ++++++--- internal/db/templates.sql.go | 9 ++- internal/devicemapper/devicemapper.go | 10 ++- internal/hostagent/server.go | 2 +- internal/sandbox/images.go | 74 +++++++++++++++++++ internal/sandbox/manager.go | 31 ++++++-- internal/service/build.go | 3 +- internal/service/sandbox.go | 6 ++ proto/hostagent/gen/hostagent.pb.go | 18 ++++- proto/hostagent/hostagent.proto | 4 + 19 files changed, 241 insertions(+), 42 deletions(-) create mode 100644 internal/sandbox/images.go diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 130faaf7..76dc2390 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -59,6 +59,14 @@ func main() { os.Exit(1) } + // Expand base images to the standard disk size (sparse, no extra physical + // disk). This ensures dm-snapshot sandboxes see the full size from boot. + imagesDir := filepath.Join(rootDir, "images") + if err := sandbox.EnsureImageSizes(imagesDir, sandbox.DefaultDiskSizeMB); err != nil { + slog.Error("failed to expand base images", "error", err) + os.Exit(1) + } + cfg := sandbox.Config{ KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"), ImagesDir: filepath.Join(rootDir, "images"), diff --git a/db/migrations/20260310094104_initial.sql b/db/migrations/20260310094104_initial.sql index be5d29f6..da6607f3 100644 --- a/db/migrations/20260310094104_initial.sql +++ b/db/migrations/20260310094104_initial.sql @@ -144,6 +144,7 @@ CREATE TABLE sandboxes ( vcpus INTEGER NOT NULL DEFAULT 1, memory_mb INTEGER NOT NULL DEFAULT 512, timeout_sec INTEGER NOT NULL DEFAULT 300, + disk_size_mb INTEGER NOT NULL DEFAULT 20480, guest_ip TEXT NOT NULL DEFAULT '', host_ip TEXT NOT NULL DEFAULT '', created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 71e61dc9..b8ae8de5 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -1,6 +1,6 @@ -- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; -- name: GetSandbox :one diff --git a/db/queries/templates.sql b/db/queries/templates.sql index b17abc3a..c7b70855 100644 --- a/db/queries/templates.sql +++ b/db/queries/templates.sql @@ -7,7 +7,8 @@ RETURNING *; SELECT * FROM templates WHERE name = $1; -- name: GetTemplateByTeam :one -SELECT * FROM templates WHERE name = $1 AND team_id = $2; +-- Platform templates (team_id = 00000000-...) are visible to all teams. +SELECT * FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000'); -- name: ListTemplates :many SELECT * FROM templates ORDER BY created_at DESC; @@ -16,10 +17,12 @@ SELECT * FROM templates ORDER BY created_at DESC; SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC; -- name: ListTemplatesByTeam :many -SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC; +-- Platform templates are visible to all teams. +SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC; -- name: ListTemplatesByTeamAndType :many -SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC; +-- Platform templates are visible to all teams. +SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC; -- name: DeleteTemplate :exec DELETE FROM templates WHERE name = $1; diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts index d826b36b..bfa69fa5 100644 --- a/frontend/src/lib/api/builds.ts +++ b/frontend/src/lib/api/builds.ts @@ -50,3 +50,17 @@ export async function listBuilds(): Promise> { export async function getBuild(id: string): Promise> { return apiFetch('GET', `/api/v1/admin/builds/${id}`); } + +export type AdminTemplate = { + name: string; + type: string; + vcpus: number; + memory_mb: number; + size_bytes: number; + team_id: string; + created_at: string; +}; + +export async function listAdminTemplates(): Promise> { + return apiFetch('GET', '/api/v1/admin/templates'); +} diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte index 904bbad4..c320ea8f 100644 --- a/frontend/src/routes/admin/templates/+page.svelte +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -3,12 +3,14 @@ import { onMount, onDestroy } from 'svelte'; import { toast } from '$lib/toast.svelte'; import { formatDate, timeAgo } from '$lib/utils/format'; - import { listSnapshots, deleteSnapshot, type Snapshot } from '$lib/api/capsules'; + import { deleteSnapshot } from '$lib/api/capsules'; import { listBuilds, createBuild, + listAdminTemplates, type Build, - type BuildLogEntry + type BuildLogEntry, + type AdminTemplate } from '$lib/api/builds'; let collapsed = $state( @@ -20,7 +22,7 @@ let activeTab = $state<'templates' | 'builds'>('templates'); // Templates state - let templates = $state([]); + let templates = $state([]); let templatesLoading = $state(true); let templatesError = $state(null); @@ -38,7 +40,7 @@ let expandedSteps = $state>(new Set()); // Delete template state - let deleteTarget = $state(null); + let deleteTarget = $state(null); let deleting = $state(false); let deleteError = $state(null); @@ -64,7 +66,7 @@ async function fetchTemplates() { templatesLoading = true; templatesError = null; - const result = await listSnapshots(); + const result = await listAdminTemplates(); if (result.ok) { templates = result.data; } else { diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index f1b3973d..61eebe76 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -17,10 +17,11 @@ import ( type buildHandler struct { svc *service.BuildService + db *db.Queries } -func newBuildHandler(svc *service.BuildService) *buildHandler { - return &buildHandler{svc: svc} +func newBuildHandler(svc *service.BuildService, db *db.Queries) *buildHandler { + return &buildHandler{svc: svc, db: db} } type createBuildRequest struct { @@ -165,3 +166,39 @@ func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, buildToResponse(build)) } + +// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams. +func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { + templates, err := h.db.ListTemplates(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates") + return + } + + type templateResponse struct { + Name string `json:"name"` + Type string `json:"type"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + SizeBytes int64 `json:"size_bytes"` + TeamID string `json:"team_id"` + CreatedAt string `json:"created_at"` + } + + resp := make([]templateResponse, len(templates)) + for i, t := range templates { + resp[i] = templateResponse{ + Name: t.Name, + Type: t.Type, + VCPUs: t.Vcpus, + MemoryMB: t.MemoryMb, + SizeBytes: t.SizeBytes, + TeamID: id.FormatTeamID(t.TeamID), + } + if t.CreatedAt.Valid { + resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339) + } + } + + writeJSON(w, http.StatusOK, resp) +} diff --git a/internal/api/server.go b/internal/api/server.go index 6999b8a8..1be4473d 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -67,7 +67,7 @@ func New( auditH := newAuditHandler(auditSvc) statsH := newStatsHandler(statsSvc) metricsH := newSandboxMetricsHandler(queries, pool) - buildH := newBuildHandler(buildSvc) + buildH := newBuildHandler(buildSvc, queries) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -177,6 +177,7 @@ func New( r.Use(requireJWT(jwtSecret)) r.Use(requireAdmin(queries)) r.Put("/teams/{id}/byoc", teamH.SetBYOC) + r.Get("/templates", buildH.ListTemplates) r.Post("/builds", buildH.Create) r.Get("/builds", buildH.List) r.Get("/builds/{id}", buildH.Get) diff --git a/internal/db/models.go b/internal/db/models.go index 3aa765c5..f35bfe7a 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -91,6 +91,7 @@ type Sandbox struct { Vcpus int32 `json:"vcpus"` MemoryMb int32 `json:"memory_mb"` TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` GuestIp string `json:"guest_ip"` HostIp string `json:"host_ip"` CreatedAt pgtype.Timestamptz `json:"created_at"` diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 07effdf2..ace43709 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -43,7 +43,7 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu } const getSandbox = `-- name: GetSandbox :one -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 ` func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) { @@ -58,6 +58,7 @@ func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, erro &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -69,7 +70,7 @@ func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, erro } const getSandboxByTeam = `-- name: GetSandboxByTeam :one -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 AND team_id = $2 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 AND team_id = $2 ` type GetSandboxByTeamParams struct { @@ -89,6 +90,7 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -100,9 +102,9 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara } const insertSandbox = `-- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type InsertSandboxParams struct { @@ -114,6 +116,7 @@ type InsertSandboxParams struct { Vcpus int32 `json:"vcpus"` MemoryMb int32 `json:"memory_mb"` TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` } func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) { @@ -126,6 +129,7 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S arg.Vcpus, arg.MemoryMb, arg.TimeoutSec, + arg.DiskSizeMb, ) var i Sandbox err := row.Scan( @@ -137,6 +141,7 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -148,7 +153,7 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S } const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') ORDER BY created_at DESC ` @@ -171,6 +176,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.U &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -189,7 +195,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.U } const listSandboxes = `-- name: ListSandboxes :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes ORDER BY created_at DESC +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes ORDER BY created_at DESC ` func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { @@ -210,6 +216,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -228,7 +235,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { } const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) ORDER BY created_at DESC ` @@ -256,6 +263,7 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -274,7 +282,7 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand } const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE team_id = $1 AND status NOT IN ('stopped', 'error') ORDER BY created_at DESC ` @@ -297,6 +305,7 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ( &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -355,7 +364,7 @@ SET status = 'running', last_active_at = $4, last_updated = NOW() WHERE id = $1 -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type UpdateSandboxRunningParams struct { @@ -382,6 +391,7 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, @@ -397,7 +407,7 @@ UPDATE sandboxes SET status = $2, last_updated = NOW() WHERE id = $1 -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated ` type UpdateSandboxStatusParams struct { @@ -417,6 +427,7 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat &i.Vcpus, &i.MemoryMb, &i.TimeoutSec, + &i.DiskSizeMb, &i.GuestIp, &i.HostIp, &i.CreatedAt, diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go index 8703bc92..45a673c2 100644 --- a/internal/db/templates.sql.go +++ b/internal/db/templates.sql.go @@ -54,7 +54,7 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error } const getTemplateByTeam = `-- name: GetTemplateByTeam :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2 +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000') ` type GetTemplateByTeamParams struct { @@ -62,6 +62,7 @@ type GetTemplateByTeamParams struct { TeamID pgtype.UUID `json:"team_id"` } +// Platform templates (team_id = 00000000-...) are visible to all teams. func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) { row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID) var i Template @@ -147,9 +148,10 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { } const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC ` +// Platform templates are visible to all teams. func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) { rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID) if err != nil { @@ -179,7 +181,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ( } const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC ` type ListTemplatesByTeamAndTypeParams struct { @@ -187,6 +189,7 @@ type ListTemplatesByTeamAndTypeParams struct { Type string `json:"type"` } +// Platform templates are visible to all teams. func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) { rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type) if err != nil { diff --git a/internal/devicemapper/devicemapper.go b/internal/devicemapper/devicemapper.go index ea14fcd8..ba801f13 100644 --- a/internal/devicemapper/devicemapper.go +++ b/internal/devicemapper/devicemapper.go @@ -116,9 +116,10 @@ type SnapshotDevice struct { // writable CoW layer. // // The origin loop device must already exist (from LoopRegistry.Acquire). -func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) { - // Create sparse CoW file sized to match the origin. - if err := createSparseFile(cowPath, originSizeBytes); err != nil { +func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) { + // Create sparse CoW file. The logical size limits how many blocks can be + // modified; because the file is sparse, only written blocks use real disk. + if err := createSparseFile(cowPath, cowSizeBytes); err != nil { return nil, fmt.Errorf("create cow file: %w", err) } @@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) return nil, fmt.Errorf("losetup cow: %w", err) } + // The dm-snapshot virtual device size must match the origin — the snapshot + // target maps 1:1 onto origin sectors. The CoW file just needs enough + // space to store all modified blocks (it's sparse, so 20GB costs nothing). sectors := originSizeBytes / 512 if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil { if detachErr := losetupDetach(cowLoopDev); detachErr != nil { diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index 86fdda0e..549158f0 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -39,7 +39,7 @@ func (s *Server) CreateSandbox( ) (*connect.Response[pb.CreateSandboxResponse], error) { msg := req.Msg - sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec)) + sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb)) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err)) } diff --git a/internal/sandbox/images.go b/internal/sandbox/images.go new file mode 100644 index 00000000..0f3e24ad --- /dev/null +++ b/internal/sandbox/images.go @@ -0,0 +1,74 @@ +package sandbox + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" +) + +// DefaultDiskSizeMB is the standard disk size for base images. Images smaller +// than this are expanded at startup so that dm-snapshot sandboxes see the full +// size without per-sandbox copies. The expansion is sparse — only metadata +// changes; no physical disk is consumed beyond the original content. +const DefaultDiskSizeMB = 20480 // 20 GB + +// EnsureImageSizes walks the images directory and expands any rootfs.ext4 that +// is smaller than the target size. This is idempotent: images already at or +// above the target size are left untouched. Should be called once at host agent +// startup before any sandboxes are created. +func EnsureImageSizes(imagesDir string, targetMB int) error { + if targetMB <= 0 { + targetMB = DefaultDiskSizeMB + } + targetBytes := int64(targetMB) * 1024 * 1024 + + entries, err := os.ReadDir(imagesDir) + if err != nil { + return fmt.Errorf("read images dir: %w", err) + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + rootfs := filepath.Join(imagesDir, entry.Name(), "rootfs.ext4") + info, err := os.Stat(rootfs) + if err != nil { + continue // not every template dir has a rootfs.ext4 + } + + if info.Size() >= targetBytes { + continue // already large enough + } + + slog.Info("expanding base image", + "template", entry.Name(), + "from_mb", info.Size()/(1024*1024), + "to_mb", targetMB, + ) + + // Expand the file (sparse — instant, no physical disk used). + if err := os.Truncate(rootfs, targetBytes); err != nil { + return fmt.Errorf("truncate %s: %w", rootfs, err) + } + + // Check filesystem before resize. + if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil { + // e2fsck returns 1 if it fixed errors, which is fine. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { + return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err) + } + } + + // Grow the ext4 filesystem to fill the new file size. + if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil { + return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err) + } + + slog.Info("base image expanded", "template", entry.Name(), "size_mb", targetMB) + } + + return nil +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 15453eb4..b91fed14 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "os/exec" "path/filepath" "sync" "time" @@ -51,8 +52,8 @@ type sandboxState struct { slot *network.Slot client *envdclient.Client uffdSocketPath string // non-empty for sandboxes restored from snapshot - dmDevice *devicemapper.SnapshotDevice - baseImagePath string // path to the base template rootfs (for loop registry release) + dmDevice *devicemapper.SnapshotDevice + baseImagePath string // path to the base template rootfs (for loop registry release) // parent holds the snapshot header and diff file paths from which this // sandbox was restored. Non-nil means re-pause should use "Diff" snapshot @@ -94,7 +95,7 @@ func New(cfg Config) *Manager { // Create boots a new sandbox: clone rootfs, set up network, start VM, wait for envd. // If sandboxID is empty, a new ID is generated. -func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec int) (*models.Sandbox, error) { +func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { if sandboxID == "" { sandboxID = id.FormatSandboxID(id.NewSandboxID()) } @@ -105,6 +106,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, if memoryMB <= 0 { memoryMB = 512 } + if diskSizeMB <= 0 { + diskSizeMB = 20480 // 20 GB default + } if template == "" { template = "minimal" @@ -115,7 +119,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Check if template refers to a snapshot (has snapfile + memfile + header + rootfs). if snapshot.IsSnapshot(m.cfg.ImagesDir, template) { - return m.createFromSnapshot(ctx, sandboxID, template, vcpus, memoryMB, timeoutSec) + return m.createFromSnapshot(ctx, sandboxID, template, vcpus, memoryMB, timeoutSec, diskSizeMB) } // Resolve base rootfs image: /var/lib/wrenn/images/{template}/rootfs.ext4 @@ -139,7 +143,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Create dm-snapshot with per-sandbox CoW file. dmName := "wrenn-" + sandboxID cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) - dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize) + cowSize := int64(diskSizeMB) * 1024 * 1024 + dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { m.loops.Release(baseRootfs) return nil, fmt.Errorf("create dm-snapshot: %w", err) @@ -853,6 +858,17 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (in // Clean up dm device and loop device now that flatten is complete. m.cleanupDM(sb) + // Shrink the flattened image to its minimum size so stored templates are + // compact. EnsureImageSizes will re-expand them on the next agent startup. + if out, err := exec.Command("e2fsck", "-fy", outputPath).CombinedOutput(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { + slog.Warn("e2fsck before shrink failed (non-fatal)", "output", string(out), "error", err) + } + } + if out, err := exec.Command("resize2fs", "-M", outputPath).CombinedOutput(); err != nil { + slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err) + } + sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name) if err != nil { slog.Warn("failed to calculate template size", "error", err) @@ -891,7 +907,7 @@ func (m *Manager) DeleteSnapshot(name string) error { // in ImagesDir/{snapshotName}/. Uses UFFD for lazy memory loading. // The template's rootfs.ext4 is a flattened standalone image — we create a // dm-snapshot on top of it just like a normal Create. -func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotName string, vcpus, _, timeoutSec int) (*models.Sandbox, error) { +func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotName string, vcpus, _, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { imagesDir := m.cfg.ImagesDir // Read the header. @@ -936,7 +952,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam dmName := "wrenn-" + sandboxID cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) - dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize) + cowSize := int64(diskSizeMB) * 1024 * 1024 + dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { source.Close() m.loops.Release(baseRootfs) diff --git a/internal/service/build.go b/internal/service/build.go index 1bd82a8f..19fb5d21 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -199,7 +199,8 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { Template: build.BaseTemplate, Vcpus: build.Vcpus, MemoryMb: build.MemoryMb, - TimeoutSec: 0, // no auto-pause for builds + TimeoutSec: 0, // no auto-pause for builds + DiskSizeMb: 20480, // 20 GB for template builds })) if err != nil { s.failBuild(ctx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index 89e40c51..96d1282c 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -32,6 +32,7 @@ type SandboxCreateParams struct { VCPUs int32 MemoryMB int32 TimeoutSec int32 + DiskSizeMB int32 } // agentForSandbox looks up the host for the given sandbox and returns a client. @@ -77,6 +78,9 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. if p.MemoryMB <= 0 { p.MemoryMB = 512 } + if p.DiskSizeMB <= 0 { + p.DiskSizeMB = 20480 // 20 GB default + } // If the template is a snapshot, use its baked-in vcpus/memory. if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" { @@ -117,6 +121,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TimeoutSec: p.TimeoutSec, + DiskSizeMb: p.DiskSizeMB, }); err != nil { return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err) } @@ -127,6 +132,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TimeoutSec: p.TimeoutSec, + DiskSizeMb: p.DiskSizeMB, })) if err != nil { if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index c7436b76..aa29db98 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -33,7 +33,10 @@ type CreateSandboxRequest struct { MemoryMb int32 `protobuf:"varint,3,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` // TTL in seconds. Sandbox is auto-paused after this duration of // inactivity. 0 means no auto-pause. - TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + // Disk size in MB for the sparse CoW file. Limits how much data the + // sandbox can write beyond the base image. Default: 20480 (20 GB). + DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -103,6 +106,13 @@ func (x *CreateSandboxRequest) GetTimeoutSec() int32 { return 0 } +func (x *CreateSandboxRequest) GetDiskSizeMb() int32 { + if x != nil { + return x.DiskSizeMb + } + return 0 +} + type CreateSandboxResponse struct { state protoimpl.MessageState `protogen:"open.v1"` SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` @@ -2271,7 +2281,7 @@ var File_hostagent_proto protoreflect.FileDescriptor const file_hostagent_proto_rawDesc = "" + "\n" + - "\x0fhostagent.proto\x12\fhostagent.v1\"\xa5\x01\n" + + "\x0fhostagent.proto\x12\fhostagent.v1\"\xc7\x01\n" + "\x14CreateSandboxRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x05 \x01(\tR\tsandboxId\x12\x1a\n" + @@ -2279,7 +2289,9 @@ const file_hostagent_proto_rawDesc = "" + "\x05vcpus\x18\x02 \x01(\x05R\x05vcpus\x12\x1b\n" + "\tmemory_mb\x18\x03 \x01(\x05R\bmemoryMb\x12\x1f\n" + "\vtimeout_sec\x18\x04 \x01(\x05R\n" + - "timeoutSec\"g\n" + + "timeoutSec\x12 \n" + + "\fdisk_size_mb\x18\x06 \x01(\x05R\n" + + "diskSizeMb\"g\n" + "\x15CreateSandboxResponse\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index cd93a2db..9af40eb3 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -85,6 +85,10 @@ message CreateSandboxRequest { // TTL in seconds. Sandbox is auto-paused after this duration of // inactivity. 0 means no auto-pause. int32 timeout_sec = 4; + + // Disk size in MB for the sparse CoW file. Limits how much data the + // sandbox can write beyond the base image. Default: 20480 (20 GB). + int32 disk_size_mb = 6; } message CreateSandboxResponse { From 5cb37bf2a0144abc270c38db96a4e20f72afa1d6 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Thu, 26 Mar 2026 23:53:08 +0600 Subject: [PATCH 12/28] Add admin template deletion with broadcast to all hosts - DELETE /v1/admin/templates/{name} endpoint (admin-only) - Broadcasts DeleteSnapshot RPC to all online hosts before removing DB record - Frontend admin templates page uses deleteAdminTemplate() instead of team-scoped deleteSnapshot() - Delete button shown for all template types, not just snapshots --- frontend/src/lib/api/builds.ts | 4 ++ .../src/routes/admin/templates/+page.svelte | 18 +++---- internal/api/handlers_builds.go | 51 +++++++++++++++++-- internal/api/server.go | 3 +- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts index bfa69fa5..900acf2d 100644 --- a/frontend/src/lib/api/builds.ts +++ b/frontend/src/lib/api/builds.ts @@ -64,3 +64,7 @@ export type AdminTemplate = { export async function listAdminTemplates(): Promise> { return apiFetch('GET', '/api/v1/admin/templates'); } + +export async function deleteAdminTemplate(name: string): Promise> { + return apiFetch('DELETE', `/api/v1/admin/templates/${name}`); +} diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte index c320ea8f..0d719bd0 100644 --- a/frontend/src/routes/admin/templates/+page.svelte +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -3,11 +3,11 @@ import { onMount, onDestroy } from 'svelte'; import { toast } from '$lib/toast.svelte'; import { formatDate, timeAgo } from '$lib/utils/format'; - import { deleteSnapshot } from '$lib/api/capsules'; import { listBuilds, createBuild, listAdminTemplates, + deleteAdminTemplate, type Build, type BuildLogEntry, type AdminTemplate @@ -145,7 +145,7 @@ deleting = true; deleteError = null; const name = deleteTarget.name; - const result = await deleteSnapshot(name); + const result = await deleteAdminTemplate(name); if (result.ok) { templates = templates.filter((t) => t.name !== name); deleteTarget = null; @@ -413,14 +413,12 @@ - {#if tmpl.type === 'snapshot'} - - {/if} + {/each} diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index 61eebe76..8b8fd5cd 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -7,21 +7,25 @@ import ( "net/http" "time" + "connectrpc.com/connect" "github.com/go-chi/chi/v5" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) type buildHandler struct { - svc *service.BuildService - db *db.Queries + svc *service.BuildService + db *db.Queries + pool *lifecycle.HostClientPool } -func newBuildHandler(svc *service.BuildService, db *db.Queries) *buildHandler { - return &buildHandler{svc: svc, db: db} +func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler { + return &buildHandler{svc: svc, db: db, pool: pool} } type createBuildRequest struct { @@ -202,3 +206,42 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, resp) } + +// DeleteTemplate handles DELETE /v1/admin/templates/{name}. +func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "name") + if err := validate.SafeName(name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err)) + return + } + ctx := r.Context() + + if _, err := h.db.GetTemplate(ctx, name); err != nil { + writeError(w, http.StatusNotFound, "not_found", "template not found") + return + } + + // Broadcast delete to all online hosts. + hosts, _ := h.db.ListActiveHosts(ctx) + for _, host := range hosts { + if host.Status != "online" { + continue + } + agent, err := h.pool.GetForHost(host) + if err != nil { + continue + } + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { + if connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err) + } + } + } + + if err := h.db.DeleteTemplate(ctx, name); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/server.go b/internal/api/server.go index 1be4473d..d298b298 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -67,7 +67,7 @@ func New( auditH := newAuditHandler(auditSvc) statsH := newStatsHandler(statsSvc) metricsH := newSandboxMetricsHandler(queries, pool) - buildH := newBuildHandler(buildSvc, queries) + buildH := newBuildHandler(buildSvc, queries, pool) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -178,6 +178,7 @@ func New( r.Use(requireAdmin(queries)) r.Put("/teams/{id}/byoc", teamH.SetBYOC) r.Get("/templates", buildH.ListTemplates) + r.Delete("/templates/{name}", buildH.DeleteTemplate) r.Post("/builds", buildH.Create) r.Get("/builds", buildH.List) r.Get("/builds/{id}", buildH.Get) From c8acac92cc046c96cd33f37264671b10e22d820f Mon Sep 17 00:00:00 2001 From: pptx704 Date: Fri, 27 Mar 2026 00:00:48 +0600 Subject: [PATCH 13/28] Add pre/post build stages to template builds Pre-build: apt update Post-build: apt clean, apt autoremove, rm apt lists Total steps count includes pre/post commands for accurate progress bars. --- internal/service/build.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/internal/service/build.go b/internal/service/build.go index 19fb5d21..5510da0f 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -25,6 +25,18 @@ const ( healthcheckTimeout = 60 * time.Second ) +// preBuildCmds run before the user recipe to prepare the build environment. +var preBuildCmds = []string{ + "apt update", +} + +// postBuildCmds run after the user recipe to clean up caches and reduce image size. +var postBuildCmds = []string{ + "apt clean", + "apt autoremove -y", + "rm -rf /var/lib/apt/lists/*", +} + // buildAgentClient is the subset of the host agent client used by the build worker. type buildAgentClient interface { CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error) @@ -91,7 +103,7 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp Healthcheck: p.Healthcheck, Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, - TotalSteps: int32(len(p.Recipe)), + TotalSteps: int32(len(p.Recipe) + len(preBuildCmds) + len(postBuildCmds)), }) if err != nil { return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) @@ -170,13 +182,18 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { return } - // Parse recipe. - var recipe []string - if err := json.Unmarshal(build.Recipe, &recipe); err != nil { + // Parse user recipe and wrap with pre/post build stages. + var userRecipe []string + if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil { s.failBuild(ctx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err)) return } + recipe := make([]string, 0, len(userRecipe)+len(preBuildCmds)+len(postBuildCmds)) + recipe = append(recipe, preBuildCmds...) + recipe = append(recipe, userRecipe...) + recipe = append(recipe, postBuildCmds...) + // Pick a platform host and create a sandbox. host, err := s.Scheduler.SelectHost(ctx, id.PlatformTeamID, false) if err != nil { From 3509ca90e8d5820cf708b42d6ac2bdc023bb8894 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Fri, 27 Mar 2026 00:28:32 +0600 Subject: [PATCH 14/28] Add pre/post build stages, fix exec timeout, expand guest PATH Build phases: - Pre-build (apt update) and post-build (apt clean, autoremove, rm lists) run with 10-minute timeout; user recipe commands keep 30s timeout - Log entries include phase field for UI grouping - Always send explicit TimeoutSec to host agent (0 defaulted to 30s) Frontend: - Pre-build/post-build steps show phase label without exposing commands - Recipe steps numbered independently starting from 1 Guest PATH: - Add /usr/games:/usr/local/games to wrenn-init.sh PATH export (standard Ubuntu paths, needed for packages like cowsay) --- frontend/src/lib/api/builds.ts | 1 + .../src/routes/admin/templates/+page.svelte | 17 ++- images/wrenn-init.sh | 2 +- internal/service/build.go | 114 +++++++++++------- 4 files changed, 84 insertions(+), 50 deletions(-) diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts index 900acf2d..349c6e1c 100644 --- a/frontend/src/lib/api/builds.ts +++ b/frontend/src/lib/api/builds.ts @@ -2,6 +2,7 @@ import { apiFetch, type ApiResult } from '$lib/api/client'; export type BuildLogEntry = { step: number; + phase: string; // "pre-build", "recipe", or "post-build" cmd: string; stdout: string; stderr: string; diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte index 0d719bd0..dde8fc3a 100644 --- a/frontend/src/routes/admin/templates/+page.svelte +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -521,6 +521,9 @@ {#if build.logs && build.logs.length > 0}
{#each build.logs as log, i (i)} + {@const isInternal = log.phase === 'pre-build' || log.phase === 'post-build'} + {@const recipeIdx = log.phase === 'recipe' ? build.logs.filter(l => l.phase === 'recipe' && l.step <= log.step).length : 0} + {@const phaseLabel = isInternal ? (log.phase === 'pre-build' ? 'Pre-build' : 'Post-build') : `Step ${recipeIdx}`}
{/if} {/if} diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index f3e29074..1b5f1f35 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -69,6 +69,7 @@ type snapshotResponse struct { MemoryMB *int32 `json:"memory_mb,omitempty"` SizeBytes int64 `json:"size_bytes"` CreatedAt string `json:"created_at"` + Platform bool `json:"platform"` } func templateToResponse(t db.Template) snapshotResponse { @@ -76,6 +77,7 @@ func templateToResponse(t db.Template) snapshotResponse { Name: t.Name, Type: t.Type, SizeBytes: t.SizeBytes, + Platform: t.TeamID == id.PlatformTeamID, } if t.Vcpus != 0 { resp.VCPUs = &t.Vcpus @@ -154,26 +156,43 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } - resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ + // Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC. + // The host agent's CreateSnapshot removes the sandbox from its in-memory + // map immediately; if the reconciler fires during the flatten window and + // the DB still says "running", it will mark the sandbox "stopped". + if sb.Status == "running" { + if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "paused", + }); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status") + return + } + } + + // Use a detached context with a generous timeout so the snapshot completes + // even if the client disconnects (the flatten step can take 10-20s). + snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer snapCancel() + + resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{ SandboxId: req.SandboxID, Name: req.Name, })) if err != nil { + // Snapshot failed — revert status back to what it was. + if sb.Status == "running" { + if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "running", + }); dbErr != nil { + slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr) + } + } status, code, msg := agentErrToHTTP(err) writeError(w, status, code, msg) return } - // Mark sandbox as paused (if it was running, it got paused by the snapshot). - if sb.Status != "paused" { - if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: sandboxID, Status: "paused", - }); err != nil { - slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err) - } - } - - tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{ + tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{ Name: req.Name, Type: "snapshot", Vcpus: sb.Vcpus, @@ -187,7 +206,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { return } - h.audit.LogSnapshotCreate(r.Context(), ac, req.Name) + h.audit.LogSnapshotCreate(snapCtx, ac, req.Name) + + if ctx.Err() != nil { + slog.Info("snapshot created but client disconnected before response", "name", req.Name) + return + } writeJSON(w, http.StatusCreated, templateToResponse(tmpl)) } @@ -220,10 +244,16 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ac := auth.MustFromContext(ctx) - if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil { + tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}) + if err != nil { writeError(w, http.StatusNotFound, "not_found", "template not found") return } + // Platform templates can only be deleted by admins via /v1/admin/templates. + if tmpl.TeamID == id.PlatformTeamID { + writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here") + return + } if err := h.deleteSnapshotBroadcast(ctx, name); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files") diff --git a/internal/devicemapper/devicemapper.go b/internal/devicemapper/devicemapper.go index ba801f13..9fa08332 100644 --- a/internal/devicemapper/devicemapper.go +++ b/internal/devicemapper/devicemapper.go @@ -224,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error { "if="+dmDevPath, "of="+outputPath, "bs=4M", + "conv=sparse", "status=none", ) if out, err := cmd.CombinedOutput(); err != nil { diff --git a/internal/sandbox/images.go b/internal/sandbox/images.go index 0f3e24ad..feb3398b 100644 --- a/internal/sandbox/images.go +++ b/internal/sandbox/images.go @@ -12,7 +12,7 @@ import ( // than this are expanded at startup so that dm-snapshot sandboxes see the full // size without per-sandbox copies. The expansion is sparse — only metadata // changes; no physical disk is consumed beyond the original content. -const DefaultDiskSizeMB = 20480 // 20 GB +const DefaultDiskSizeMB = 5120 // 5 GB // EnsureImageSizes walks the images directory and expands any rootfs.ext4 that // is smaller than the target size. This is idempotent: images already at or diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index b91fed14..9fcecb5a 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -107,7 +107,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB = 512 } if diskSizeMB <= 0 { - diskSizeMB = 20480 // 20 GB default + diskSizeMB = 5120 // 5 GB default } if template == "" { diff --git a/internal/service/build.go b/internal/service/build.go index 23f3b8c9..5142a133 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -213,7 +213,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { Vcpus: build.Vcpus, MemoryMb: build.MemoryMb, TimeoutSec: 0, // no auto-pause for builds - DiskSizeMb: 20480, // 20 GB for template builds + DiskSizeMb: 5120, // 5 GB for template builds })) if err != nil { s.failBuild(ctx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index 96d1282c..43f1bd3f 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -79,7 +79,7 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. p.MemoryMB = 512 } if p.DiskSizeMB <= 0 { - p.DiskSizeMB = 20480 // 20 GB default + p.DiskSizeMB = 5120 // 5 GB default } // If the template is a snapshot, use its baked-in vcpus/memory. @@ -187,20 +187,32 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUI sandboxIDStr := id.FormatSandboxID(sandboxID) + // Pre-mark as "paused" in DB before the RPC so the reconciler does not + // mark the sandbox "stopped" while the host agent processes the pause. + if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "paused", + }); err != nil { + return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err) + } + // Flush all metrics tiers before pausing so data survives in DB. s.flushAndPersistMetrics(ctx, agent, sandboxID, true) if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ SandboxId: sandboxIDStr, })); err != nil { + // Revert status on failure. + if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: "running", + }); dbErr != nil { + slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr) + } return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) } - sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: sandboxID, Status: "paused", - }) + sb, err = s.DB.GetSandbox(ctx, sandboxID) if err != nil { - return db.Sandbox{}, fmt.Errorf("update status: %w", err) + return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err) } return sb, nil } diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index aa29db98..9f984f91 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -34,8 +34,8 @@ type CreateSandboxRequest struct { // TTL in seconds. Sandbox is auto-paused after this duration of // inactivity. 0 means no auto-pause. TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` - // Disk size in MB for the sparse CoW file. Limits how much data the - // sandbox can write beyond the base image. Default: 20480 (20 GB). + // Disk size in MB for the rootfs. Base images are expanded to this size + // at host agent startup. Default: 5120 (5 GB). DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index 9af40eb3..5a3205b7 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -86,8 +86,8 @@ message CreateSandboxRequest { // inactivity. 0 means no auto-pause. int32 timeout_sec = 4; - // Disk size in MB for the sparse CoW file. Limits how much data the - // sandbox can write beyond the base image. Default: 20480 (20 GB). + // Disk size in MB for the rootfs. Base images are expanded to this size + // at host agent startup. Default: 5120 (5 GB). int32 disk_size_mb = 6; } From 03e96629c70002af8dc4c7a2a27ff1f434a9b3a0 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sat, 28 Mar 2026 20:45:57 +0600 Subject: [PATCH 17/28] Remove slug from team page UI --- .../src/routes/dashboard/team/+page.svelte | 165 ++++++------------ 1 file changed, 51 insertions(+), 114 deletions(-) diff --git a/frontend/src/routes/dashboard/team/+page.svelte b/frontend/src/routes/dashboard/team/+page.svelte index 773cdf4f..f17311ba 100644 --- a/frontend/src/routes/dashboard/team/+page.svelte +++ b/frontend/src/routes/dashboard/team/+page.svelte @@ -50,7 +50,6 @@ let nameInputEl = $state(null); // Copy state - let copiedSlug = $state(false); let copiedId = $state(false); // Add member dialog @@ -139,16 +138,11 @@ savingName = false; } - async function copyToClipboard(text: string, which: 'slug' | 'id') { + async function copyToClipboard(text: string) { try { await navigator.clipboard.writeText(text); - if (which === 'slug') { - copiedSlug = true; - setTimeout(() => (copiedSlug = false), 2000); - } else { - copiedId = true; - setTimeout(() => (copiedId = false), 2000); - } + copiedId = true; + setTimeout(() => (copiedId = false), 2000); } catch { toast.error('Copy failed — select the text and copy manually.'); } @@ -514,115 +508,58 @@
- -
- -
-
-
- Slug -
- {team.slug} -
- -
- - -
-
-
- Team ID -
- {team.id} + Team ID
-
+
From 75b28ed8992b0d1fda4a4fad1854a9605966a174 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 29 Mar 2026 00:30:10 +0600 Subject: [PATCH 18/28] Add UUID-based template IDs and team-scoped template directory layout Introduces internal/layout package for centralized path construction, migrates templates from name-based TEXT primary keys to UUID PKs with team-scoped directories (WRENN_DIR/images/teams/{team_id}/{template_id}). The built-in minimal template uses sentinel zero UUIDs. Proto messages carry team_id + template_id alongside deprecated template name field. Team deletion now cleans up template files across all hosts. --- .../20260328162803_template_uuid_pk.sql | 64 +++++ db/queries/sandboxes.sql | 4 +- db/queries/template_builds.sql | 4 +- db/queries/templates.sql | 24 +- internal/api/handlers_builds.go | 24 +- internal/api/handlers_snapshots.go | 35 ++- internal/api/middleware.go | 8 + internal/db/models.go | 35 +-- internal/db/sandboxes.sql.go | 62 +++-- internal/db/template_builds.sql.go | 24 +- internal/db/templates.sql.go | 130 ++++++++-- internal/hostagent/server.go | 67 ++++- internal/id/id.go | 20 +- internal/id/id_test.go | 6 +- internal/layout/layout.go | 58 +++++ internal/layout/layout_test.go | 120 +++++++++ internal/models/sandbox.go | 23 +- internal/sandbox/images.go | 108 +++++--- internal/sandbox/manager.go | 244 +++++++++--------- internal/service/build.go | 33 ++- internal/service/sandbox.go | 41 ++- internal/service/team.go | 49 ++++ proto/hostagent/gen/hostagent.pb.go | 168 +++++++++--- proto/hostagent/hostagent.proto | 28 +- 24 files changed, 1057 insertions(+), 322 deletions(-) create mode 100644 db/migrations/20260328162803_template_uuid_pk.sql create mode 100644 internal/layout/layout.go create mode 100644 internal/layout/layout_test.go diff --git a/db/migrations/20260328162803_template_uuid_pk.sql b/db/migrations/20260328162803_template_uuid_pk.sql new file mode 100644 index 00000000..8665241f --- /dev/null +++ b/db/migrations/20260328162803_template_uuid_pk.sql @@ -0,0 +1,64 @@ +-- +goose Up + +-- 1. Add UUID id column to templates and make it the primary key. +ALTER TABLE templates ADD COLUMN id UUID DEFAULT gen_random_uuid(); +UPDATE templates SET id = gen_random_uuid() WHERE id IS NULL; +ALTER TABLE templates ALTER COLUMN id SET NOT NULL; +ALTER TABLE templates DROP CONSTRAINT templates_pkey; +ALTER TABLE templates ADD PRIMARY KEY (id); + +-- 2. Name becomes a display field with team-scoped uniqueness. +ALTER TABLE templates ADD CONSTRAINT uq_templates_team_name UNIQUE (team_id, name); + +-- 3. Prevent team templates from using names that belong to global (platform) templates. +-- A team template insert/update with a name matching any platform template is rejected. +CREATE OR REPLACE FUNCTION check_global_template_name_collision() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.team_id != '00000000-0000-0000-0000-000000000000' THEN + IF EXISTS ( + SELECT 1 FROM templates + WHERE name = NEW.name + AND team_id = '00000000-0000-0000-0000-000000000000' + ) THEN + RAISE EXCEPTION 'template name "%" is reserved by a global template', NEW.name + USING ERRCODE = 'unique_violation'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trg_check_global_template_name + BEFORE INSERT OR UPDATE ON templates + FOR EACH ROW + EXECUTE FUNCTION check_global_template_name_collision(); + +-- 4. Add template UUID references to template_builds. +ALTER TABLE template_builds + ADD COLUMN template_id UUID, + ADD COLUMN team_id UUID; + +-- 5. Add template UUID references to sandboxes. +ALTER TABLE sandboxes + ADD COLUMN template_id UUID, + ADD COLUMN template_team_id UUID; + +-- +goose Down + +ALTER TABLE sandboxes + DROP COLUMN IF EXISTS template_team_id, + DROP COLUMN IF EXISTS template_id; + +ALTER TABLE template_builds + DROP COLUMN IF EXISTS team_id, + DROP COLUMN IF EXISTS template_id; + +DROP TRIGGER IF EXISTS trg_check_global_template_name ON templates; +DROP FUNCTION IF EXISTS check_global_template_name_collision(); + +ALTER TABLE templates DROP CONSTRAINT IF EXISTS uq_templates_team_name; + +ALTER TABLE templates DROP CONSTRAINT IF EXISTS templates_pkey; +ALTER TABLE templates ADD PRIMARY KEY (name); +ALTER TABLE templates DROP COLUMN IF EXISTS id; diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index b8ae8de5..8cbd10be 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -1,6 +1,6 @@ -- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *; -- name: GetSandbox :one diff --git a/db/queries/template_builds.sql b/db/queries/template_builds.sql index ead4d925..be1c09e5 100644 --- a/db/queries/template_builds.sql +++ b/db/queries/template_builds.sql @@ -1,6 +1,6 @@ -- name: InsertTemplateBuild :one -INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps) -VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8) +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10) RETURNING *; -- name: GetTemplateBuild :one diff --git a/db/queries/templates.sql b/db/queries/templates.sql index c7b70855..de4d6f2a 100644 --- a/db/queries/templates.sql +++ b/db/queries/templates.sql @@ -1,15 +1,23 @@ -- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) -VALUES ($1, $2, $3, $4, $5, $6) +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *; -- name: GetTemplate :one -SELECT * FROM templates WHERE name = $1; +SELECT * FROM templates WHERE id = $1; -- name: GetTemplateByTeam :one -- Platform templates (team_id = 00000000-...) are visible to all teams. SELECT * FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000'); +-- name: GetTemplateByName :one +-- Look up a template by team_id and name (exact team match, no global fallback). +SELECT * FROM templates WHERE team_id = $1 AND name = $2; + +-- name: GetPlatformTemplateByName :one +-- Check if a global (platform) template exists with the given name. +SELECT * FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1; + -- name: ListTemplates :many SELECT * FROM templates ORDER BY created_at DESC; @@ -25,7 +33,15 @@ SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-000 SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC; -- name: DeleteTemplate :exec -DELETE FROM templates WHERE name = $1; +DELETE FROM templates WHERE id = $1; -- name: DeleteTemplateByTeam :exec DELETE FROM templates WHERE name = $1 AND team_id = $2; + +-- name: DeleteTemplatesByTeam :exec +-- Bulk delete all templates owned by a team (for team soft-delete cleanup). +DELETE FROM templates WHERE team_id = $1; + +-- name: ListTemplatesByTeamOnly :many +-- List templates owned by a specific team (NOT including platform templates). +SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC; diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index 8b8fd5cd..58a7ed44 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -180,13 +180,13 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { } type templateResponse struct { - Name string `json:"name"` - Type string `json:"type"` - VCPUs int32 `json:"vcpus"` - MemoryMB int32 `json:"memory_mb"` - SizeBytes int64 `json:"size_bytes"` - TeamID string `json:"team_id"` - CreatedAt string `json:"created_at"` + Name string `json:"name"` + Type string `json:"type"` + VCPUs int32 `json:"vcpus"` + MemoryMB int32 `json:"memory_mb"` + SizeBytes int64 `json:"size_bytes"` + TeamID string `json:"team_id"` + CreatedAt string `json:"created_at"` } resp := make([]templateResponse, len(templates)) @@ -216,7 +216,8 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - if _, err := h.db.GetTemplate(ctx, name); err != nil { + tmpl, err := h.db.GetPlatformTemplateByName(ctx, name) + if err != nil { writeError(w, http.StatusNotFound, "not_found", "template not found") return } @@ -231,14 +232,17 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { if err != nil { continue } - if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: formatUUIDForRPC(tmpl.TeamID), + TemplateId: formatUUIDForRPC(tmpl.ID), + })); err != nil { if connect.CodeOf(err) != connect.CodeNotFound { slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err) } } } - if err := h.db.DeleteTemplate(ctx, name); err != nil { + if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil { writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record") return } diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index 1b5f1f35..07bd0301 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -11,6 +11,8 @@ import ( "connectrpc.com/connect" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" + "git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -34,8 +36,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life // deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts. // Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts -// and ignore NotFound errors. TODO: add host_id to templates table. -func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error { +// and ignore NotFound errors. +func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error { hosts, err := h.db.ListActiveHosts(ctx) if err != nil { return fmt.Errorf("list hosts: %w", err) @@ -48,9 +50,12 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri if err != nil { continue } - if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil { + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: formatUUIDForRPC(teamID), + TemplateId: formatUUIDForRPC(templateID), + })); err != nil { if connect.CodeOf(err) != connect.CodeNotFound { - slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err) + slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err) } } } @@ -122,14 +127,20 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { ac := auth.MustFromContext(ctx) overwrite := r.URL.Query().Get("overwrite") == "true" + // Check for global name collision. + if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil { + writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template") + return + } + // Check if name already exists for this team. - if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil { + if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil { if !overwrite { writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace") return } // Delete old snapshot files from all hosts before removing the DB record. - if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil { + if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files") return } @@ -174,9 +185,14 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute) defer snapCancel() + // Generate the new template ID upfront so the host agent knows where to store files. + newTemplateID := id.NewTemplateID() + resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{ - SandboxId: req.SandboxID, - Name: req.Name, + SandboxId: req.SandboxID, + Name: req.Name, + TeamId: formatUUIDForRPC(ac.TeamID), + TemplateId: formatUUIDForRPC(newTemplateID), })) if err != nil { // Snapshot failed — revert status back to what it was. @@ -193,6 +209,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { } tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{ + ID: newTemplateID, Name: req.Name, Type: "snapshot", Vcpus: sb.Vcpus, @@ -255,7 +272,7 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { return } - if err := h.deleteSnapshotBroadcast(ctx, name); err != nil { + if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files") return } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 6a56293c..5c9d8cdf 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -12,6 +12,9 @@ import ( "time" "connectrpc.com/connect" + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" ) type errorResponse struct { @@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) { }) } +// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages. +func formatUUIDForRPC(u pgtype.UUID) string { + return id.UUIDString(u) +} + // agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message. func agentErrToHTTP(err error) (int, string, string) { switch connect.CodeOf(err) { diff --git a/internal/db/models.go b/internal/db/models.go index f35bfe7a..d5bfc0f2 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -83,21 +83,23 @@ type OauthProvider struct { } type Sandbox struct { - ID pgtype.UUID `json:"id"` - TeamID pgtype.UUID `json:"team_id"` - HostID pgtype.UUID `json:"host_id"` - Template string `json:"template"` - Status string `json:"status"` - Vcpus int32 `json:"vcpus"` - MemoryMb int32 `json:"memory_mb"` - TimeoutSec int32 `json:"timeout_sec"` - DiskSizeMb int32 `json:"disk_size_mb"` - GuestIp string `json:"guest_ip"` - HostIp string `json:"host_ip"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - StartedAt pgtype.Timestamptz `json:"started_at"` - LastActiveAt pgtype.Timestamptz `json:"last_active_at"` - LastUpdated pgtype.Timestamptz `json:"last_updated"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` + Template string `json:"template"` + Status string `json:"status"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` + GuestIp string `json:"guest_ip"` + HostIp string `json:"host_ip"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + StartedAt pgtype.Timestamptz `json:"started_at"` + LastActiveAt pgtype.Timestamptz `json:"last_active_at"` + LastUpdated pgtype.Timestamptz `json:"last_updated"` + TemplateID pgtype.UUID `json:"template_id"` + TemplateTeamID pgtype.UUID `json:"template_team_id"` } type SandboxMetricPoint struct { @@ -146,6 +148,7 @@ type Template struct { SizeBytes int64 `json:"size_bytes"` CreatedAt pgtype.Timestamptz `json:"created_at"` TeamID pgtype.UUID `json:"team_id"` + ID pgtype.UUID `json:"id"` } type TemplateBuild struct { @@ -166,6 +169,8 @@ type TemplateBuild struct { CreatedAt pgtype.Timestamptz `json:"created_at"` StartedAt pgtype.Timestamptz `json:"started_at"` CompletedAt pgtype.Timestamptz `json:"completed_at"` + TemplateID pgtype.UUID `json:"template_id"` + TeamID pgtype.UUID `json:"team_id"` } type User struct { diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index ace43709..4107f1ab 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -43,7 +43,7 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu } const getSandbox = `-- name: GetSandbox :one -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 ` func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) { @@ -65,12 +65,14 @@ func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, erro &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } const getSandboxByTeam = `-- name: GetSandboxByTeam :one -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes WHERE id = $1 AND team_id = $2 +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2 ` type GetSandboxByTeamParams struct { @@ -97,26 +99,30 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } const insertSandbox = `-- name: InsertSandbox :one -INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type InsertSandboxParams struct { - ID pgtype.UUID `json:"id"` - TeamID pgtype.UUID `json:"team_id"` - HostID pgtype.UUID `json:"host_id"` - Template string `json:"template"` - Status string `json:"status"` - Vcpus int32 `json:"vcpus"` - MemoryMb int32 `json:"memory_mb"` - TimeoutSec int32 `json:"timeout_sec"` - DiskSizeMb int32 `json:"disk_size_mb"` + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` + HostID pgtype.UUID `json:"host_id"` + Template string `json:"template"` + Status string `json:"status"` + Vcpus int32 `json:"vcpus"` + MemoryMb int32 `json:"memory_mb"` + TimeoutSec int32 `json:"timeout_sec"` + DiskSizeMb int32 `json:"disk_size_mb"` + TemplateID pgtype.UUID `json:"template_id"` + TemplateTeamID pgtype.UUID `json:"template_team_id"` } func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) { @@ -130,6 +136,8 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S arg.MemoryMb, arg.TimeoutSec, arg.DiskSizeMb, + arg.TemplateID, + arg.TemplateTeamID, ) var i Sandbox err := row.Scan( @@ -148,12 +156,14 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE team_id = $1 AND status IN ('running', 'paused', 'starting') ORDER BY created_at DESC ` @@ -183,6 +193,8 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.U &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -195,7 +207,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.U } const listSandboxes = `-- name: ListSandboxes :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes ORDER BY created_at DESC +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC ` func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { @@ -223,6 +235,8 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -235,7 +249,7 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) { } const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE host_id = $1 AND status = ANY($2::text[]) ORDER BY created_at DESC ` @@ -270,6 +284,8 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -282,7 +298,7 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand } const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many -SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated FROM sandboxes +SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE team_id = $1 AND status NOT IN ('stopped', 'error') ORDER BY created_at DESC ` @@ -312,6 +328,8 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ( &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ); err != nil { return nil, err } @@ -364,7 +382,7 @@ SET status = 'running', last_active_at = $4, last_updated = NOW() WHERE id = $1 -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type UpdateSandboxRunningParams struct { @@ -398,6 +416,8 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } @@ -407,7 +427,7 @@ UPDATE sandboxes SET status = $2, last_updated = NOW() WHERE id = $1 -RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated +RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id ` type UpdateSandboxStatusParams struct { @@ -434,6 +454,8 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat &i.StartedAt, &i.LastActiveAt, &i.LastUpdated, + &i.TemplateID, + &i.TemplateTeamID, ) return i, err } diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go index 9e770ee1..7aa1b67e 100644 --- a/internal/db/template_builds.sql.go +++ b/internal/db/template_builds.sql.go @@ -12,7 +12,7 @@ import ( ) const getTemplateBuild = `-- name: GetTemplateBuild :one -SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at FROM template_builds WHERE id = $1 +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id FROM template_builds WHERE id = $1 ` func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) { @@ -36,14 +36,16 @@ func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (Templat &i.CreatedAt, &i.StartedAt, &i.CompletedAt, + &i.TemplateID, + &i.TeamID, ) return i, err } const insertTemplateBuild = `-- name: InsertTemplateBuild :one -INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps) -VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8) -RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10) +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id ` type InsertTemplateBuildParams struct { @@ -55,6 +57,8 @@ type InsertTemplateBuildParams struct { Vcpus int32 `json:"vcpus"` MemoryMb int32 `json:"memory_mb"` TotalSteps int32 `json:"total_steps"` + TemplateID pgtype.UUID `json:"template_id"` + TeamID pgtype.UUID `json:"team_id"` } func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) { @@ -67,6 +71,8 @@ func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBui arg.Vcpus, arg.MemoryMb, arg.TotalSteps, + arg.TemplateID, + arg.TeamID, ) var i TemplateBuild err := row.Scan( @@ -87,12 +93,14 @@ func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBui &i.CreatedAt, &i.StartedAt, &i.CompletedAt, + &i.TemplateID, + &i.TeamID, ) return i, err } const listTemplateBuilds = `-- name: ListTemplateBuilds :many -SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at FROM template_builds ORDER BY created_at DESC +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id FROM template_builds ORDER BY created_at DESC ` func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) { @@ -122,6 +130,8 @@ func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, erro &i.CreatedAt, &i.StartedAt, &i.CompletedAt, + &i.TemplateID, + &i.TeamID, ); err != nil { return nil, err } @@ -189,7 +199,7 @@ SET status = $2, started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, completed_at = CASE WHEN $2 IN ('success', 'failed') THEN NOW() ELSE completed_at END WHERE id = $1 -RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id ` type UpdateBuildStatusParams struct { @@ -218,6 +228,8 @@ func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusPa &i.CreatedAt, &i.StartedAt, &i.CompletedAt, + &i.TemplateID, + &i.TeamID, ) return i, err } diff --git a/internal/db/templates.sql.go b/internal/db/templates.sql.go index 45a673c2..7d37808d 100644 --- a/internal/db/templates.sql.go +++ b/internal/db/templates.sql.go @@ -12,11 +12,11 @@ import ( ) const deleteTemplate = `-- name: DeleteTemplate :exec -DELETE FROM templates WHERE name = $1 +DELETE FROM templates WHERE id = $1 ` -func (q *Queries) DeleteTemplate(ctx context.Context, name string) error { - _, err := q.db.Exec(ctx, deleteTemplate, name) +func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteTemplate, id) return err } @@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy return err } -const getTemplate = `-- name: GetTemplate :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 +const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec +DELETE FROM templates WHERE team_id = $1 ` -func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) { - row := q.db.QueryRow(ctx, getTemplate, name) +// Bulk delete all templates owned by a team (for team soft-delete cleanup). +func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID) + return err +} + +const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1 +` + +// Check if a global (platform) template exists with the given name. +func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) { + row := q.db.QueryRow(ctx, getPlatformTemplateByName, name) var i Template err := row.Scan( &i.Name, @@ -49,12 +60,59 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, + ) + return i, err +} + +const getTemplate = `-- name: GetTemplate :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1 +` + +func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) { + row := q.db.QueryRow(ctx, getTemplate, id) + var i Template + err := row.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, + ) + return i, err +} + +const getTemplateByName = `-- name: GetTemplateByName :one +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2 +` + +type GetTemplateByNameParams struct { + TeamID pgtype.UUID `json:"team_id"` + Name string `json:"name"` +} + +// Look up a template by team_id and name (exact team match, no global fallback). +func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) { + row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name) + var i Template + err := row.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, ) return i, err } const getTemplateByTeam = `-- name: GetTemplateByTeam :one -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000') +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000') ` type GetTemplateByTeamParams struct { @@ -74,17 +132,19 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ) return i, err } const insertTemplate = `-- name: InsertTemplate :one -INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ($1, $2, $3, $4, $5, $6, $7) +RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id ` type InsertTemplateParams struct { + ID pgtype.UUID `json:"id"` Name string `json:"name"` Type string `json:"type"` Vcpus int32 `json:"vcpus"` @@ -95,6 +155,7 @@ type InsertTemplateParams struct { func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) { row := q.db.QueryRow(ctx, insertTemplate, + arg.ID, arg.Name, arg.Type, arg.Vcpus, @@ -111,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ) return i, err } const listTemplates = `-- name: ListTemplates :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC ` func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { @@ -136,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -148,7 +211,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) { } const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC ` // Platform templates are visible to all teams. @@ -169,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ( &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -181,7 +245,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ( } const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC ` type ListTemplatesByTeamAndTypeParams struct { @@ -207,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC +` + +// List templates owned by a specific team (NOT including platform templates). +func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) { + rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Template + for rows.Next() { + var i Template + if err := rows.Scan( + &i.Name, + &i.Type, + &i.Vcpus, + &i.MemoryMb, + &i.SizeBytes, + &i.CreatedAt, + &i.TeamID, + &i.ID, ); err != nil { return nil, err } @@ -219,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla } const listTemplatesByType = `-- name: ListTemplatesByType :many -SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC +SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) { @@ -239,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp &i.SizeBytes, &i.CreatedAt, &i.TeamID, + &i.ID, ); err != nil { return nil, err } diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index 549158f0..ab016f8a 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -12,6 +12,8 @@ import ( "time" "connectrpc.com/connect" + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -33,13 +35,35 @@ func NewServer(mgr *sandbox.Manager, terminate func()) *Server { return &Server{mgr: mgr, terminate: terminate} } +// parseUUIDString parses a UUID hex string into a pgtype.UUID. +// An empty string yields an all-zeros UUID (valid). +func parseUUIDString(s string) (pgtype.UUID, error) { + if s == "" { + return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil + } + parsed, err := uuid.Parse(s) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err) + } + return pgtype.UUID{Bytes: parsed, Valid: true}, nil +} + func (s *Server) CreateSandbox( ctx context.Context, req *connect.Request[pb.CreateSandboxRequest], ) (*connect.Response[pb.CreateSandboxResponse], error) { msg := req.Msg - sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb)) + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb)) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err)) } @@ -90,12 +114,22 @@ func (s *Server) CreateSnapshot( ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest], ) (*connect.Response[pb.CreateSnapshotResponse], error) { - sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name) + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err)) } return connect.NewResponse(&pb.CreateSnapshotResponse{ - Name: req.Msg.Name, + Name: msg.Name, SizeBytes: sizeBytes, }), nil } @@ -104,7 +138,17 @@ func (s *Server) DeleteSnapshot( ctx context.Context, req *connect.Request[pb.DeleteSnapshotRequest], ) (*connect.Response[pb.DeleteSnapshotResponse], error) { - if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil { + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err)) } return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil @@ -114,7 +158,17 @@ func (s *Server) FlattenRootfs( ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest], ) (*connect.Response[pb.FlattenRootfsResponse], error) { - sizeBytes, err := s.mgr.FlattenRootfs(ctx, req.Msg.SandboxId, req.Msg.Name) + msg := req.Msg + teamID, err := parseUUIDString(msg.TeamId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + templateID, err := parseUUIDString(msg.TemplateId) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err)) } @@ -413,7 +467,8 @@ func (s *Server) ListSandboxes( infos[i] = &pb.SandboxInfo{ SandboxId: sb.ID, Status: string(sb.Status), - Template: sb.Template, + TeamId: uuid.UUID(sb.TemplateTeamID).String(), + TemplateId: uuid.UUID(sb.TemplateID).String(), Vcpus: int32(sb.VCPUs), MemoryMb: int32(sb.MemoryMB), HostIp: sb.HostIP.String(), diff --git a/internal/id/id.go b/internal/id/id.go index 35c44aeb..45cba6c5 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -36,8 +36,9 @@ func NewAuditLogID() pgtype.UUID { return newUUID() } func NewBuildID() pgtype.UUID { return newUUID() } func NewAdminPermissionID() pgtype.UUID { return newUUID() } +func NewTemplateID() pgtype.UUID { return newUUID() } + // NewSnapshotName generates a snapshot name: "template-" + 8 hex chars. -// Templates use TEXT primary keys (not UUID), so this stays as a string. func NewSnapshotName() string { return "template-" + hex8() } @@ -76,8 +77,8 @@ const ( PrefixAdminPermission = "perm-" ) -// uuidToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z). -func uuidToBase36(b [16]byte) string { +// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z). +func UUIDToBase36(b [16]byte) string { n := new(big.Int).SetBytes(b[:]) buf := make([]byte, base36IDLen) mod := new(big.Int) @@ -110,7 +111,7 @@ func base36ToUUID(s string) ([16]byte, error) { } func formatUUID(prefix string, id pgtype.UUID) string { - return prefix + uuidToBase36(id.Bytes) + return prefix + UUIDToBase36(id.Bytes) } func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) } @@ -151,6 +152,17 @@ func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBu // (e.g. base templates, shared infrastructure). var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true} +// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal" +// template. When both team_id and template_id are zero, the host agent uses +// the minimal rootfs at WRENN_DIR/images/minimal/. +var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true} + +// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string +// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format. +func UUIDString(id pgtype.UUID) string { + return uuid.UUID(id.Bytes).String() +} + // --- Helpers --- func hex8() string { diff --git a/internal/id/id_test.go b/internal/id/id_test.go index f8ae2853..c16ec7a0 100644 --- a/internal/id/id_test.go +++ b/internal/id/id_test.go @@ -10,7 +10,7 @@ import ( func TestBase36RoundTrip(t *testing.T) { for i := 0; i < 1000; i++ { orig := uuid.New() - encoded := uuidToBase36(orig) + encoded := UUIDToBase36(orig) if len(encoded) != base36IDLen { t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded) @@ -29,7 +29,7 @@ func TestBase36RoundTrip(t *testing.T) { func TestBase36ZeroUUID(t *testing.T) { var zero [16]byte - encoded := uuidToBase36(zero) + encoded := UUIDToBase36(zero) if encoded != "0000000000000000000000000" { t.Fatalf("zero UUID should encode to all zeros, got %s", encoded) } @@ -87,7 +87,7 @@ func TestPlatformTeamIDFormats(t *testing.T) { func TestMaxUUID(t *testing.T) { max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - encoded := uuidToBase36(max) + encoded := UUIDToBase36(max) if len(encoded) != base36IDLen { t.Fatalf("max UUID encoding wrong length: %d", len(encoded)) } diff --git a/internal/layout/layout.go b/internal/layout/layout.go new file mode 100644 index 00000000..d4084f5f --- /dev/null +++ b/internal/layout/layout.go @@ -0,0 +1,58 @@ +package layout + +import ( + "path/filepath" + + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +// IsMinimal reports whether the given team and template IDs represent the +// built-in "minimal" template (both all-zeros). +func IsMinimal(teamID, templateID pgtype.UUID) bool { + return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes +} + +// TemplateDir returns the on-disk directory for a template. +// +// minimal (zeros, zeros): {wrennDir}/images/minimal +// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)} +func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string { + if IsMinimal(teamID, templateID) { + return filepath.Join(wrennDir, "images", "minimal") + } + return filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(teamID.Bytes), + id.UUIDToBase36(templateID.Bytes)) +} + +// TemplateRootfs returns the path to a template's rootfs.ext4. +func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string { + return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4") +} + +// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files. +func PauseSnapshotDir(wrennDir, sandboxID string) string { + return filepath.Join(wrennDir, "snapshots", sandboxID) +} + +// SandboxesDir returns the directory for running sandbox CoW files. +func SandboxesDir(wrennDir string) string { + return filepath.Join(wrennDir, "sandboxes") +} + +// KernelPath returns the path to the Firecracker kernel. +func KernelPath(wrennDir string) string { + return filepath.Join(wrennDir, "kernels", "vmlinux") +} + +// ImagesRoot returns the root images directory. +func ImagesRoot(wrennDir string) string { + return filepath.Join(wrennDir, "images") +} + +// TeamsDir returns the directory containing all team template subdirectories. +func TeamsDir(wrennDir string) string { + return filepath.Join(wrennDir, "images", "teams") +} diff --git a/internal/layout/layout_test.go b/internal/layout/layout_test.go new file mode 100644 index 00000000..bffdae26 --- /dev/null +++ b/internal/layout/layout_test.go @@ -0,0 +1,120 @@ +package layout + +import ( + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +func TestIsMinimal(t *testing.T) { + tests := []struct { + name string + teamID pgtype.UUID + templateID pgtype.UUID + want bool + }{ + { + name: "both zeros", + teamID: id.PlatformTeamID, + templateID: id.MinimalTemplateID, + want: true, + }, + { + name: "non-zero team", + teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}, + templateID: id.MinimalTemplateID, + want: false, + }, + { + name: "non-zero template", + teamID: id.PlatformTeamID, + templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}, + want: false, + }, + { + name: "both non-zero", + teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true}, + templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want { + t.Errorf("IsMinimal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTemplateDir(t *testing.T) { + wrennDir := "/var/lib/wrenn" + + t.Run("minimal", func(t *testing.T) { + got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + want := filepath.Join(wrennDir, "images", "minimal") + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) + + t.Run("team template", func(t *testing.T) { + teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true} + tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true} + got := TemplateDir(wrennDir, teamID, tmplID) + want := filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(teamID.Bytes), + id.UUIDToBase36(tmplID.Bytes)) + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) + + t.Run("global template (platform team, non-zero template)", func(t *testing.T) { + tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true} + got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID) + want := filepath.Join(wrennDir, "images", "teams", + id.UUIDToBase36(id.PlatformTeamID.Bytes), + id.UUIDToBase36(tmplID.Bytes)) + if got != want { + t.Errorf("TemplateDir() = %q, want %q", got, want) + } + }) +} + +func TestTemplateRootfs(t *testing.T) { + wrennDir := "/var/lib/wrenn" + got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4") + if got != want { + t.Errorf("TemplateRootfs() = %q, want %q", got, want) + } +} + +func TestPauseSnapshotDir(t *testing.T) { + got := PauseSnapshotDir("/var/lib/wrenn", "sb-abc123") + want := "/var/lib/wrenn/snapshots/sb-abc123" + if got != want { + t.Errorf("PauseSnapshotDir() = %q, want %q", got, want) + } +} + +func TestSandboxesDir(t *testing.T) { + got := SandboxesDir("/var/lib/wrenn") + want := "/var/lib/wrenn/sandboxes" + if got != want { + t.Errorf("SandboxesDir() = %q, want %q", got, want) + } +} + +func TestKernelPath(t *testing.T) { + got := KernelPath("/var/lib/wrenn") + want := "/var/lib/wrenn/kernels/vmlinux" + if got != want { + t.Errorf("KernelPath() = %q, want %q", got, want) + } +} diff --git a/internal/models/sandbox.go b/internal/models/sandbox.go index b99bd6b8..ab72cd3e 100644 --- a/internal/models/sandbox.go +++ b/internal/models/sandbox.go @@ -18,15 +18,16 @@ const ( // Sandbox holds all state for a running sandbox on this host. type Sandbox struct { - ID string - Status SandboxStatus - Template string - VCPUs int - MemoryMB int - TimeoutSec int - SlotIndex int - HostIP net.IP - RootfsPath string - CreatedAt time.Time - LastActiveAt time.Time + ID string + Status SandboxStatus + TemplateTeamID [16]byte + TemplateID [16]byte + VCPUs int + MemoryMB int + TimeoutSec int + SlotIndex int + HostIP net.IP + RootfsPath string + CreatedAt time.Time + LastActiveAt time.Time } diff --git a/internal/sandbox/images.go b/internal/sandbox/images.go index feb3398b..1716d80d 100644 --- a/internal/sandbox/images.go +++ b/internal/sandbox/images.go @@ -6,6 +6,9 @@ import ( "os" "os/exec" "path/filepath" + + "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" ) // DefaultDiskSizeMB is the standard disk size for base images. Images smaller @@ -14,61 +17,90 @@ import ( // changes; no physical disk is consumed beyond the original content. const DefaultDiskSizeMB = 5120 // 5 GB -// EnsureImageSizes walks the images directory and expands any rootfs.ext4 that +// EnsureImageSizes walks template directories and expands any rootfs.ext4 that // is smaller than the target size. This is idempotent: images already at or // above the target size are left untouched. Should be called once at host agent // startup before any sandboxes are created. -func EnsureImageSizes(imagesDir string, targetMB int) error { +func EnsureImageSizes(wrennDir string, targetMB int) error { if targetMB <= 0 { targetMB = DefaultDiskSizeMB } targetBytes := int64(targetMB) * 1024 * 1024 - entries, err := os.ReadDir(imagesDir) - if err != nil { - return fmt.Errorf("read images dir: %w", err) + // Expand the built-in minimal image. + minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID) + if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil { + return err } - for _, entry := range entries { - if !entry.IsDir() { + // Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep. + teamsDir := layout.TeamsDir(wrennDir) + teamEntries, err := os.ReadDir(teamsDir) + if err != nil { + if os.IsNotExist(err) { + return nil // teams dir doesn't exist yet — nothing to expand + } + return fmt.Errorf("read teams dir: %w", err) + } + + for _, teamEntry := range teamEntries { + if !teamEntry.IsDir() { continue } - rootfs := filepath.Join(imagesDir, entry.Name(), "rootfs.ext4") - info, err := os.Stat(rootfs) + teamPath := filepath.Join(teamsDir, teamEntry.Name()) + templateEntries, err := os.ReadDir(teamPath) if err != nil { - continue // not every template dir has a rootfs.ext4 + continue } - - if info.Size() >= targetBytes { - continue // already large enough - } - - slog.Info("expanding base image", - "template", entry.Name(), - "from_mb", info.Size()/(1024*1024), - "to_mb", targetMB, - ) - - // Expand the file (sparse — instant, no physical disk used). - if err := os.Truncate(rootfs, targetBytes); err != nil { - return fmt.Errorf("truncate %s: %w", rootfs, err) - } - - // Check filesystem before resize. - if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil { - // e2fsck returns 1 if it fixed errors, which is fine. - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { - return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err) + for _, tmplEntry := range templateEntries { + if !tmplEntry.IsDir() { + continue + } + rootfs := filepath.Join(teamPath, tmplEntry.Name(), "rootfs.ext4") + if err := expandImage(rootfs, targetBytes, targetMB); err != nil { + return err } } - - // Grow the ext4 filesystem to fill the new file size. - if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil { - return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err) - } - - slog.Info("base image expanded", "template", entry.Name(), "size_mb", targetMB) } return nil } + +// expandImage expands a single rootfs image if it is smaller than targetBytes. +func expandImage(rootfs string, targetBytes int64, targetMB int) error { + info, err := os.Stat(rootfs) + if err != nil { + return nil // not every template dir has a rootfs.ext4 + } + + if info.Size() >= targetBytes { + return nil // already large enough + } + + slog.Info("expanding base image", + "path", rootfs, + "from_mb", info.Size()/(1024*1024), + "to_mb", targetMB, + ) + + // Expand the file (sparse — instant, no physical disk used). + if err := os.Truncate(rootfs, targetBytes); err != nil { + return fmt.Errorf("truncate %s: %w", rootfs, err) + } + + // Check filesystem before resize. + if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil { + // e2fsck returns 1 if it fixed errors, which is fine. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 { + return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err) + } + } + + // Grow the ext4 filesystem to fill the new file size. + if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil { + return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err) + } + + slog.Info("base image expanded", "path", rootfs, "size_mb", targetMB) + return nil +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 9fcecb5a..4e97c49b 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -11,25 +11,23 @@ import ( "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/envdclient" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" "git.omukk.dev/wrenn/sandbox/internal/models" "git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/snapshot" "git.omukk.dev/wrenn/sandbox/internal/uffd" - "git.omukk.dev/wrenn/sandbox/internal/validate" "git.omukk.dev/wrenn/sandbox/internal/vm" ) // Config holds the paths and defaults for the sandbox manager. type Config struct { - KernelPath string - ImagesDir string // directory containing template images (e.g., /var/lib/wrenn/images/{name}/rootfs.ext4) - SandboxesDir string // directory for per-sandbox rootfs clones (e.g., /var/lib/wrenn/sandboxes) - SnapshotsDir string // directory for pause snapshots (e.g., /var/lib/wrenn/snapshots/{sandbox-id}/) - EnvdTimeout time.Duration + WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package + EnvdTimeout time.Duration } // Manager orchestrates sandbox lifecycle: VM, network, filesystem, envd. @@ -52,8 +50,8 @@ type sandboxState struct { slot *network.Slot client *envdclient.Client uffdSocketPath string // non-empty for sandboxes restored from snapshot - dmDevice *devicemapper.SnapshotDevice - baseImagePath string // path to the base template rootfs (for loop registry release) + dmDevice *devicemapper.SnapshotDevice + baseImagePath string // path to the base template rootfs (for loop registry release) // parent holds the snapshot header and diff file paths from which this // sandbox was restored. Non-nil means re-pause should use "Diff" snapshot @@ -95,7 +93,7 @@ func New(cfg Config) *Manager { // Create boots a new sandbox: clone rootfs, set up network, start VM, wait for envd. // If sandboxID is empty, a new ID is generated. -func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, memoryMB, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { +func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, memoryMB, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { if sandboxID == "" { sandboxID = id.FormatSandboxID(id.NewSandboxID()) } @@ -110,20 +108,14 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, diskSizeMB = 5120 // 5 GB default } - if template == "" { - template = "minimal" - } - if err := validate.SafeName(template); err != nil { - return nil, fmt.Errorf("invalid template name: %w", err) - } - // Check if template refers to a snapshot (has snapfile + memfile + header + rootfs). - if snapshot.IsSnapshot(m.cfg.ImagesDir, template) { - return m.createFromSnapshot(ctx, sandboxID, template, vcpus, memoryMB, timeoutSec, diskSizeMB) + tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if _, err := os.Stat(filepath.Join(tmplDir, snapshot.SnapFileName)); err == nil { + return m.createFromSnapshot(ctx, sandboxID, teamID, templateID, vcpus, memoryMB, timeoutSec, diskSizeMB) } - // Resolve base rootfs image: /var/lib/wrenn/images/{template}/rootfs.ext4 - baseRootfs := filepath.Join(m.cfg.ImagesDir, template, "rootfs.ext4") + // Resolve base rootfs image. + baseRootfs := layout.TemplateRootfs(m.cfg.WrennDir, teamID, templateID) if _, err := os.Stat(baseRootfs); err != nil { return nil, fmt.Errorf("base rootfs not found at %s: %w", baseRootfs, err) } @@ -142,7 +134,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Create dm-snapshot with per-sandbox CoW file. dmName := "wrenn-" + sandboxID - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) cowSize := int64(diskSizeMB) * 1024 * 1024 dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { @@ -172,7 +164,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, // Boot VM — Firecracker gets the dm device path. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, MemoryMB: memoryMB, @@ -211,17 +203,18 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ - ID: sandboxID, - Status: models.StatusRunning, - Template: template, - VCPUs: vcpus, - MemoryMB: memoryMB, - TimeoutSec: timeoutSec, - SlotIndex: slotIdx, - HostIP: slot.HostIP, - RootfsPath: dmDev.DevicePath, - CreatedAt: now, - LastActiveAt: now, + ID: sandboxID, + Status: models.StatusRunning, + TemplateTeamID: teamID.Bytes, + TemplateID: templateID.Bytes, + VCPUs: vcpus, + MemoryMB: memoryMB, + TimeoutSec: timeoutSec, + SlotIndex: slotIdx, + HostIP: slot.HostIP, + RootfsPath: dmDev.DevicePath, + CreatedAt: now, + LastActiveAt: now, }, slot: slot, client: client, @@ -237,7 +230,8 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, slog.Info("sandbox created", "id", sandboxID, - "template", template, + "team_id", teamID, + "template_id", templateID, "host_ip", slot.HostIP.String(), "dm_device", dmDev.DevicePath, ) @@ -260,7 +254,9 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { } // Always clean up pause snapshot files (may exist if sandbox was paused). - warnErr("snapshot cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + if err := os.RemoveAll(layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)); err != nil { + slog.Warn("snapshot cleanup error", "id", sandboxID, "error", err) + } slog.Info("sandbox destroyed", "id", sandboxID) return nil @@ -331,18 +327,18 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } // Step 2: Take VM state snapshot (snapfile + memfile). - if err := snapshot.EnsureDir(m.cfg.SnapshotsDir, sandboxID); err != nil { + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if err := os.MkdirAll(pauseDir, 0755); err != nil { resumeOnError() return fmt.Errorf("create snapshot dir: %w", err) } - snapDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID) - rawMemPath := filepath.Join(snapDir, "memfile.raw") - snapPath := snapshot.SnapPath(m.cfg.SnapshotsDir, sandboxID) + rawMemPath := filepath.Join(pauseDir, "memfile.raw") + snapPath := filepath.Join(pauseDir, snapshot.SnapFileName) snapshotStart := time.Now() if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("create VM snapshot: %w", err) } @@ -350,24 +346,24 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // Step 3: Process the raw memfile into a compact diff + header. buildID := uuid.New() - headerPath := snapshot.MemHeaderPath(m.cfg.SnapshotsDir, sandboxID) + headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName) processStart := time.Now() if sb.parent != nil && snapshotType == "Diff" { // Diff: process against parent header, producing only changed blocks. - diffPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, buildID) + diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID) if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("process memfile with parent: %w", err) } // Copy previous generation diff files into the snapshot directory. for prevBuildID, prevPath := range sb.parent.diffPaths { - dstPath := snapshot.MemDiffPathForBuild(m.cfg.SnapshotsDir, sandboxID, uuid.MustParse(prevBuildID)) + dstPath := snapshot.MemDiffPathForBuild(pauseDir, "", uuid.MustParse(prevBuildID)) if prevPath != dstPath { if err := copyFile(prevPath, dstPath); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("copy parent diff file: %w", err) } @@ -375,9 +371,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } } else { // Full: first generation or generation cap reached — single diff file. - diffPath := snapshot.MemDiffPath(m.cfg.SnapshotsDir, sandboxID) + diffPath := snapshot.MemDiffPath(pauseDir, "") if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("process memfile: %w", err) } @@ -407,7 +403,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { if sb.uffdSocketPath != "" { os.Remove(sb.uffdSocketPath) } - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) m.mu.Lock() delete(m.boxes, sandboxID) m.mu.Unlock() @@ -415,9 +411,9 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } // Move (not copy) the CoW file into the snapshot directory. - snapshotCow := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID) + snapshotCow := snapshot.CowPath(pauseDir, "") if err := os.Rename(sb.dmDevice.CowPath, snapshotCow); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) // VM and dm-snapshot are already gone — clean up remaining resources. warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot)) m.slots.Release(sb.SlotIndex) @@ -434,10 +430,10 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } // Record which base template this CoW was built against. - if err := snapshot.WriteMeta(m.cfg.SnapshotsDir, sandboxID, &snapshot.RootfsMeta{ + if err := snapshot.WriteMeta(pauseDir, "", &snapshot.RootfsMeta{ BaseTemplate: sb.baseImagePath, }); err != nil { - warnErr("snapshot dir cleanup error", sandboxID, snapshot.Remove(m.cfg.SnapshotsDir, sandboxID)) + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) // VM and dm-snapshot are already gone — clean up remaining resources. warnErr("network cleanup error during pause", sandboxID, network.RemoveNetwork(sb.slot)) m.slots.Release(sb.SlotIndex) @@ -477,13 +473,13 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // Resume restores a paused sandbox from its snapshot using UFFD for // lazy memory loading. The sandbox gets a new network slot. func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) (*models.Sandbox, error) { - snapDir := m.cfg.SnapshotsDir - if !snapshot.Exists(snapDir, sandboxID) { + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if _, err := os.Stat(pauseDir); err != nil { return nil, fmt.Errorf("no snapshot found for sandbox %s", sandboxID) } // Read the header to set up the UFFD memory source. - headerData, err := os.ReadFile(snapshot.MemHeaderPath(snapDir, sandboxID)) + headerData, err := os.ReadFile(filepath.Join(pauseDir, snapshot.MemHeaderName)) if err != nil { return nil, fmt.Errorf("read header: %w", err) } @@ -494,7 +490,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Build diff file map — supports both single-generation and multi-generation. - diffPaths, err := snapshot.ListDiffFiles(snapDir, sandboxID, header) + diffPaths, err := snapshot.ListDiffFiles(pauseDir, "", header) if err != nil { return nil, fmt.Errorf("list diff files: %w", err) } @@ -505,7 +501,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Read rootfs metadata to find the base template image. - meta, err := snapshot.ReadMeta(snapDir, sandboxID) + meta, err := snapshot.ReadMeta(pauseDir, "") if err != nil { source.Close() return nil, fmt.Errorf("read rootfs meta: %w", err) @@ -527,8 +523,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Move CoW file from snapshot dir to sandboxes dir for the running sandbox. - savedCow := snapshot.CowPath(snapDir, sandboxID) - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) + savedCow := snapshot.CowPath(pauseDir, "") + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) if err := os.Rename(savedCow, cowPath); err != nil { source.Close() m.loops.Release(baseImagePath) @@ -574,7 +570,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) } // Start UFFD server. - uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID)) + uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID)) os.Remove(uffdSocketPath) // Clean stale socket. uffdServer := uffd.NewServer(uffdSocketPath, source) if err := uffdServer.Start(ctx); err != nil { @@ -590,7 +586,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) // Restore VM from snapshot. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: 1, // Placeholder; overridden by snapshot. MemoryMB: int(header.Metadata.Size / (1024 * 1024)), // Placeholder; overridden by snapshot. @@ -602,8 +598,8 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) NetMask: slot.GuestNetMask, } - snapPath := snapshot.SnapPath(snapDir, sandboxID) - if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil { + resumeSnapPath := filepath.Join(pauseDir, snapshot.SnapFileName) + if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, resumeSnapPath, uffdSocketPath); err != nil { warnErr("uffd server stop error", sandboxID, uffdServer.Stop()) source.Close() warnErr("network cleanup error", sandboxID, network.RemoveNetwork(slot)) @@ -636,7 +632,6 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) Sandbox: models.Sandbox{ ID: sandboxID, Status: models.StatusRunning, - Template: "", VCPUs: vmCfg.VCPUs, MemoryMB: vmCfg.MemoryMB, TimeoutSec: timeoutSec, @@ -685,11 +680,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) // The rootfs is flattened (base + CoW merged) into a new standalone rootfs.ext4 // so the template has no dependency on the original base image. Memory state // and VM snapshot files are copied as-is. -func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (int64, error) { - if err := validate.SafeName(name); err != nil { - return 0, fmt.Errorf("invalid snapshot name: %w", err) - } - +func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) { // If the sandbox is running, pause it first. if _, err := m.get(sandboxID); err == nil { if err := m.Pause(ctx, sandboxID); err != nil { @@ -697,25 +688,26 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i } } - // At this point, pause snapshot files must exist in SnapshotsDir/{sandboxID}/. - if !snapshot.Exists(m.cfg.SnapshotsDir, sandboxID) { + // At this point, pause snapshot files must exist. + pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID) + if _, err := os.Stat(pauseDir); err != nil { return 0, fmt.Errorf("no snapshot found for sandbox %s", sandboxID) } // Create template directory. - if err := snapshot.EnsureDir(m.cfg.ImagesDir, name); err != nil { + dstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if err := os.MkdirAll(dstDir, 0755); err != nil { return 0, fmt.Errorf("create template dir: %w", err) } // Copy VM snapshot file and memory header. - srcDir := snapshot.DirPath(m.cfg.SnapshotsDir, sandboxID) - dstDir := snapshot.DirPath(m.cfg.ImagesDir, name) + srcDir := pauseDir for _, fname := range []string{snapshot.SnapFileName, snapshot.MemHeaderName} { src := filepath.Join(srcDir, fname) dst := filepath.Join(dstDir, fname) if err := copyFile(src, dst); err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("copy %s: %w", fname, err) } } @@ -723,59 +715,59 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i // Copy all memory diff files referenced by the header (supports multi-generation). headerData, err := os.ReadFile(filepath.Join(srcDir, snapshot.MemHeaderName)) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("read header for template: %w", err) } srcHeader, err := snapshot.Deserialize(headerData) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("deserialize header for template: %w", err) } - srcDiffPaths, err := snapshot.ListDiffFiles(m.cfg.SnapshotsDir, sandboxID, srcHeader) + srcDiffPaths, err := snapshot.ListDiffFiles(pauseDir, "", srcHeader) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("list diff files for template: %w", err) } for _, srcPath := range srcDiffPaths { dstPath := filepath.Join(dstDir, filepath.Base(srcPath)) if err := copyFile(srcPath, dstPath); err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("copy diff file %s: %w", filepath.Base(srcPath), err) } } // Flatten rootfs: temporarily set up dm device from base + CoW, dd to new image. - meta, err := snapshot.ReadMeta(m.cfg.SnapshotsDir, sandboxID) + meta, err := snapshot.ReadMeta(pauseDir, "") if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("read rootfs meta: %w", err) } originLoop, err := m.loops.Acquire(meta.BaseTemplate) if err != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("acquire loop device for flatten: %w", err) } originSize, err := devicemapper.OriginSizeBytes(originLoop) if err != nil { m.loops.Release(meta.BaseTemplate) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("get origin size: %w", err) } // Temporarily restore the dm-snapshot to read the merged view. - cowPath := snapshot.CowPath(m.cfg.SnapshotsDir, sandboxID) + cowPath := snapshot.CowPath(pauseDir, "") tmpDmName := "wrenn-flatten-" + sandboxID tmpDev, err := devicemapper.RestoreSnapshot(ctx, tmpDmName, originLoop, cowPath, originSize) if err != nil { m.loops.Release(meta.BaseTemplate) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("restore dm-snapshot for flatten: %w", err) } // Flatten to new standalone rootfs. - flattenedPath := snapshot.RootfsPath(m.cfg.ImagesDir, name) + flattenedPath := filepath.Join(dstDir, snapshot.RootfsFileName) flattenErr := devicemapper.FlattenSnapshot(tmpDev.DevicePath, flattenedPath) // Always clean up the temporary dm device. @@ -783,18 +775,19 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i m.loops.Release(meta.BaseTemplate) if flattenErr != nil { - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", dstDir, os.RemoveAll(dstDir)) return 0, fmt.Errorf("flatten rootfs: %w", flattenErr) } - sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name) + sizeBytes, err := snapshot.DirSize(dstDir, "") if err != nil { slog.Warn("failed to calculate snapshot size", "error", err) } slog.Info("template snapshot created (rootfs flattened)", "sandbox", sandboxID, - "name", name, + "team_id", teamID, + "template_id", templateID, "size_bytes", sizeBytes, ) return sizeBytes, nil @@ -804,11 +797,7 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i // rootfs into a standalone rootfs.ext4, and cleans up all resources. // The result is an image-only template (no VM memory/CPU state) stored in // ImagesDir/{name}/rootfs.ext4. -func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (int64, error) { - if err := validate.SafeName(name); err != nil { - return 0, fmt.Errorf("invalid template name: %w", err) - } - +func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID) (int64, error) { m.mu.Lock() sb, ok := m.boxes[sandboxID] if ok { @@ -837,21 +826,22 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (in } // Create template directory and flatten the dm-snapshot. - if err := snapshot.EnsureDir(m.cfg.ImagesDir, name); err != nil { + flattenDstDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) + if err := os.MkdirAll(flattenDstDir, 0755); err != nil { m.cleanupDM(sb) return 0, fmt.Errorf("create template dir: %w", err) } - outputPath := snapshot.RootfsPath(m.cfg.ImagesDir, name) + outputPath := filepath.Join(flattenDstDir, snapshot.RootfsFileName) if sb.dmDevice == nil { m.cleanupDM(sb) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir)) return 0, fmt.Errorf("sandbox %s has no dm device", sandboxID) } if err := devicemapper.FlattenSnapshot(sb.dmDevice.DevicePath, outputPath); err != nil { m.cleanupDM(sb) - warnErr("template dir cleanup error", name, snapshot.Remove(m.cfg.ImagesDir, name)) + warnErr("template dir cleanup error", flattenDstDir, os.RemoveAll(flattenDstDir)) return 0, fmt.Errorf("flatten rootfs: %w", err) } @@ -869,14 +859,15 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID, name string) (in slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err) } - sizeBytes, err := snapshot.DirSize(m.cfg.ImagesDir, name) + sizeBytes, err := snapshot.DirSize(flattenDstDir, "") if err != nil { slog.Warn("failed to calculate template size", "error", err) } slog.Info("rootfs flattened to image-only template", "sandbox", sandboxID, - "name", name, + "team_id", teamID, + "template_id", templateID, "size_bytes", sizeBytes, ) return sizeBytes, nil @@ -896,22 +887,19 @@ func (m *Manager) cleanupDM(sb *sandboxState) { } // DeleteSnapshot removes a snapshot template from disk. -func (m *Manager) DeleteSnapshot(name string) error { - if err := validate.SafeName(name); err != nil { - return fmt.Errorf("invalid snapshot name: %w", err) - } - return snapshot.Remove(m.cfg.ImagesDir, name) +func (m *Manager) DeleteSnapshot(teamID, templateID pgtype.UUID) error { + return os.RemoveAll(layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)) } // createFromSnapshot creates a new sandbox by restoring from a snapshot template // in ImagesDir/{snapshotName}/. Uses UFFD for lazy memory loading. // The template's rootfs.ext4 is a flattened standalone image — we create a // dm-snapshot on top of it just like a normal Create. -func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotName string, vcpus, _, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { - imagesDir := m.cfg.ImagesDir +func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, teamID, templateID pgtype.UUID, vcpus, _, timeoutSec, diskSizeMB int) (*models.Sandbox, error) { + tmplDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID) // Read the header. - headerData, err := os.ReadFile(snapshot.MemHeaderPath(imagesDir, snapshotName)) + headerData, err := os.ReadFile(filepath.Join(tmplDir, snapshot.MemHeaderName)) if err != nil { return nil, fmt.Errorf("read snapshot header: %w", err) } @@ -925,7 +913,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam memoryMB := int(header.Metadata.Size / (1024 * 1024)) // Build diff file map — supports multi-generation templates. - diffPaths, err := snapshot.ListDiffFiles(imagesDir, snapshotName, header) + diffPaths, err := snapshot.ListDiffFiles(tmplDir, "", header) if err != nil { return nil, fmt.Errorf("list diff files: %w", err) } @@ -936,7 +924,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } // Set up dm-snapshot on the template's flattened rootfs. - baseRootfs := snapshot.RootfsPath(imagesDir, snapshotName) + baseRootfs := filepath.Join(tmplDir, snapshot.RootfsFileName) originLoop, err := m.loops.Acquire(baseRootfs) if err != nil { source.Close() @@ -951,7 +939,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } dmName := "wrenn-" + sandboxID - cowPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s.cow", sandboxID)) + cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) cowSize := int64(diskSizeMB) * 1024 * 1024 dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { @@ -981,7 +969,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam } // Start UFFD server. - uffdSocketPath := filepath.Join(m.cfg.SandboxesDir, fmt.Sprintf("%s-uffd.sock", sandboxID)) + uffdSocketPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s-uffd.sock", sandboxID)) os.Remove(uffdSocketPath) uffdServer := uffd.NewServer(uffdSocketPath, source) if err := uffdServer.Start(ctx); err != nil { @@ -997,7 +985,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam // Restore VM. vmCfg := vm.VMConfig{ SandboxID: sandboxID, - KernelPath: m.cfg.KernelPath, + KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, MemoryMB: memoryMB, @@ -1009,7 +997,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam NetMask: slot.GuestNetMask, } - snapPath := snapshot.SnapPath(imagesDir, snapshotName) + snapPath := filepath.Join(tmplDir, snapshot.SnapFileName) if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg, snapPath, uffdSocketPath); err != nil { warnErr("uffd server stop error", sandboxID, uffdServer.Stop()) source.Close() @@ -1041,17 +1029,18 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam now := time.Now() sb := &sandboxState{ Sandbox: models.Sandbox{ - ID: sandboxID, - Status: models.StatusRunning, - Template: snapshotName, - VCPUs: vcpus, - MemoryMB: memoryMB, - TimeoutSec: timeoutSec, - SlotIndex: slotIdx, - HostIP: slot.HostIP, - RootfsPath: dmDev.DevicePath, - CreatedAt: now, - LastActiveAt: now, + ID: sandboxID, + Status: models.StatusRunning, + TemplateTeamID: teamID.Bytes, + TemplateID: templateID.Bytes, + VCPUs: vcpus, + MemoryMB: memoryMB, + TimeoutSec: timeoutSec, + SlotIndex: slotIdx, + HostIP: slot.HostIP, + RootfsPath: dmDev.DevicePath, + CreatedAt: now, + LastActiveAt: now, }, slot: slot, client: client, @@ -1073,7 +1062,8 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID, snapshotNam slog.Info("sandbox created from snapshot", "id", sandboxID, - "snapshot", snapshotName, + "team_id", teamID, + "template_id", templateID, "host_ip", slot.HostIP.String(), "dm_device", dmDev.DevicePath, ) diff --git a/internal/service/build.go b/internal/service/build.go index 5142a133..2592a6d9 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -95,6 +95,7 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp buildID := id.NewBuildID() buildIDStr := id.FormatBuildID(buildID) + newTemplateID := id.NewTemplateID() build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{ ID: buildID, @@ -105,6 +106,8 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TotalSteps: int32(len(p.Recipe) + len(preBuildCmds) + len(postBuildCmds)), + TemplateID: newTemplateID, + TeamID: id.PlatformTeamID, }) if err != nil { return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) @@ -207,12 +210,27 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { sandboxIDStr := id.FormatSandboxID(sandboxID) log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID)) + // Resolve the base template to UUIDs. "minimal" is the zero sentinel. + baseTeamID := id.PlatformTeamID + baseTemplateID := id.MinimalTemplateID + if build.BaseTemplate != "minimal" { + baseTmpl, err := s.DB.GetPlatformTemplateByName(ctx, build.BaseTemplate) + if err != nil { + s.failBuild(ctx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err)) + return + } + baseTeamID = baseTmpl.TeamID + baseTemplateID = baseTmpl.ID + } + resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ SandboxId: sandboxIDStr, Template: build.BaseTemplate, + TeamId: id.UUIDString(baseTeamID), + TemplateId: id.UUIDString(baseTemplateID), Vcpus: build.Vcpus, MemoryMb: build.MemoryMb, - TimeoutSec: 0, // no auto-pause for builds + TimeoutSec: 0, // no auto-pause for builds DiskSizeMb: 5120, // 5 GB for template builds })) if err != nil { @@ -316,8 +334,10 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { // Healthcheck passed → full snapshot (with memory/CPU state). log.Info("healthcheck passed, creating snapshot") snapResp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ - SandboxId: sandboxIDStr, - Name: build.Name, + SandboxId: sandboxIDStr, + Name: build.Name, + TeamId: id.UUIDString(build.TeamID), + TemplateId: id.UUIDString(build.TemplateID), })) if err != nil { s.destroySandbox(ctx, agent, sandboxIDStr) @@ -329,8 +349,10 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { // No healthcheck → image-only template (rootfs only). log.Info("no healthcheck, flattening rootfs") flatResp, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ - SandboxId: sandboxIDStr, - Name: build.Name, + SandboxId: sandboxIDStr, + Name: build.Name, + TeamId: id.UUIDString(build.TeamID), + TemplateId: id.UUIDString(build.TemplateID), })) if err != nil { s.destroySandbox(ctx, agent, sandboxIDStr) @@ -347,6 +369,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { } if _, err := s.DB.InsertTemplate(ctx, db.InsertTemplateParams{ + ID: build.TemplateID, Name: build.Name, Type: templateType, Vcpus: build.Vcpus, diff --git a/internal/service/sandbox.go b/internal/service/sandbox.go index 43f1bd3f..2d1f68c4 100644 --- a/internal/service/sandbox.go +++ b/internal/service/sandbox.go @@ -82,10 +82,21 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. p.DiskSizeMB = 5120 // 5 GB default } - // If the template is a snapshot, use its baked-in vcpus/memory. - if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" { - p.VCPUs = tmpl.Vcpus - p.MemoryMB = tmpl.MemoryMb + // Resolve template name → (teamID, templateID). + templateTeamID := id.PlatformTeamID + templateID := id.MinimalTemplateID + if p.Template != "minimal" { + tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}) + if err != nil { + return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err) + } + templateTeamID = tmpl.TeamID + templateID = tmpl.ID + // If the template is a snapshot, use its baked-in vcpus/memory. + if tmpl.Type == "snapshot" { + p.VCPUs = tmpl.Vcpus + p.MemoryMB = tmpl.MemoryMb + } } if !p.TeamID.Valid { @@ -113,15 +124,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. sandboxIDStr := id.FormatSandboxID(sandboxID) if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{ - ID: sandboxID, - TeamID: p.TeamID, - HostID: host.ID, - Template: p.Template, - Status: "pending", - Vcpus: p.VCPUs, - MemoryMb: p.MemoryMB, - TimeoutSec: p.TimeoutSec, - DiskSizeMb: p.DiskSizeMB, + ID: sandboxID, + TeamID: p.TeamID, + HostID: host.ID, + Template: p.Template, + Status: "pending", + Vcpus: p.VCPUs, + MemoryMb: p.MemoryMB, + TimeoutSec: p.TimeoutSec, + DiskSizeMb: p.DiskSizeMB, + TemplateID: templateID, + TemplateTeamID: templateTeamID, }); err != nil { return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err) } @@ -129,6 +142,8 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db. resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ SandboxId: sandboxIDStr, Template: p.Template, + TeamId: id.UUIDString(templateTeamID), + TemplateId: id.UUIDString(templateID), Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, TimeoutSec: p.TimeoutSec, diff --git a/internal/service/team.go b/internal/service/team.go index 667cd044..a7acbac1 100644 --- a/internal/service/team.go +++ b/internal/service/team.go @@ -202,12 +202,61 @@ func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtyp } } + // Clean up team-owned templates from all hosts in the background. + go s.cleanupTeamTemplates(context.Background(), teamID) + if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil { return fmt.Errorf("soft delete team: %w", err) } return nil } +// cleanupTeamTemplates deletes all template files for a team from all online hosts, +// then removes the DB records. Called asynchronously during team deletion. +func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) { + templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID) + if err != nil { + slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err) + return + } + if len(templates) == 0 { + return + } + + hosts, err := s.DB.ListActiveHosts(ctx) + if err != nil { + slog.Warn("team delete: failed to list hosts for template cleanup", "error", err) + return + } + + for _, tmpl := range templates { + for _, host := range hosts { + if host.Status != "online" { + continue + } + agent, err := s.HostPool.GetForHost(host) + if err != nil { + continue + } + if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{ + TeamId: id.UUIDString(tmpl.TeamID), + TemplateId: id.UUIDString(tmpl.ID), + })); err != nil && connect.CodeOf(err) != connect.CodeNotFound { + slog.Warn("team delete: failed to delete template on host", + "host_id", id.FormatHostID(host.ID), + "template", tmpl.Name, + "error", err, + ) + } + } + } + + // Remove DB records. + if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil { + slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err) + } +} + // GetMembers returns all members of the team with their emails and roles. func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) { rows, err := s.DB.GetTeamMembers(ctx, teamID) diff --git a/proto/hostagent/gen/hostagent.pb.go b/proto/hostagent/gen/hostagent.pb.go index 9f984f91..516e4d25 100644 --- a/proto/hostagent/gen/hostagent.pb.go +++ b/proto/hostagent/gen/hostagent.pb.go @@ -25,7 +25,7 @@ type CreateSandboxRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // Sandbox ID assigned by the control plane. If empty, the host agent generates one. SandboxId string `protobuf:"bytes,5,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - // Template name (e.g., "minimal", "python311"). Determines base rootfs. + // Deprecated: use team_id + template_id instead. Template string `protobuf:"bytes,1,opt,name=template,proto3" json:"template,omitempty"` // Number of virtual CPUs (default: 1). Vcpus int32 `protobuf:"varint,2,opt,name=vcpus,proto3" json:"vcpus,omitempty"` @@ -36,7 +36,11 @@ type CreateSandboxRequest struct { TimeoutSec int32 `protobuf:"varint,4,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` // Disk size in MB for the rootfs. Base images are expanded to this size // at host agent startup. Default: 5120 (5 GB). - DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"` + DiskSizeMb int32 `protobuf:"varint,6,opt,name=disk_size_mb,json=diskSizeMb,proto3" json:"disk_size_mb,omitempty"` + // Team UUID that owns the template (hex string). All-zeros = platform. + TeamId string `protobuf:"bytes,7,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel. + TemplateId string `protobuf:"bytes,8,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -113,6 +117,20 @@ func (x *CreateSandboxRequest) GetDiskSizeMb() int32 { return 0 } +func (x *CreateSandboxRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *CreateSandboxRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type CreateSandboxResponse struct { state protoimpl.MessageState `protogen:"open.v1"` SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` @@ -448,9 +466,14 @@ func (x *ResumeSandboxResponse) GetHostIp() string { } type CreateSnapshotRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that will own the new template. + TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID for the new snapshot template. + TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -499,6 +522,20 @@ func (x *CreateSnapshotRequest) GetName() string { return "" } +func (x *CreateSnapshotRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *CreateSnapshotRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type CreateSnapshotResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -552,8 +589,13 @@ func (x *CreateSnapshotResponse) GetSizeBytes() int64 { } type DeleteSnapshotRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that owns the template. + TeamId string `protobuf:"bytes,2,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID to delete. + TemplateId string `protobuf:"bytes,3,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -595,6 +637,20 @@ func (x *DeleteSnapshotRequest) GetName() string { return "" } +func (x *DeleteSnapshotRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *DeleteSnapshotRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type DeleteSnapshotResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -851,16 +907,19 @@ func (x *ListSandboxesResponse) GetAutoPausedSandboxIds() []string { } type SandboxInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` - Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"` - Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"` - MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` - HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"` - CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"` - LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"` - TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + // Deprecated: use team_id + template_id instead. + Template string `protobuf:"bytes,3,opt,name=template,proto3" json:"template,omitempty"` + Vcpus int32 `protobuf:"varint,4,opt,name=vcpus,proto3" json:"vcpus,omitempty"` + MemoryMb int32 `protobuf:"varint,5,opt,name=memory_mb,json=memoryMb,proto3" json:"memory_mb,omitempty"` + HostIp string `protobuf:"bytes,6,opt,name=host_ip,json=hostIp,proto3" json:"host_ip,omitempty"` + CreatedAtUnix int64 `protobuf:"varint,7,opt,name=created_at_unix,json=createdAtUnix,proto3" json:"created_at_unix,omitempty"` + LastActiveAtUnix int64 `protobuf:"varint,8,opt,name=last_active_at_unix,json=lastActiveAtUnix,proto3" json:"last_active_at_unix,omitempty"` + TimeoutSec int32 `protobuf:"varint,9,opt,name=timeout_sec,json=timeoutSec,proto3" json:"timeout_sec,omitempty"` + TeamId string `protobuf:"bytes,10,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + TemplateId string `protobuf:"bytes,11,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -958,6 +1017,20 @@ func (x *SandboxInfo) GetTimeoutSec() int32 { return 0 } +func (x *SandboxInfo) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *SandboxInfo) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type WriteFileRequest struct { state protoimpl.MessageState `protogen:"open.v1"` SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` @@ -2182,9 +2255,14 @@ func (x *FlushSandboxMetricsResponse) GetPoints_24H() []*MetricPoint { } type FlattenRootfsRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` // template name — output written to images/{name}/rootfs.ext4 + state protoimpl.MessageState `protogen:"open.v1"` + SandboxId string `protobuf:"bytes,1,opt,name=sandbox_id,json=sandboxId,proto3" json:"sandbox_id,omitempty"` + // Deprecated: use team_id + template_id instead. + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + // Team UUID that will own the resulting template. + TeamId string `protobuf:"bytes,3,opt,name=team_id,json=teamId,proto3" json:"team_id,omitempty"` + // Template UUID for the output. + TemplateId string `protobuf:"bytes,4,opt,name=template_id,json=templateId,proto3" json:"template_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2233,6 +2311,20 @@ func (x *FlattenRootfsRequest) GetName() string { return "" } +func (x *FlattenRootfsRequest) GetTeamId() string { + if x != nil { + return x.TeamId + } + return "" +} + +func (x *FlattenRootfsRequest) GetTemplateId() string { + if x != nil { + return x.TemplateId + } + return "" +} + type FlattenRootfsResponse struct { state protoimpl.MessageState `protogen:"open.v1"` SizeBytes int64 `protobuf:"varint,1,opt,name=size_bytes,json=sizeBytes,proto3" json:"size_bytes,omitempty"` @@ -2281,7 +2373,7 @@ var File_hostagent_proto protoreflect.FileDescriptor const file_hostagent_proto_rawDesc = "" + "\n" + - "\x0fhostagent.proto\x12\fhostagent.v1\"\xc7\x01\n" + + "\x0fhostagent.proto\x12\fhostagent.v1\"\x81\x02\n" + "\x14CreateSandboxRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x05 \x01(\tR\tsandboxId\x12\x1a\n" + @@ -2291,7 +2383,10 @@ const file_hostagent_proto_rawDesc = "" + "\vtimeout_sec\x18\x04 \x01(\x05R\n" + "timeoutSec\x12 \n" + "\fdisk_size_mb\x18\x06 \x01(\x05R\n" + - "diskSizeMb\"g\n" + + "diskSizeMb\x12\x17\n" + + "\ateam_id\x18\a \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\b \x01(\tR\n" + + "templateId\"g\n" + "\x15CreateSandboxResponse\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + @@ -2314,17 +2409,23 @@ const file_hostagent_proto_rawDesc = "" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + "\x06status\x18\x02 \x01(\tR\x06status\x12\x17\n" + - "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"J\n" + + "\ahost_ip\x18\x03 \x01(\tR\x06hostIp\"\x84\x01\n" + "\x15CreateSnapshotRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + - "\x04name\x18\x02 \x01(\tR\x04name\"K\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x04 \x01(\tR\n" + + "templateId\"K\n" + "\x16CreateSnapshotResponse\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n" + "\n" + - "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"+\n" + + "size_bytes\x18\x02 \x01(\x03R\tsizeBytes\"e\n" + "\x15DeleteSnapshotRequest\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\"\x18\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x02 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x03 \x01(\tR\n" + + "templateId\"\x18\n" + "\x16DeleteSnapshotResponse\"s\n" + "\vExecRequest\x12\x1d\n" + "\n" + @@ -2340,7 +2441,7 @@ const file_hostagent_proto_rawDesc = "" + "\x14ListSandboxesRequest\"\x87\x01\n" + "\x15ListSandboxesResponse\x127\n" + "\tsandboxes\x18\x01 \x03(\v2\x19.hostagent.v1.SandboxInfoR\tsandboxes\x125\n" + - "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xa4\x02\n" + + "\x17auto_paused_sandbox_ids\x18\x02 \x03(\tR\x14autoPausedSandboxIds\"\xde\x02\n" + "\vSandboxInfo\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x16\n" + @@ -2352,7 +2453,11 @@ const file_hostagent_proto_rawDesc = "" + "\x0fcreated_at_unix\x18\a \x01(\x03R\rcreatedAtUnix\x12-\n" + "\x13last_active_at_unix\x18\b \x01(\x03R\x10lastActiveAtUnix\x12\x1f\n" + "\vtimeout_sec\x18\t \x01(\x05R\n" + - "timeoutSec\"_\n" + + "timeoutSec\x12\x17\n" + + "\ateam_id\x18\n" + + " \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\v \x01(\tR\n" + + "templateId\"_\n" + "\x10WriteFileRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + @@ -2427,11 +2532,14 @@ const file_hostagent_proto_rawDesc = "" + "points_10m\x18\x01 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints10m\x126\n" + "\tpoints_2h\x18\x02 \x03(\v2\x19.hostagent.v1.MetricPointR\bpoints2h\x128\n" + "\n" + - "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h\"I\n" + + "points_24h\x18\x03 \x03(\v2\x19.hostagent.v1.MetricPointR\tpoints24h\"\x83\x01\n" + "\x14FlattenRootfsRequest\x12\x1d\n" + "\n" + "sandbox_id\x18\x01 \x01(\tR\tsandboxId\x12\x12\n" + - "\x04name\x18\x02 \x01(\tR\x04name\"6\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x17\n" + + "\ateam_id\x18\x03 \x01(\tR\x06teamId\x12\x1f\n" + + "\vtemplate_id\x18\x04 \x01(\tR\n" + + "templateId\"6\n" + "\x15FlattenRootfsResponse\x12\x1d\n" + "\n" + "size_bytes\x18\x01 \x01(\x03R\tsizeBytes2\xc8\f\n" + diff --git a/proto/hostagent/hostagent.proto b/proto/hostagent/hostagent.proto index 5a3205b7..817d5359 100644 --- a/proto/hostagent/hostagent.proto +++ b/proto/hostagent/hostagent.proto @@ -73,7 +73,7 @@ message CreateSandboxRequest { // Sandbox ID assigned by the control plane. If empty, the host agent generates one. string sandbox_id = 5; - // Template name (e.g., "minimal", "python311"). Determines base rootfs. + // Deprecated: use team_id + template_id instead. string template = 1; // Number of virtual CPUs (default: 1). @@ -89,6 +89,12 @@ message CreateSandboxRequest { // Disk size in MB for the rootfs. Base images are expanded to this size // at host agent startup. Default: 5120 (5 GB). int32 disk_size_mb = 6; + + // Team UUID that owns the template (hex string). All-zeros = platform. + string team_id = 7; + + // Template UUID (hex string). Both zeros + team zeros = "minimal" sentinel. + string template_id = 8; } message CreateSandboxResponse { @@ -125,7 +131,12 @@ message ResumeSandboxResponse { message CreateSnapshotRequest { string sandbox_id = 1; + // Deprecated: use team_id + template_id instead. string name = 2; + // Team UUID that will own the new template. + string team_id = 3; + // Template UUID for the new snapshot template. + string template_id = 4; } message CreateSnapshotResponse { @@ -134,7 +145,12 @@ message CreateSnapshotResponse { } message DeleteSnapshotRequest { + // Deprecated: use team_id + template_id instead. string name = 1; + // Team UUID that owns the template. + string team_id = 2; + // Template UUID to delete. + string template_id = 3; } message DeleteSnapshotResponse {} @@ -166,6 +182,7 @@ message ListSandboxesResponse { message SandboxInfo { string sandbox_id = 1; string status = 2; + // Deprecated: use team_id + template_id instead. string template = 3; int32 vcpus = 4; int32 memory_mb = 5; @@ -173,6 +190,8 @@ message SandboxInfo { int64 created_at_unix = 7; int64 last_active_at_unix = 8; int32 timeout_sec = 9; + string team_id = 10; + string template_id = 11; } message WriteFileRequest { @@ -299,7 +318,12 @@ message FlushSandboxMetricsResponse { message FlattenRootfsRequest { string sandbox_id = 1; - string name = 2; // template name — output written to images/{name}/rootfs.ext4 + // Deprecated: use team_id + template_id instead. + string name = 2; + // Team UUID that will own the resulting template. + string team_id = 3; + // Template UUID for the output. + string template_id = 4; } message FlattenRootfsResponse { From 906cc42d1391b470e1e3bfb928c6512d60961790 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 29 Mar 2026 00:30:20 +0600 Subject: [PATCH 19/28] Rename AGENT_*/CP_LISTEN_ADDR env vars to WRENN_* prefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AGENT_FILES_ROOTDIR → WRENN_DIR, AGENT_LISTEN_ADDR → WRENN_HOST_LISTEN_ADDR, AGENT_CP_URL → WRENN_CP_URL, AGENT_HOST_INTERFACE → WRENN_HOST_INTERFACE, CP_LISTEN_ADDR → WRENN_CP_LISTEN_ADDR. Consolidates all env vars under a consistent WRENN_ namespace. --- .env.example | 10 +++++----- README.md | 16 ++++++++-------- cmd/host-agent/main.go | 16 ++++++---------- internal/config/config.go | 2 +- internal/hostagent/registration.go | 2 +- scripts/rootfs-from-container.sh | 8 ++++---- 6 files changed, 25 insertions(+), 29 deletions(-) diff --git a/.env.example b/.env.example index bf528015..32b235a7 100644 --- a/.env.example +++ b/.env.example @@ -5,13 +5,13 @@ DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable REDIS_URL=redis://localhost:6379/0 # Control Plane -CP_LISTEN_ADDR=:8080 +WRENN_CP_LISTEN_ADDR=:8080 # Host Agent -AGENT_LISTEN_ADDR=:50051 -AGENT_FILES_ROOTDIR=/var/lib/wrenn -AGENT_HOST_INTERFACE=eth0 -AGENT_CP_URL=http://localhost:8080 +WRENN_HOST_LISTEN_ADDR=:50051 +WRENN_DIR=/var/lib/wrenn +WRENN_HOST_INTERFACE=eth0 +WRENN_CP_URL=http://localhost:8080 # Lago (billing — external service) LAGO_API_URL=http://localhost:3000 diff --git a/README.md b/README.md index dff19325..e2b290f0 100644 --- a/README.md +++ b/README.md @@ -51,12 +51,12 @@ Copy `.env.example` to `.env` and edit: DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable # Control plane -CP_LISTEN_ADDR=:8000 +WRENN_CP_LISTEN_ADDR=:8000 CP_HOST_AGENT_ADDR=http://localhost:50051 # Host agent -AGENT_LISTEN_ADDR=:50051 -AGENT_FILES_ROOTDIR=/var/lib/wrenn +WRENN_HOST_LISTEN_ADDR=:50051 +WRENN_DIR=/var/lib/wrenn ``` ### Run @@ -69,7 +69,7 @@ make migrate-up ./builds/wrenn-cp ``` -Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). +Control plane listens on `WRENN_CP_LISTEN_ADDR` (default `:8000`). ### Host registration @@ -87,16 +87,16 @@ Hosts must be registered with the control plane before they can serve sandboxes. 2. **Start the host agent** with the registration token and its externally-reachable address: ```bash - sudo AGENT_CP_URL=http://cp-host:8000 \ + sudo WRENN_CP_URL=http://cp-host:8000 \ ./builds/wrenn-agent \ --register \ --address 10.0.1.5:50051 ``` - On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$AGENT_FILES_ROOTDIR/host-token`. + On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$WRENN_DIR/host-token`. 3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically: ```bash - sudo AGENT_CP_URL=http://cp-host:8000 \ + sudo WRENN_CP_URL=http://cp-host:8000 \ ./builds/wrenn-agent --address 10.0.1.5:50051 ``` @@ -107,7 +107,7 @@ Hosts must be registered with the control plane before they can serve sandboxes. ``` Then restart the agent with the new token. -The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`). +The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `WRENN_HOST_LISTEN_ADDR` (default `:50051`). ### Rootfs images diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 76dc2390..6665ba9a 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -45,13 +45,13 @@ func main() { // Clean up any stale dm-snapshot devices from a previous crash. devicemapper.CleanupStaleDevices() - listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051") - rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn") - cpURL := os.Getenv("AGENT_CP_URL") + listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") + rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") + cpURL := os.Getenv("WRENN_CP_URL") tokenFile := filepath.Join(rootDir, "host.jwt") if cpURL == "" { - slog.Error("AGENT_CP_URL environment variable is required") + slog.Error("WRENN_CP_URL environment variable is required") os.Exit(1) } if *advertiseAddr == "" { @@ -61,17 +61,13 @@ func main() { // Expand base images to the standard disk size (sparse, no extra physical // disk). This ensures dm-snapshot sandboxes see the full size from boot. - imagesDir := filepath.Join(rootDir, "images") - if err := sandbox.EnsureImageSizes(imagesDir, sandbox.DefaultDiskSizeMB); err != nil { + if err := sandbox.EnsureImageSizes(rootDir, sandbox.DefaultDiskSizeMB); err != nil { slog.Error("failed to expand base images", "error", err) os.Exit(1) } cfg := sandbox.Config{ - KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"), - ImagesDir: filepath.Join(rootDir, "images"), - SandboxesDir: filepath.Join(rootDir, "sandboxes"), - SnapshotsDir: filepath.Join(rootDir, "snapshots"), + WrennDir: rootDir, } mgr := sandbox.New(cfg) diff --git a/internal/config/config.go b/internal/config/config.go index 7ef0aa69..a4564aa5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,7 +28,7 @@ func Load() Config { return Config{ DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"), RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"), - ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), + ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"), JWTSecret: os.Getenv("JWT_SECRET"), OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go index 9f39c3b0..5948e0ce 100644 --- a/internal/hostagent/registration.go +++ b/internal/hostagent/registration.go @@ -17,7 +17,7 @@ import ( "golang.org/x/sys/unix" ) -// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt. +// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt. type tokenFile struct { HostID string `json:"host_id"` JWT string `json:"jwt"` diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index 2159ac7c..2f96a3a7 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -16,7 +16,7 @@ # image_name — Directory name under images dir (e.g. "waitlist") # # Output: -# ${AGENT_FILES_ROOTDIR}/images//rootfs.ext4 +# ${WRENN_DIR}/images//rootfs.ext4 # # Requires: docker, mkfs.ext4, resize2fs, e2fsck, make (for building envd), curl (for tini download) # Sudo is used only for mount/umount/copy-into-image operations. @@ -25,8 +25,8 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -AGENT_FILES_ROOTDIR="${AGENT_FILES_ROOTDIR:-/var/lib/wrenn}" -AGENT_IMAGES_PATH="${AGENT_FILES_ROOTDIR}/images" +WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}" +WRENN_IMAGES_PATH="${WRENN_DIR}/images" if [ $# -lt 2 ]; then echo "Usage: $0 " @@ -35,7 +35,7 @@ fi CONTAINER="$1" IMAGE_NAME="$2" -OUTPUT_DIR="${AGENT_IMAGES_PATH}/${IMAGE_NAME}" +OUTPUT_DIR="${WRENN_IMAGES_PATH}/${IMAGE_NAME}" OUTPUT_FILE="${OUTPUT_DIR}/rootfs.ext4" MOUNT_DIR="/tmp/wrenn-rootfs-build" TAR_FILE="/tmp/wrenn-rootfs-export-${IMAGE_NAME}.tar" From 46d60fc5a586bc85ce6d662674ce196c707641c0 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 29 Mar 2026 01:34:54 +0600 Subject: [PATCH 20/28] Seed minimal template in DB and protect it from deletion Insert a minimal template row (all-zeros UUID) so it appears in both team and admin template listings. Guard delete endpoints to prevent removal of the minimal template. --- .../20260328162803_template_uuid_pk.sql | 20 ++++++++++++++++++- internal/api/handlers_builds.go | 5 +++++ internal/api/handlers_snapshots.go | 5 +++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/db/migrations/20260328162803_template_uuid_pk.sql b/db/migrations/20260328162803_template_uuid_pk.sql index 8665241f..0bb65668 100644 --- a/db/migrations/20260328162803_template_uuid_pk.sql +++ b/db/migrations/20260328162803_template_uuid_pk.sql @@ -12,6 +12,7 @@ ALTER TABLE templates ADD CONSTRAINT uq_templates_team_name UNIQUE (team_id, nam -- 3. Prevent team templates from using names that belong to global (platform) templates. -- A team template insert/update with a name matching any platform template is rejected. +-- +goose StatementBegin CREATE OR REPLACE FUNCTION check_global_template_name_collision() RETURNS TRIGGER AS $$ BEGIN @@ -28,13 +29,27 @@ BEGIN RETURN NEW; END; $$ LANGUAGE plpgsql; +-- +goose StatementEnd CREATE TRIGGER trg_check_global_template_name BEFORE INSERT OR UPDATE ON templates FOR EACH ROW EXECUTE FUNCTION check_global_template_name_collision(); --- 4. Add template UUID references to template_builds. +-- 4. Seed the built-in "minimal" template so it appears in all listings. +-- Both id and team_id are the all-zeros UUID (platform sentinel). +INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id) +VALUES ( + '00000000-0000-0000-0000-000000000000', + 'minimal', + 'base', + 1, + 512, + 0, + '00000000-0000-0000-0000-000000000000' +) ON CONFLICT DO NOTHING; + +-- 5. Add template UUID references to template_builds. ALTER TABLE template_builds ADD COLUMN template_id UUID, ADD COLUMN team_id UUID; @@ -54,6 +69,9 @@ ALTER TABLE template_builds DROP COLUMN IF EXISTS team_id, DROP COLUMN IF EXISTS template_id; +-- Remove the seeded minimal template. +DELETE FROM templates WHERE id = '00000000-0000-0000-0000-000000000000'; + DROP TRIGGER IF EXISTS trg_check_global_template_name ON templates; DROP FUNCTION IF EXISTS check_global_template_name_collision(); diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index 58a7ed44..3b964008 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -12,6 +12,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" @@ -221,6 +222,10 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, "not_found", "template not found") return } + if layout.IsMinimal(tmpl.TeamID, tmpl.ID) { + writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted") + return + } // Broadcast delete to all online hosts. hosts, _ := h.db.ListActiveHosts(ctx) diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index 07bd0301..f7d05f2f 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -17,6 +17,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/layout" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/service" "git.omukk.dev/wrenn/sandbox/internal/validate" @@ -271,6 +272,10 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here") return } + if layout.IsMinimal(tmpl.TeamID, tmpl.ID) { + writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted") + return + } if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil { writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files") From 1ca10230a9f1c8641545510d7888cfaeca283d02 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 29 Mar 2026 02:14:30 +0600 Subject: [PATCH 21/28] Prefix network namespaces with wrenn-, add stale cleanup, lower diff cap Rename ns-{idx} to wrenn-ns-{idx} and veth-{idx} to wrenn-veth-{idx} to avoid collisions with other tools. Add CleanupStaleNamespaces() at agent startup to remove orphaned namespaces, veths, iptables rules, and routes from a previous crash. Lower maxDiffGenerations from 10 to 8 to prevent Go runtime memory corruption from snapshot/restore drift. --- cmd/host-agent/main.go | 4 +- internal/network/setup.go | 82 ++++++++++++++++++++++++++++++++++++- internal/sandbox/manager.go | 8 +++- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 6665ba9a..a6571df8 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -16,6 +16,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/hostagent" + "git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/sandbox" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -42,8 +43,9 @@ func main() { slog.Warn("failed to enable ip_forward", "error", err) } - // Clean up any stale dm-snapshot devices from a previous crash. + // Clean up stale resources from a previous crash. devicemapper.CleanupStaleDevices() + network.CleanupStaleNamespaces() listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") diff --git a/internal/network/setup.go b/internal/network/setup.go index 70a8a547..ee06d391 100644 --- a/internal/network/setup.go +++ b/internal/network/setup.go @@ -5,13 +5,91 @@ import ( "fmt" "log/slog" "net" + "os" "os/exec" "runtime" + "strings" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" ) +const nsPrefix = "wrenn-ns-" + +// CleanupStaleNamespaces removes leftover wrenn network namespaces from a +// previous crash. Called once at agent startup. +func CleanupStaleNamespaces() { + entries, err := os.ReadDir("/run/netns") + if err != nil { + return // no /run/netns or unreadable — nothing to clean + } + for _, e := range entries { + name := e.Name() + if !strings.HasPrefix(name, nsPrefix) { + continue + } + // Also remove the associated veth from the host side. + vethName := "wrenn-veth-" + strings.TrimPrefix(name, nsPrefix) + if link, err := netlink.LinkByName(vethName); err == nil { + _ = netlink.LinkDel(link) + } + if err := netns.DeleteNamed(name); err != nil { + slog.Warn("failed to remove stale namespace", "ns", name, "error", err) + } else { + slog.Info("removed stale namespace", "ns", name) + } + } + + // Clean up any stale wrenn iptables rules referencing old veth interfaces. + cleanupStaleIptablesRules() +} + +// cleanupStaleIptablesRules removes host iptables rules that reference +// wrenn-veth interfaces no longer present on the system. +func cleanupStaleIptablesRules() { + for _, table := range []string{"filter", "nat"} { + cmd := exec.Command("iptables-save", "-t", table) + out, err := cmd.Output() + if err != nil { + continue + } + for _, line := range strings.Split(string(out), "\n") { + if !strings.Contains(line, "wrenn-veth-") { + continue + } + // Lines look like "-A FORWARD -i wrenn-veth-1 -o wlo1 -j ACCEPT" + // Convert -A to -D to delete the rule. + if !strings.HasPrefix(line, "-A ") { + continue + } + delRule := "-D " + line[3:] + args := strings.Fields(delRule) + delCmd := exec.Command("iptables", append([]string{"-t", table}, args...)...) + if err := delCmd.Run(); err != nil { + slog.Debug("failed to remove stale iptables rule", "rule", line, "error", err) + } + } + } + + // Also remove stale host routes to 10.11.0.x via wrenn-veth interfaces. + routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) + if err != nil { + return + } + for _, r := range routes { + if r.LinkIndex == 0 { + continue + } + link, err := netlink.LinkByIndex(r.LinkIndex) + if err != nil { + continue + } + if strings.HasPrefix(link.Attrs().Name, "wrenn-veth-") { + _ = netlink.RouteDel(&r) + } + } +} + const ( // Fixed addresses inside each network namespace (safe because each // sandbox gets its own netns). @@ -84,8 +162,8 @@ func NewSlot(index int) *Slot { GuestIP: guestIP, GuestNetMask: guestNetMask, TapName: tapName, - NamespaceID: fmt.Sprintf("ns-%d", index), - VethName: fmt.Sprintf("veth-%d", index), + NamespaceID: fmt.Sprintf("wrenn-ns-%d", index), + VethName: fmt.Sprintf("wrenn-veth-%d", index), } } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 4e97c49b..84e04aa2 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -73,8 +73,12 @@ type snapshotParent struct { } // maxDiffGenerations caps how many incremental diff generations we chain -// before falling back to a Full snapshot to collapse the chain. -const maxDiffGenerations = 10 +// before falling back to a Full snapshot to collapse the chain. Firecracker +// snapshot/restore of a Go process (envd) accumulates runtime memory state +// drift; empirically, ~10 diff-based cycles corrupt the Go page allocator. +// A Full snapshot resets the generation counter and produces a clean base, +// preventing the crash. +const maxDiffGenerations = 8 // New creates a new sandbox manager. func New(cfg Config) *Manager { From 8f06fc554ae8b75496c0a2afdf201c1cf48cccc1 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Sun, 29 Mar 2026 02:33:33 +0600 Subject: [PATCH 22/28] Replace Full snapshot fallback with file-level diff merge Always use Firecracker Diff snapshots (fast, only changed pages) and merge diff files at the file level when the generation cap is reached. The previous approach used Firecracker's Full snapshot type which dumps all memory to disk and can timeout, losing all snapshot data on failure. Add snapshot.MergeDiffs() which reads each block from the appropriate generation's diff file via the header mapping and writes them into a single consolidated file with a fresh generation-0 header. --- internal/sandbox/manager.go | 71 +++++++++++++++++++++++++-- internal/snapshot/memfile.go | 94 ++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 84e04aa2..a8293e64 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "time" @@ -314,10 +315,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) - // Determine snapshot type: Diff if resumed from snapshot (avoids UFFD - // fault-in storm), Full otherwise or if generation cap is reached. + // Always use Diff when we have a parent snapshot — Diff only captures + // changed pages and is much faster than Full (which dumps all memory). + // For first-time pauses (no parent) we must use Full. snapshotType := "Full" - if sb.parent != nil && sb.parent.header.Metadata.Generation < maxDiffGenerations { + if sb.parent != nil { snapshotType = "Diff" } @@ -353,7 +355,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName) processStart := time.Now() - if sb.parent != nil && snapshotType == "Diff" { + if sb.parent != nil { // Diff: process against parent header, producing only changed blocks. diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID) if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil { @@ -373,8 +375,50 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } } } + + // If the generation cap is reached, merge all diff files into a + // single file to collapse the chain. This is a file-level operation + // (no Firecracker involvement) so it's fast and reliable. + generation := sb.parent.header.Metadata.Generation + 1 + if generation >= maxDiffGenerations { + slog.Debug("pause: merging diff generations", "id", sandboxID, "generation", generation) + + // Load the header we just wrote (it references all generations). + headerData, err := os.ReadFile(headerPath) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("read header for merge: %w", err) + } + currentHeader, err := snapshot.Deserialize(headerData) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("deserialize header for merge: %w", err) + } + + // Locate all diff files referenced by the header. + diffFiles, err := snapshot.ListDiffFiles(pauseDir, "", currentHeader) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("list diff files for merge: %w", err) + } + + // Merge into a single new diff file. + mergedPath := snapshot.MemDiffPath(pauseDir, "") + if _, err := snapshot.MergeDiffs(currentHeader, diffFiles, mergedPath, headerPath); err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("merge diff files: %w", err) + } + + // Remove the old per-generation diff files. + removeStaleMemDiffs(pauseDir) + slog.Debug("pause: diff merge complete", "id", sandboxID) + } } else { - // Full: first generation or generation cap reached — single diff file. + // Full: first pause — no parent to diff against. diffPath := snapshot.MemDiffPath(pauseDir, "") if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil { warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) @@ -1283,6 +1327,23 @@ func (m *Manager) PauseAll(ctx context.Context) { } } +// removeStaleMemDiffs removes memfile.{uuid} diff files from a snapshot +// directory. Called before writing a Full snapshot to prevent orphaned diffs +// from accumulating across generation resets. +func removeStaleMemDiffs(dir string) { + entries, err := os.ReadDir(dir) + if err != nil { + return + } + for _, e := range entries { + name := e.Name() + // Match "memfile.{uuid}" but not "memfile", "memfile.header", or "memfile.raw". + if strings.HasPrefix(name, "memfile.") && name != snapshot.MemHeaderName && name != "memfile.raw" { + os.Remove(filepath.Join(dir, name)) + } + } +} + // warnErr logs a warning if err is non-nil. Used for best-effort cleanup // in error paths where the primary error has already been captured. func warnErr(msg string, id string, err error) { diff --git a/internal/snapshot/memfile.go b/internal/snapshot/memfile.go index aabe8851..f7b14f9a 100644 --- a/internal/snapshot/memfile.go +++ b/internal/snapshot/memfile.go @@ -4,6 +4,7 @@ package snapshot import ( + "context" "fmt" "io" "os" @@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe return header, nil } +// MergeDiffs consolidates multiple generation diff files into a single diff +// file and resets the generation counter to 0. This is a pure file-level +// operation — no Firecracker involvement. +// +// It reads each non-nil block from the appropriate diff file (as mapped by +// the header), writes them all sequentially into a single new diff file, +// and produces a fresh header pointing only at that file. +// +// diffFiles maps build ID (string) → open file path for each generation's diff. +func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) { + blockSize := int64(header.Metadata.BlockSize) + mergedBuildID := uuid.New() + + // Open all source diff files. + sources := make(map[string]*os.File, len(diffFiles)) + for id, path := range diffFiles { + f, err := os.Open(path) + if err != nil { + // Close already opened files. + for _, sf := range sources { + sf.Close() + } + return nil, fmt.Errorf("open diff file for build %s: %w", id, err) + } + sources[id] = f + } + defer func() { + for _, f := range sources { + f.Close() + } + }() + + dst, err := os.Create(mergedDiffPath) + if err != nil { + return nil, fmt.Errorf("create merged diff file: %w", err) + } + defer dst.Close() + + totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize) + dirty := make([]bool, totalBlocks) + empty := make([]bool, totalBlocks) + buf := make([]byte, blockSize) + + for i := int64(0); i < totalBlocks; i++ { + offset := i * blockSize + mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset) + if err != nil { + return nil, fmt.Errorf("lookup block %d: %w", i, err) + } + + if *buildID == uuid.Nil { + empty[i] = true + continue + } + + src, ok := sources[buildID.String()] + if !ok { + return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i) + } + + if _, err := src.ReadAt(buf, mappedOffset); err != nil { + return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err) + } + + dirty[i] = true + if _, err := dst.Write(buf); err != nil { + return nil, fmt.Errorf("write merged block %d: %w", i, err) + } + } + + // Build fresh header with generation 0. + dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize) + emptyMappings := CreateMapping(uuid.Nil, empty, blockSize) + merged := MergeMappings(dirtyMappings, emptyMappings) + normalized := NormalizeMappings(merged) + + metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size) + newHeader, err := NewHeader(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("create merged header: %w", err) + } + + headerData, err := Serialize(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("serialize merged header: %w", err) + } + if err := os.WriteFile(headerPath, headerData, 0644); err != nil { + return nil, fmt.Errorf("write merged header: %w", err) + } + + return newHeader, nil +} + // isZeroBlock checks if a block is entirely zero bytes. func isZeroBlock(block []byte) bool { // Fast path: compare 8 bytes at a time. From 88f919c4ca16af18ba82969396ec67513550dc30 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Mon, 30 Mar 2026 17:12:05 +0600 Subject: [PATCH 23/28] Rename sandbox prefix to cl-, add MMDS metadata, fix proxy port routing - Change sandbox ID prefix from sb- to cl- (capsule) throughout - Fix proxy URL regex character class: base36 uses 0-9a-z, not just hex - Add MMDS V2 config and metadata to VM boot flow so envd can read WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest - Pass TemplateID through VMConfig into both fresh and snapshot boot paths --- internal/api/handler_sandbox_proxy.go | 8 ++++---- internal/id/id.go | 2 +- internal/id/id_test.go | 4 ++-- internal/layout/layout_test.go | 4 ++-- internal/sandbox/manager.go | 2 ++ internal/validate/name_test.go | 2 +- internal/vm/config.go | 6 +++++- internal/vm/fc.go | 25 +++++++++++++++++++++++++ internal/vm/manager.go | 19 +++++++++++++++++++ 9 files changed, 61 insertions(+), 11 deletions(-) diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 322a559c..299aea9f 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -18,9 +18,9 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/lifecycle" ) -// sandboxHostPattern matches hostnames like "49999-sb-abcd1234.localhost" or -// "49999-sb-abcd1234.example.com". Captures: port, sandbox ID. -var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f-]+)\.`) +// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or +// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID. +var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`) // SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests // whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching @@ -48,7 +48,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { host := r.Host - // Strip port from Host header (e.g. "49999-sb-abcd1234.localhost:8000" → "49999-sb-abcd1234.localhost") + // Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost") if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 { host = host[:colonIdx] } diff --git a/internal/id/id.go b/internal/id/id.go index 45cba6c5..f4b6cdb6 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -65,7 +65,7 @@ func NewRefreshToken() string { // --- Formatting (pgtype.UUID → prefixed string for API/RPC output) --- const ( - PrefixSandbox = "sb-" + PrefixSandbox = "cl-" PrefixUser = "usr-" PrefixTeam = "team-" PrefixAPIKey = "key-" diff --git a/internal/id/id_test.go b/internal/id/id_test.go index c16ec7a0..6fb23945 100644 --- a/internal/id/id_test.go +++ b/internal/id/id_test.go @@ -46,8 +46,8 @@ func TestFormatParseRoundTrip(t *testing.T) { id := NewSandboxID() formatted := FormatSandboxID(id) - if formatted[:3] != "sb-" { - t.Fatalf("expected sb- prefix, got %s", formatted) + if formatted[:3] != "cl-" { + t.Fatalf("expected cl- prefix, got %s", formatted) } if len(formatted) != 3+base36IDLen { t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted) diff --git a/internal/layout/layout_test.go b/internal/layout/layout_test.go index bffdae26..f7b9afd3 100644 --- a/internal/layout/layout_test.go +++ b/internal/layout/layout_test.go @@ -96,8 +96,8 @@ func TestTemplateRootfs(t *testing.T) { } func TestPauseSnapshotDir(t *testing.T) { - got := PauseSnapshotDir("/var/lib/wrenn", "sb-abc123") - want := "/var/lib/wrenn/snapshots/sb-abc123" + got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123") + want := "/var/lib/wrenn/snapshots/cl-abc123" if got != want { t.Errorf("PauseSnapshotDir() = %q, want %q", got, want) } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index a8293e64..ac2bc221 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -169,6 +169,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template // Boot VM — Firecracker gets the dm device path. vmCfg := vm.VMConfig{ SandboxID: sandboxID, + TemplateID: id.UUIDString(templateID), KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, @@ -1033,6 +1034,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team // Restore VM. vmCfg := vm.VMConfig{ SandboxID: sandboxID, + TemplateID: id.UUIDString(templateID), KernelPath: layout.KernelPath(m.cfg.WrennDir), RootfsPath: dmDev.DevicePath, VCPUs: vcpus, diff --git a/internal/validate/name_test.go b/internal/validate/name_test.go index 4b7769e2..f17e210a 100644 --- a/internal/validate/name_test.go +++ b/internal/validate/name_test.go @@ -11,7 +11,7 @@ func TestSafeName(t *testing.T) { {"simple", "minimal", false}, {"with-dash", "template-abc123", false}, {"with-dot", "my-snapshot.v2", false}, - {"sandbox-id", "sb-12345678", false}, + {"sandbox-id", "cl-12345678", false}, {"single-char", "a", false}, {"numbers", "123", false}, {"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false}, diff --git a/internal/vm/config.go b/internal/vm/config.go index b99480e4..0c1f2582 100644 --- a/internal/vm/config.go +++ b/internal/vm/config.go @@ -4,9 +4,13 @@ import "fmt" // VMConfig holds the configuration for creating a Firecracker microVM. type VMConfig struct { - // SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4"). + // SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4"). SandboxID string + // TemplateID is the template UUID string used to populate MMDS metadata + // so that envd can read WRENN_TEMPLATE_ID from inside the guest. + TemplateID string + // KernelPath is the path to the uncompressed Linux kernel (vmlinux). KernelPath string diff --git a/internal/vm/fc.go b/internal/vm/fc.go index b5af5dbb..3d0f246d 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -101,6 +101,31 @@ func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error }) } +// setMMDSConfig enables MMDS V2 token-based access on the given network interface. +// Must be called before startVM. +func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error { + return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{ + "version": "V2", + "network_interfaces": []string{ifaceID}, + }) +} + +// mmdsMetadata is the metadata payload written to the Firecracker MMDS store. +// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID. +type mmdsMetadata struct { + SandboxID string `json:"instanceID"` + TemplateID string `json:"envID"` +} + +// setMMDS writes sandbox metadata to the Firecracker MMDS store. +// Can be called after the VM has started. +func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error { + return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{ + SandboxID: sandboxID, + TemplateID: templateID, + }) +} + // startVM issues the InstanceStart action. func (c *fcClient) startVM(ctx context.Context) error { return c.do(ctx, http.MethodPut, "/actions", map[string]string{ diff --git a/internal/vm/manager.go b/internal/vm/manager.go index c7e34797..9e9466f5 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -71,6 +71,13 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) { return nil, fmt.Errorf("start VM: %w", err) } + // Step 5: Push sandbox metadata into MMDS so envd can read + // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. + if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil { + _ = proc.stop() + return nil, fmt.Errorf("set MMDS metadata: %w", err) + } + vm := &VM{ Config: cfg, process: proc, @@ -108,6 +115,12 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error { return fmt.Errorf("set machine config: %w", err) } + // MMDS config — enable V2 token access on eth0 so that envd can read + // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. + if err := client.setMMDSConfig(ctx, "eth0"); err != nil { + return fmt.Errorf("set MMDS config: %w", err) + } + return nil } @@ -238,6 +251,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath return nil, fmt.Errorf("resume VM: %w", err) } + // Step 5: Push sandbox metadata into MMDS. + if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil { + _ = proc.stop() + return nil, fmt.Errorf("set MMDS metadata: %w", err) + } + vm := &VM{ Config: cfg, process: proc, From 25ce0729d5fc15c9c256deed58ecdff9143b60b3 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Mon, 30 Mar 2026 21:24:35 +0600 Subject: [PATCH 24/28] =?UTF-8?q?Add=20mTLS=20to=20CP=E2=86=92agent=20chan?= =?UTF-8?q?nel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent the system falls back to plain HTTP so dev mode works without certificates - Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column replaces the removed mtls_enabled flag) - CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap; background goroutine renews it every 12 hours without restarting the server - Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is respected (ListenAndServeTLS always reads cert from disk) - Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS config as the Connect RPC clients instead of http.DefaultTransport - Credentials file renamed host-credentials.json with cert_pem/key_pem/ ca_cert_pem fields; duplicate register/refresh response structs collapsed to authResponse --- .env.example | 8 + cmd/control-plane/main.go | 49 +++- cmd/host-agent/main.go | 76 ++++-- .../20260330112050_mtls_cert_expiry.sql | 7 + db/queries/hosts.sql | 23 +- internal/api/handler_sandbox_proxy.go | 4 +- internal/api/handlers_hosts.go | 12 + internal/api/server.go | 5 +- internal/auth/cert.go | 251 ++++++++++++++++++ internal/config/config.go | 8 + internal/db/hosts.sql.go | 87 +++--- internal/db/models.go | 3 +- internal/hostagent/certstore.go | 42 +++ internal/hostagent/registration.go | 162 ++++++----- internal/lifecycle/hostpool.go | 53 +++- internal/service/host.go | 70 ++++- 16 files changed, 716 insertions(+), 144 deletions(-) create mode 100644 db/migrations/20260330112050_mtls_cert_expiry.sql create mode 100644 internal/auth/cert.go create mode 100644 internal/hostagent/certstore.go diff --git a/.env.example b/.env.example index 32b235a7..62e9b4ce 100644 --- a/.env.example +++ b/.env.example @@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY= # Auth JWT_SECRET= +# mTLS — CP→Agent channel +# Generate a self-signed CA with: +# openssl ecparam -genkey -name P-256 -noout -out ca.key +# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca" +# Then set these to the file contents (newlines replaced with \n or use multiline env). +WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE----- +WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY----- + # OAuth OAUTH_GITHUB_CLIENT_ID= OAUTH_GITHUB_CLIENT_SECRET= diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index af57d2be..419dc875 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -15,6 +15,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/api" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -68,8 +69,52 @@ func main() { } slog.Info("connected to redis") + // mTLS: parse internal CA and build a TLS-capable host client pool. + // When CA env vars are absent the pool falls back to plain HTTP (dev mode). + var ca *auth.CA + if cfg.CACert != "" && cfg.CAKey != "" { + var err error + ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey) + if err != nil { + slog.Error("failed to parse mTLS CA from environment", "error", err) + os.Exit(1) + } + slog.Info("mTLS enabled: CA loaded") + } else { + slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted") + } + // Host client pool — manages Connect RPC clients to host agents. - hostPool := lifecycle.NewHostClientPool() + var hostPool *lifecycle.HostClientPool + if ca != nil { + cpCertStore, err := auth.NewCPCertStore(ca) + if err != nil { + slog.Error("failed to issue CP client certificate", "error", err) + os.Exit(1) + } + // Renew the CP client certificate periodically so it never expires + // while the control plane is running (TTL = 24h, renewal = every 12h). + go func() { + ticker := time.NewTicker(auth.CPCertRenewInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := cpCertStore.Refresh(); err != nil { + slog.Error("failed to renew CP client certificate", "error", err) + } else { + slog.Info("CP client certificate renewed") + } + } + } + }() + hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) + slog.Info("host client pool: mTLS enabled") + } else { + hostPool = lifecycle.NewHostClientPool() + } // Scheduler — picks a host for each new sandbox (round-robin for now). hostScheduler := scheduler.NewRoundRobinScheduler(queries) @@ -88,7 +133,7 @@ func main() { } // API server. - srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca) // Start template build workers (2 concurrent). stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index a6571df8..de48a664 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -2,8 +2,10 @@ package main import ( "context" + "crypto/tls" "flag" "log/slog" + "net" "net/http" "os" "os/signal" @@ -14,6 +16,7 @@ import ( "github.com/joho/godotenv" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/hostagent" "git.omukk.dev/wrenn/sandbox/internal/network" @@ -50,7 +53,7 @@ func main() { listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") cpURL := os.Getenv("WRENN_CP_URL") - tokenFile := filepath.Join(rootDir, "host.jwt") + credsFile := filepath.Join(rootDir, "host-credentials.json") if cpURL == "" { slog.Error("WRENN_CP_URL environment variable is required") @@ -80,10 +83,10 @@ func main() { mgr.StartTTLReaper(ctx) // Register with the control plane and start heartbeating. - hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ CPURL: cpURL, RegistrationToken: *registrationToken, - TokenFile: tokenFile, + TokenFile: credsFile, Address: *advertiseAddr, }) if err != nil { @@ -91,17 +94,29 @@ func main() { os.Exit(1) } - hostID, err := hostagent.HostIDFromToken(hostToken) - if err != nil { - slog.Error("failed to extract host ID from token", "error", err) - os.Exit(1) - } - - slog.Info("host registered", "host_id", hostID) + slog.Info("host registered", "host_id", creds.HostID) // httpServer is declared here so the shutdown func can reference it. httpServer := &http.Server{Addr: listenAddr} + // Set up mTLS if the CP issued a certificate during registration. + var certStore hostagent.CertStore + if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" { + if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { + slog.Error("failed to load host TLS certificate", "error", err) + os.Exit(1) + } + tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert) + if tlsCfg == nil { + slog.Error("failed to build agent TLS config: invalid CA certificate PEM") + os.Exit(1) + } + httpServer.TLSConfig = tlsCfg + slog.Info("mTLS enabled on agent server") + } else { + slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP") + } + // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // and httpServer.Shutdown are each called exactly once regardless of // whether shutdown is triggered by a signal, a heartbeat 404, or the @@ -134,7 +149,7 @@ func main() { // Start heartbeat loop. Handler must be set before this because the // immediate beat can trigger doShutdown → httpServer.Shutdown synchronously. - hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, + hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second, // pauseAll: called on 3 consecutive network failures. func() { pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute) @@ -145,6 +160,17 @@ func main() { func() { doShutdown("host deleted from CP") }, + // onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh. + func(tf *hostagent.TokenFile) { + if tf.CertPEM == "" || tf.KeyPEM == "" { + return + } + if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil { + slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err) + } else { + slog.Info("TLS cert hot-swapped after credentials refresh") + } + }, ) // Graceful shutdown on SIGINT/SIGTERM. @@ -155,10 +181,30 @@ func main() { doShutdown("signal: " + sig.String()) }() - slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID) - if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - slog.Error("http server error", "error", err) - os.Exit(1) + slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID) + if httpServer.TLSConfig != nil { + // When TLSConfig is pre-populated (cert via GetCertificate callback), + // ListenAndServeTLS does not work because it requires on-disk cert/key paths. + // Instead, create the TLS listener manually and call Serve. + ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) + if err != nil { + slog.Error("failed to start TLS listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("https server error", "error", err) + os.Exit(1) + } + } else { + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + slog.Error("failed to start listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("http server error", "error", err) + os.Exit(1) + } } slog.Info("host agent stopped") diff --git a/db/migrations/20260330112050_mtls_cert_expiry.sql b/db/migrations/20260330112050_mtls_cert_expiry.sql new file mode 100644 index 00000000..e7245d27 --- /dev/null +++ b/db/migrations/20260330112050_mtls_cert_expiry.sql @@ -0,0 +1,7 @@ +-- +goose Up +ALTER TABLE hosts DROP COLUMN mtls_enabled; +ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ; + +-- +goose Down +ALTER TABLE hosts DROP COLUMN cert_expires_at; +ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql index 27ece000..0a5a150b 100644 --- a/db/queries/hosts.sql +++ b/db/queries/hosts.sql @@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC; -- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending'; +-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1; + -- name: UpdateHostStatus :exec UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1; diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 299aea9f..a7b9f5b1 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -42,7 +42,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec inner: inner, db: queries, pool: pool, - transport: http.DefaultTransport, + transport: pool.Transport(), } } @@ -110,7 +110,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - agentAddr := lifecycle.EnsureScheme(agentHost.Address) + agentAddr := h.pool.ResolveAddr(agentHost.Address) upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) target, err := url.Parse(agentAddr) diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go index c910c612..50652a00 100644 --- a/internal/api/handlers_hosts.go +++ b/internal/api/handlers_hosts.go @@ -49,6 +49,9 @@ type refreshTokenResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type deletePreviewResponse struct { @@ -69,6 +72,9 @@ type registerHostResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type addTagRequest struct { @@ -388,6 +394,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } @@ -501,6 +510,9 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } diff --git a/internal/api/server.go b/internal/api/server.go index d298b298..5d854b90 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" @@ -36,6 +37,7 @@ func New( jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string, + ca *auth.CA, ) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -44,7 +46,7 @@ func New( sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched} apiKeySvc := &service.APIKeyService{DB: queries} templateSvc := &service.TemplateService{DB: queries} - hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool} + hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca} teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} auditSvc := &service.AuditService{DB: queries} statsSvc := &service.StatsService{DB: queries, Pool: pgPool} @@ -182,6 +184,7 @@ func New( r.Post("/builds", buildH.Create) r.Get("/builds", buildH.List) r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) }) return &Server{router: r, BuildSvc: buildSvc} diff --git a/internal/auth/cert.go b/internal/auth/cert.go new file mode 100644 index 00000000..1af48672 --- /dev/null +++ b/internal/auth/cert.go @@ -0,0 +1,251 @@ +package auth + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "sync/atomic" + "time" +) + +// CPCertRenewInterval is how often the control plane should renew its client +// certificate. It is set to half the cert TTL so there is always a wide safety +// margin before expiry. +const CPCertRenewInterval = cpCertTTL / 2 + +const ( + hostCertTTL = 7 * 24 * time.Hour + cpCertTTL = 24 * time.Hour +) + +// CA holds a parsed certificate authority ready to issue leaf certificates. +type CA struct { + Cert *x509.Certificate + Key *ecdsa.PrivateKey + PEM string // PEM-encoded certificate for embedding in register/refresh responses +} + +// ParseCA parses PEM-encoded CA certificate and private key strings. +// The cert and key are expected to be ECDSA P-256. +func ParseCA(certPEM, keyPEM string) (*CA, error) { + certBlock, _ := pem.Decode([]byte(certPEM)) + if certBlock == nil { + return nil, fmt.Errorf("failed to decode CA certificate PEM") + } + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA certificate: %w", err) + } + + keyBlock, _ := pem.Decode([]byte(keyPEM)) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode CA key PEM") + } + keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA private key: %w", err) + } + + return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil +} + +// HostCert holds all material returned when issuing a leaf cert for a host agent. +type HostCert struct { + CertPEM string + KeyPEM string + Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint + ExpiresAt time.Time // stored in hosts.cert_expires_at + TLSCert tls.Certificate +} + +// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server +// certificate for the host agent. hostID becomes the common name; the host's +// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS +// stack can verify the connection without disabling hostname checking. +func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return HostCert{}, fmt.Errorf("generate host key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return HostCert{}, err + } + + now := time.Now() + expires := now.Add(hostCertTTL) + + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: hostID}, + NotBefore: now.Add(-time.Minute), // small clock-skew tolerance + NotAfter: expires, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + // Extract IP from "ip:port" address; fall back to DNS SAN if not parseable. + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + if ip := net.ParseIP(host); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{host} + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return HostCert{}, fmt.Errorf("create host certificate: %w", err) + } + + certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return HostCert{}, fmt.Errorf("marshal host key: %w", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + + tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return HostCert{}, fmt.Errorf("build TLS certificate: %w", err) + } + + fp := fmt.Sprintf("%x", sha256.Sum256(derBytes)) + + return HostCert{ + CertPEM: certPEM, + KeyPEM: keyPEM, + Fingerprint: fp, + ExpiresAt: expires, + TLSCert: tlsCert, + }, nil +} + +// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for +// the control plane to present during mTLS handshakes with host agents. +// Called once at CP startup; the result is embedded into the shared HTTP client. +func IssueCPClientCert(ca *CA) (tls.Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return tls.Certificate{}, err + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "wrenn-cp"}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(cpCertTTL), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the +// PEM-encoded CA certificate. This is used on the agent side where only the +// CA certificate (not the private key) is available. +func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM([]byte(caCertPEM)) { + return nil + } + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: pool, + GetCertificate: getCert, + MinVersion: tls.VersionTLS13, + } +} + +// CPCertStore provides lock-free read/write access to the control plane's +// current client TLS certificate. It is used with tls.Config.GetClientCertificate +// to enable hot-swap without restarting the HTTP client. +// +// The zero value is not usable; use NewCPCertStore to create one. +type CPCertStore struct { + ptr atomic.Pointer[tls.Certificate] + ca *CA +} + +// NewCPCertStore issues an initial CP client certificate from ca and returns a +// store that can renew it in place. Returns an error if the initial issuance fails. +func NewCPCertStore(ca *CA) (*CPCertStore, error) { + s := &CPCertStore{ca: ca} + if err := s.Refresh(); err != nil { + return nil, err + } + return s, nil +} + +// Refresh issues a fresh CP client certificate and atomically stores it. +// If issuance fails the existing cert is unchanged. +func (s *CPCertStore) Refresh() error { + cert, err := IssueCPClientCert(s.ca) + if err != nil { + return fmt.Errorf("renew CP client certificate: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called +// per-handshake and always returns the most recently stored certificate. +func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no CP client certificate available") + } + return cert, nil +} + +// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client. +// It uses certStore.GetClientCertificate so the certificate can be renewed +// without replacing the config or transport. +func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config { + pool := x509.NewCertPool() + pool.AddCert(ca.Cert) + return &tls.Config{ + RootCAs: pool, + GetClientCertificate: certStore.GetClientCertificate, + MinVersion: tls.VersionTLS13, + } +} + +// randomSerial returns a random 128-bit certificate serial number. +func randomSerial() (*big.Int, error) { + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("generate serial number: %w", err) + } + return serial, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index a4564aa5..e4e67403 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,11 @@ type Config struct { ListenAddr string JWTSecret string + // mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either + // disables cert issuance and leaves agent connections on plain HTTP (dev mode). + CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate + CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key + OAuthGitHubClientID string OAuthGitHubClientSecret string OAuthRedirectURL string @@ -31,6 +36,9 @@ func Load() Config { ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"), JWTSecret: os.Getenv("JWT_SECRET"), + CACert: os.Getenv("WRENN_CA_CERT"), + CAKey: os.Getenv("WRENN_CA_KEY"), + OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index 8bfd8d32..2e3962b5 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error { } const getHost = `-- name: GetHost :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 ` func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { @@ -59,13 +59,13 @@ func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } const getHostByTeam = `-- name: GetHostByTeam :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2 ` type GetHostByTeamParams struct { @@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -157,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ( const insertHost = `-- name: InsertHost :one INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled +RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at ` type InsertHostParams struct { @@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams } const listActiveHosts = `-- name: ListActiveHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at ` // Returns all hosts that have completed registration (not pending/offline). @@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { } const listHosts = `-- name: ListHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC ` func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { @@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { } const listHostsByStatus = `-- name: ListHostsByStatus :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) { @@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, } const listHostsByTag = `-- name: ListHostsByTag :many -SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h +SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h JOIN host_tags ht ON ht.host_id = h.id WHERE ht.tag = $1 ORDER BY h.created_at DESC @@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -411,7 +411,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error } const listHostsByTeam = `-- name: ListHostsByTeam :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC ` func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) { @@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho } const listHostsByType = `-- name: ListHostsByType :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) { @@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -516,24 +516,28 @@ func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error const registerHost = `-- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending' ` type RegisterHostParams struct { - ID pgtype.UUID `json:"id"` - Arch string `json:"arch"` - CpuCores int32 `json:"cpu_cores"` - MemoryMb int32 `json:"memory_mb"` - DiskGb int32 `json:"disk_gb"` - Address string `json:"address"` + ID pgtype.UUID `json:"id"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { @@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int arg.MemoryMb, arg.DiskGb, arg.Address, + arg.CertFingerprint, + arg.CertExpiresAt, ) if err != nil { return 0, err @@ -565,6 +571,25 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er return err } +const updateHostCert = `-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1 +` + +type UpdateHostCertParams struct { + ID pgtype.UUID `json:"id"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` +} + +func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error { + _, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt) + return err +} + const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 ` diff --git a/internal/db/models.go b/internal/db/models.go index d5bfc0f2..1e9a5d00 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -48,7 +48,7 @@ type Host struct { CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` CertFingerprint string `json:"cert_fingerprint"` - MtlsEnabled bool `json:"mtls_enabled"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } type HostRefreshToken struct { @@ -171,6 +171,7 @@ type TemplateBuild struct { CompletedAt pgtype.Timestamptz `json:"completed_at"` TemplateID pgtype.UUID `json:"template_id"` TeamID pgtype.UUID `json:"team_id"` + SkipPrePost bool `json:"skip_pre_post"` } type User struct { diff --git a/internal/hostagent/certstore.go b/internal/hostagent/certstore.go new file mode 100644 index 00000000..4260ba2f --- /dev/null +++ b/internal/hostagent/certstore.go @@ -0,0 +1,42 @@ +package hostagent + +import ( + "crypto/tls" + "fmt" + "sync/atomic" +) + +// CertStore provides lock-free read/write access to the agent's current TLS +// certificate. It is used with tls.Config.GetCertificate to enable hot-swap +// of the agent's cert on JWT refresh without restarting the server. +// +// The zero value is usable; GetCert returns an error until a cert is stored. +type CertStore struct { + ptr atomic.Pointer[tls.Certificate] +} + +// Store atomically replaces the current certificate. +func (s *CertStore) Store(cert *tls.Certificate) { + s.ptr.Store(cert) +} + +// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert. +// If parsing fails the existing cert is unchanged. +func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error { + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return fmt.Errorf("parse TLS key pair: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has +// been stored yet. +func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no TLS certificate available") + } + return cert, nil +} diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go index 5948e0ce..07909ee1 100644 --- a/internal/hostagent/registration.go +++ b/internal/hostagent/registration.go @@ -17,18 +17,24 @@ import ( "golang.org/x/sys/unix" ) -// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt. -type tokenFile struct { +// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json. +// It holds all credentials the agent needs: the host JWT, refresh token, and +// (when mTLS is enabled) the TLS certificate material for the agent's server. +type TokenFile struct { HostID string `json:"host_id"` JWT string `json:"jwt"` RefreshToken string `json:"refresh_token"` + // mTLS fields — empty when the CP has no CA configured. + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } // RegistrationConfig holds the configuration for host registration. type RegistrationConfig struct { CPURL string // Control plane base URL (e.g., http://localhost:8000) RegistrationToken string // One-time registration token from the control plane - TokenFile string // Path to persist the host JWT after registration + TokenFile string // Path to persist the credentials after registration Address string // Externally-reachable address (ip:port) for this host } @@ -41,22 +47,20 @@ type registerRequest struct { Address string `json:"address"` } -type registerResponse struct { +// authResponse is the shared JSON shape for both register and refresh responses. +type authResponse struct { Host json.RawMessage `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type refreshRequest struct { RefreshToken string `json:"refresh_token"` } -type refreshResponse struct { - Host json.RawMessage `json:"host"` - Token string `json:"token"` - RefreshToken string `json:"refresh_token"` -} - type errorResponse struct { Error struct { Code string `json:"code"` @@ -64,8 +68,8 @@ type errorResponse struct { } `json:"error"` } -// loadTokenFile reads and parses the persisted token file. -func loadTokenFile(path string) (*tokenFile, error) { +// LoadTokenFile reads and parses the persisted credentials file. +func LoadTokenFile(path string) (*TokenFile, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) { if !strings.HasPrefix(trimmed, "{") { // Old format: just the JWT, no refresh token. hostID, _ := hostIDFromJWT(trimmed) - return &tokenFile{HostID: hostID, JWT: trimmed}, nil + return &TokenFile{HostID: hostID, JWT: trimmed}, nil } - var tf tokenFile + var tf TokenFile if err := json.Unmarshal(data, &tf); err != nil { - return nil, fmt.Errorf("parse token file: %w", err) + return nil, fmt.Errorf("parse credentials file: %w", err) } return &tf, nil } -// saveTokenFile writes the token file as JSON with 0600 permissions. -func saveTokenFile(path string, tf tokenFile) error { +// saveTokenFile writes the credentials file as JSON with 0600 permissions. +func saveTokenFile(path string, tf TokenFile) error { data, err := json.MarshalIndent(tf, "", " ") if err != nil { - return fmt.Errorf("marshal token file: %w", err) + return fmt.Errorf("marshal credentials file: %w", err) } return os.WriteFile(path, data, 0600) } // Register calls the control plane to register this host agent and persists -// the returned JWT and refresh token to disk. Returns the host JWT token string. -func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { - // If no explicit registration token was given, reuse the saved JWT. +// the returned credentials to disk. Returns the full TokenFile on success. +func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) { + // If no explicit registration token was given, reuse the saved credentials. // A --register flag always overrides the local file so operators can - // force re-registration without manually deleting host.jwt. + // force re-registration without manually deleting the credentials file. if cfg.RegistrationToken == "" { - if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { - slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID) - return tf.JWT, nil + if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { + slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID) + return tf, nil } - return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)") + return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)") } arch := runtime.GOARCH @@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { body, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("marshal registration request: %w", err) + return nil, fmt.Errorf("marshal registration request: %w", err) } url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create registration request: %w", err) + return nil, fmt.Errorf("create registration request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("registration request failed: %w", err) + return nil, fmt.Errorf("registration request failed: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("read registration response: %w", err) + return nil, fmt.Errorf("read registration response: %w", err) } if resp.StatusCode != http.StatusCreated { var errResp errorResponse if err := json.Unmarshal(respBody, &errResp); err == nil { - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) } - var regResp registerResponse + var regResp authResponse if err := json.Unmarshal(respBody, ®Resp); err != nil { - return "", fmt.Errorf("parse registration response: %w", err) + return nil, fmt.Errorf("parse registration response: %w", err) } if regResp.Token == "" { - return "", fmt.Errorf("registration response missing token") + return nil, fmt.Errorf("registration response missing token") } hostID, err := hostIDFromJWT(regResp.Token) if err != nil { - return "", fmt.Errorf("extract host ID from JWT: %w", err) + return nil, fmt.Errorf("extract host ID from JWT: %w", err) } - // Persist JWT + refresh token. - tf := tokenFile{ + tf := TokenFile{ HostID: hostID, JWT: regResp.Token, RefreshToken: regResp.RefreshToken, + CertPEM: regResp.CertPEM, + KeyPEM: regResp.KeyPEM, + CACertPEM: regResp.CACertPEM, } if err := saveTokenFile(cfg.TokenFile, tf); err != nil { - return "", fmt.Errorf("save host token: %w", err) + return nil, fmt.Errorf("save host credentials: %w", err) } - slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID) + slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID) - return regResp.Token, nil + return &tf, nil } -// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token. -// It reads and updates the token file in place. -func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) { - tf, err := loadTokenFile(tokenFilePath) +// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh +// token, and (when mTLS is enabled) a new TLS certificate. The credentials file +// is updated in place. Returns the updated TokenFile. +func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) { + tf, err := LoadTokenFile(credentialsFilePath) if err != nil { - return "", fmt.Errorf("load token file: %w", err) + return nil, fmt.Errorf("load credentials file: %w", err) } if tf.RefreshToken == "" { - return "", fmt.Errorf("no refresh token available; host must re-register") + return nil, fmt.Errorf("no refresh token available; host must re-register") } body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken}) url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create refresh request: %w", err) + return nil, fmt.Errorf("create refresh request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("refresh request failed: %w", err) + return nil, fmt.Errorf("refresh request failed: %w", err) } defer resp.Body.Close() @@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error if resp.StatusCode != http.StatusOK { var errResp errorResponse if json.Unmarshal(respBody, &errResp) == nil { - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) } - var refResp refreshResponse + var refResp authResponse if err := json.Unmarshal(respBody, &refResp); err != nil { - return "", fmt.Errorf("parse refresh response: %w", err) + return nil, fmt.Errorf("parse refresh response: %w", err) } tf.JWT = refResp.Token tf.RefreshToken = refResp.RefreshToken - if err := saveTokenFile(tokenFilePath, *tf); err != nil { - return "", fmt.Errorf("save refreshed token: %w", err) + if refResp.CertPEM != "" { + tf.CertPEM = refResp.CertPEM + tf.KeyPEM = refResp.KeyPEM + tf.CACertPEM = refResp.CACertPEM + } + if err := saveTokenFile(credentialsFilePath, *tf); err != nil { + return nil, fmt.Errorf("save refreshed credentials: %w", err) } - slog.Info("host JWT refreshed", "host_id", tf.HostID) - return refResp.Token, nil + slog.Info("host credentials refreshed", "host_id", tf.HostID) + return tf, nil } // StartHeartbeat launches a background goroutine that sends periodic heartbeats // to the control plane. It runs until the context is cancelled. // -// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh +// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh // also fails (expired refresh token), it calls pauseAll and stops. // // On repeated network failures (3 consecutive), it calls pauseAll but keeps // retrying — the connection may recover and the host should resume heartbeating. // // onDeleted is called when CP returns 404, meaning this host record was deleted. -// The token file is removed before calling onDeleted so subsequent starts prompt -// for a new registration token. -func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) { +// The credentials file is removed before calling onDeleted so subsequent starts +// prompt for a new registration token. +// +// onCredsRefreshed is called after a successful credential refresh (JWT + cert). +// It may be nil. The caller uses it to hot-swap the agent's TLS certificate. +func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) { client := &http.Client{Timeout: 10 * time.Second} go func() { @@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure := false currentJWT := "" - // Load the current JWT from disk. - if tf, err := loadTokenFile(tokenFilePath); err == nil { + // Load the current JWT from the credentials file. + if tf, err := LoadTokenFile(credentialsFilePath); err == nil { currentJWT = tf.JWT } @@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure = false case http.StatusUnauthorized, http.StatusForbidden: - slog.Warn("heartbeat: JWT rejected — attempting token refresh") - newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath) + slog.Warn("heartbeat: JWT rejected — attempting credentials refresh") + newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath) if refreshErr != nil { - slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required", + slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required", "error", refreshErr) if pauseAll != nil && !pausedDueToFailure { pauseAll() @@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in // Stop the heartbeat loop — operator must re-register. return true } - currentJWT = newJWT - slog.Info("heartbeat: JWT refreshed successfully") + currentJWT = newCreds.JWT + slog.Info("heartbeat: credentials refreshed successfully") + if onCredsRefreshed != nil { + onCredsRefreshed(newCreds) + } case http.StatusNotFound: - slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting") - if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) { - slog.Warn("heartbeat: failed to remove token file", "error", err) + slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting") + if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) { + slog.Warn("heartbeat: failed to remove credentials file", "error", err) } if onDeleted != nil { onDeleted() @@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) { } // hostIDFromJWT is the internal implementation used by both HostIDFromToken and -// the token file loader. +// the credentials file loader. func hostIDFromJWT(token string) (string, error) { parts := strings.Split(token, ".") if len(parts) != 3 { diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go index f1341653..f5784898 100644 --- a/internal/lifecycle/hostpool.go +++ b/internal/lifecycle/hostpool.go @@ -1,6 +1,7 @@ package lifecycle import ( + "crypto/tls" "fmt" "net/http" "strings" @@ -19,14 +20,33 @@ type HostClientPool struct { mu sync.RWMutex clients map[string]hostagentv1connect.HostAgentServiceClient httpClient *http.Client + scheme string // "http://" or "https://" } -// NewHostClientPool creates a new pool. The underlying HTTP client uses a -// 10-minute timeout to support long-running streaming operations. +// NewHostClientPool creates a pool that connects to agents over plain HTTP. +// Use NewHostClientPoolTLS when mTLS is required. func NewHostClientPool() *HostClientPool { return &HostClientPool{ clients: make(map[string]hostagentv1connect.HostAgentServiceClient), httpClient: &http.Client{Timeout: 10 * time.Minute}, + scheme: "http://", + } +} + +// NewHostClientPoolTLS creates a pool that connects to agents over mTLS. +// tlsCfg should already carry the CP client cert and CA trust anchor +// (use auth.CPClientTLSConfig to construct it). +func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool { + transport := &http.Transport{ + TLSClientConfig: tlsCfg, + } + return &HostClientPool{ + clients: make(map[string]hostagentv1connect.HostAgentServiceClient), + httpClient: &http.Client{ + Timeout: 10 * time.Minute, + Transport: transport, + }, + scheme: "https://", } } @@ -46,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen if c, ok = p.clients[hostID]; ok { return c } - c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, EnsureScheme(address)) + c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address)) p.clients[hostID] = c return c } @@ -69,7 +89,34 @@ func (p *HostClientPool) Evict(hostID string) { p.mu.Unlock() } +// ensureScheme prepends the pool's configured scheme if the address has none. +func (p *HostClientPool) ensureScheme(addr string) string { + if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { + return addr + } + return p.scheme + addr +} + +// Transport returns the http.RoundTripper used by this pool. Use this when you +// need to make raw HTTP requests to agent addresses with the same TLS settings +// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy). +func (p *HostClientPool) Transport() http.RoundTripper { + if p.httpClient.Transport != nil { + return p.httpClient.Transport + } + return http.DefaultTransport +} + +// ResolveAddr prepends the pool's configured scheme to addr if it has none. +// Use this when constructing URLs that must use the same transport as the pool +// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does +// the same thing, but ResolveAddr exposes it for callers that only need the URL. +func (p *HostClientPool) ResolveAddr(addr string) string { + return p.ensureScheme(addr) +} + // EnsureScheme adds "http://" if the address has no scheme. +// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting. func EnsureScheme(addr string) string { if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { return addr diff --git a/internal/service/host.go b/internal/service/host.go index 195b9ff8..74018ebe 100644 --- a/internal/service/host.go +++ b/internal/service/host.go @@ -27,6 +27,7 @@ type HostService struct { Redis *redis.Client JWT []byte Pool *lifecycle.HostClientPool + CA *auth.CA // nil disables mTLS cert issuance (dev/test environments) } // HostCreateParams holds the parameters for creating a host. @@ -55,18 +56,28 @@ type HostRegisterParams struct { Address string } -// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token. +// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived +// refresh token, and optionally the host's mTLS certificate material. type HostRegisterResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } -// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh. +// HostRefreshResult holds a new JWT and rotated refresh token after a successful +// refresh, plus refreshed mTLS certificate material when CA is configured. type HostRefreshResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } // HostDeletePreview describes what will be affected by deleting a host. @@ -268,14 +279,25 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) } + // Issue mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err) + } + } + // Atomically update only if still pending (defense-in-depth against races). rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ - ID: hostID, - Arch: p.Arch, - CpuCores: p.CPUCores, - MemoryMb: p.MemoryMB, - DiskGb: p.DiskGB, - Address: p.Address, + ID: hostID, + Arch: p.Arch, + CpuCores: p.CPUCores, + MemoryMb: p.MemoryMB, + DiskGb: p.DiskGB, + Address: p.Address, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil}, }) if err != nil { return HostRegisterResult{}, fmt.Errorf("register host: %w", err) @@ -301,7 +323,13 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) } - return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil + result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // Refresh validates a refresh token, rotates it (revokes old, issues new), @@ -328,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err) } + // Renew mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err) + } + if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{ + ID: host.ID, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true}, + }); err != nil { + return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err) + } + } + // Issue-then-revoke rotation: insert new token first so a crash between // the two DB calls leaves the host with two valid tokens rather than zero. newRefreshToken, err := s.issueRefreshToken(ctx, host.ID) @@ -340,7 +384,13 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err) } - return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil + result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // issueRefreshToken creates a new refresh token record in the DB and returns From 948db13bed3e3e65d55c3022520c2b00fdc9de1c Mon Sep 17 00:00:00 2001 From: pptx704 Date: Mon, 30 Mar 2026 21:24:52 +0600 Subject: [PATCH 25/28] Add skip_pre_post build option, cancel endpoint, and recipe package - skip_pre_post flag on builds bypasses apt update/clean pre/post steps for faster iteration when the recipe handles its own environment setup - POST /v1/admin/builds/{id}/cancel endpoint marks an in-progress build as cancelled; UpdateBuildStatus now also sets completed_at for 'cancelled' - internal/recipe: typed recipe parser and executor (RUN/ENV/COPY steps) replacing the raw string slice approach in the build worker - pre/post build commands prefixed with RUN to match recipe step format --- .../20260330150223_build_options.sql | 11 + db/queries/template_builds.sql | 8 +- frontend/src/lib/api/builds.ts | 5 + .../src/routes/admin/templates/+page.svelte | 92 +++++-- internal/api/handlers_builds.go | 20 ++ internal/db/template_builds.sql.go | 22 +- internal/recipe/context.go | 63 +++++ internal/recipe/context_test.go | 114 ++++++++ internal/recipe/executor.go | 185 +++++++++++++ internal/recipe/step.go | 129 +++++++++ internal/recipe/step_test.go | 208 ++++++++++++++ internal/service/build.go | 258 +++++++++++------- 12 files changed, 981 insertions(+), 134 deletions(-) create mode 100644 db/migrations/20260330150223_build_options.sql create mode 100644 internal/recipe/context.go create mode 100644 internal/recipe/context_test.go create mode 100644 internal/recipe/executor.go create mode 100644 internal/recipe/step.go create mode 100644 internal/recipe/step_test.go diff --git a/db/migrations/20260330150223_build_options.sql b/db/migrations/20260330150223_build_options.sql new file mode 100644 index 00000000..981ad065 --- /dev/null +++ b/db/migrations/20260330150223_build_options.sql @@ -0,0 +1,11 @@ +-- +goose Up + +-- Allow completed_at to be set when a build is cancelled. +-- (The UpdateBuildStatus query is updated in sqlc; no schema change needed for that.) + +-- Add skip_pre_post flag: when true, the pre-build and post-build command phases +-- are skipped for this build. +ALTER TABLE template_builds ADD COLUMN skip_pre_post BOOLEAN NOT NULL DEFAULT FALSE; + +-- +goose Down +ALTER TABLE template_builds DROP COLUMN skip_pre_post; diff --git a/db/queries/template_builds.sql b/db/queries/template_builds.sql index be1c09e5..1fb07be3 100644 --- a/db/queries/template_builds.sql +++ b/db/queries/template_builds.sql @@ -1,6 +1,6 @@ -- name: InsertTemplateBuild :one -INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id) -VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10) +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11) RETURNING *; -- name: GetTemplateBuild :one @@ -12,8 +12,8 @@ SELECT * FROM template_builds ORDER BY created_at DESC; -- name: UpdateBuildStatus :one UPDATE template_builds SET status = $2, - started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, - completed_at = CASE WHEN $2 IN ('success', 'failed') THEN NOW() ELSE completed_at END + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END WHERE id = $1 RETURNING *; diff --git a/frontend/src/lib/api/builds.ts b/frontend/src/lib/api/builds.ts index 349c6e1c..1de23b8d 100644 --- a/frontend/src/lib/api/builds.ts +++ b/frontend/src/lib/api/builds.ts @@ -38,6 +38,7 @@ export type CreateBuildParams = { healthcheck?: string; vcpus?: number; memory_mb?: number; + skip_pre_post?: boolean; }; export async function createBuild(params: CreateBuildParams): Promise> { @@ -69,3 +70,7 @@ export async function listAdminTemplates(): Promise> export async function deleteAdminTemplate(name: string): Promise> { return apiFetch('DELETE', `/api/v1/admin/templates/${name}`); } + +export async function cancelBuild(id: string): Promise> { + return apiFetch('POST', `/api/v1/admin/builds/${id}/cancel`); +} diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte index dde8fc3a..4619e7b9 100644 --- a/frontend/src/routes/admin/templates/+page.svelte +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -6,6 +6,7 @@ import { listBuilds, createBuild, + cancelBuild, listAdminTemplates, deleteAdminTemplate, type Build, @@ -52,11 +53,15 @@ vcpus: 1, memory_mb: 512, recipe: '', - healthcheck: '' + healthcheck: '', + skip_pre_post: false }); let creating = $state(false); let createError = $state(null); + // Cancel build state + let cancelingBuildId = $state(null); + // Stats let templateCount = $derived(templates.length); let snapshotCount = $derived(templates.filter((t) => t.type === 'snapshot').length); @@ -123,12 +128,13 @@ recipe: lines, healthcheck: createForm.healthcheck.trim() || undefined, vcpus: createForm.vcpus, - memory_mb: createForm.memory_mb + memory_mb: createForm.memory_mb, + skip_pre_post: createForm.skip_pre_post }); if (result.ok) { showCreate = false; - createForm = { name: '', base_template: 'minimal', vcpus: 1, memory_mb: 512, recipe: '', healthcheck: '' }; + createForm = { name: '', base_template: 'minimal', vcpus: 1, memory_mb: 512, recipe: '', healthcheck: '', skip_pre_post: false }; builds = [result.data, ...builds]; activeTab = 'builds'; expandedBuildId = result.data.id; @@ -156,6 +162,18 @@ deleting = false; } + async function handleCancelBuild(buildId: string) { + cancelingBuildId = buildId; + const result = await cancelBuild(buildId); + if (result.ok) { + builds = builds.map((b) => b.id === buildId ? { ...b, status: 'cancelled' } : b); + toast.success('Build cancelled'); + } else { + toast.error(result.error ?? 'Failed to cancel build'); + } + cancelingBuildId = null; + } + function toggleBuildExpand(buildId: string) { if (expandedBuildId === buildId) { expandedBuildId = null; @@ -198,10 +216,28 @@ case 'success': return 'var(--color-accent-bright)'; case 'failed': return 'var(--color-red)'; case 'running': return 'var(--color-blue)'; + case 'cancelled': return 'var(--color-amber)'; default: return 'var(--color-text-muted)'; } } + // Returns [keyword, rest] from a recipe instruction string. + function splitInstruction(cmd: string): [string, string] { + const idx = cmd.indexOf(' '); + if (idx === -1) return [cmd.toUpperCase(), '']; + return [cmd.slice(0, idx).toUpperCase(), cmd.slice(idx + 1)]; + } + + function keywordColor(keyword: string): string { + switch (keyword) { + case 'RUN': return 'var(--color-blue)'; + case 'START': return 'var(--color-accent-bright)'; + case 'ENV': return 'var(--color-amber)'; + case 'WORKDIR': return 'var(--color-text-tertiary)'; + default: return 'var(--color-text-muted)'; + } + } + onMount(() => { fetchTemplates(); fetchBuilds().then(startPolling); @@ -512,6 +548,22 @@
+ {#if build.status === 'pending' || build.status === 'running'} +
+ +
+ {/if} {#if build.error}
{build.error} @@ -524,6 +576,7 @@ {@const isInternal = log.phase === 'pre-build' || log.phase === 'post-build'} {@const recipeIdx = log.phase === 'recipe' ? build.logs.filter(l => l.phase === 'recipe' && l.step <= log.step).length : 0} {@const phaseLabel = isInternal ? (log.phase === 'pre-build' ? 'Pre-build' : 'Post-build') : `Step ${recipeIdx}`} + {@const [kw, kwRest] = splitInstruction(log.cmd)}
+ +
diff --git a/internal/api/handlers_builds.go b/internal/api/handlers_builds.go index 3b964008..282c3f48 100644 --- a/internal/api/handlers_builds.go +++ b/internal/api/handlers_builds.go @@ -36,6 +36,7 @@ type createBuildRequest struct { Healthcheck string `json:"healthcheck"` VCPUs int32 `json:"vcpus"` MemoryMB int32 `json:"memory_mb"` + SkipPrePost bool `json:"skip_pre_post"` } type buildResponse struct { @@ -127,6 +128,7 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) { Healthcheck: req.Healthcheck, VCPUs: req.VCPUs, MemoryMB: req.MemoryMB, + SkipPrePost: req.SkipPrePost, }) if err != nil { slog.Error("failed to create build", "error", err) @@ -254,3 +256,21 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } + +// Cancel handles POST /v1/admin/builds/{id}/cancel. +func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) { + buildIDStr := chi.URLParam(r, "id") + + buildID, err := id.ParseBuildID(buildIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID") + return + } + + if err := h.svc.Cancel(r.Context(), buildID); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/db/template_builds.sql.go b/internal/db/template_builds.sql.go index 7aa1b67e..facfb199 100644 --- a/internal/db/template_builds.sql.go +++ b/internal/db/template_builds.sql.go @@ -12,7 +12,7 @@ import ( ) const getTemplateBuild = `-- name: GetTemplateBuild :one -SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id FROM template_builds WHERE id = $1 +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1 ` func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) { @@ -38,14 +38,15 @@ func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (Templat &i.CompletedAt, &i.TemplateID, &i.TeamID, + &i.SkipPrePost, ) return i, err } const insertTemplateBuild = `-- name: InsertTemplateBuild :one -INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id) -VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10) -RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id +INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post) +VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11) +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post ` type InsertTemplateBuildParams struct { @@ -59,6 +60,7 @@ type InsertTemplateBuildParams struct { TotalSteps int32 `json:"total_steps"` TemplateID pgtype.UUID `json:"template_id"` TeamID pgtype.UUID `json:"team_id"` + SkipPrePost bool `json:"skip_pre_post"` } func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) { @@ -73,6 +75,7 @@ func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBui arg.TotalSteps, arg.TemplateID, arg.TeamID, + arg.SkipPrePost, ) var i TemplateBuild err := row.Scan( @@ -95,12 +98,13 @@ func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBui &i.CompletedAt, &i.TemplateID, &i.TeamID, + &i.SkipPrePost, ) return i, err } const listTemplateBuilds = `-- name: ListTemplateBuilds :many -SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id FROM template_builds ORDER BY created_at DESC +SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC ` func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) { @@ -132,6 +136,7 @@ func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, erro &i.CompletedAt, &i.TemplateID, &i.TeamID, + &i.SkipPrePost, ); err != nil { return nil, err } @@ -196,10 +201,10 @@ func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandbox const updateBuildStatus = `-- name: UpdateBuildStatus :one UPDATE template_builds SET status = $2, - started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, - completed_at = CASE WHEN $2 IN ('success', 'failed') THEN NOW() ELSE completed_at END + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END WHERE id = $1 -RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id +RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post ` type UpdateBuildStatusParams struct { @@ -230,6 +235,7 @@ func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusPa &i.CompletedAt, &i.TemplateID, &i.TeamID, + &i.SkipPrePost, ) return i, err } diff --git a/internal/recipe/context.go b/internal/recipe/context.go new file mode 100644 index 00000000..db4c39cc --- /dev/null +++ b/internal/recipe/context.go @@ -0,0 +1,63 @@ +package recipe + +import "strings" + +// ExecContext holds mutable state that persists across recipe steps. +// It is initialized empty and updated by ENV and WORKDIR steps. +type ExecContext struct { + WorkDir string + EnvVars map[string]string +} + +// WrappedCommand returns the full shell command for a RUN step with context +// applied. The result is passed as the argument to /bin/sh -c. +// +// If WORKDIR and/or ENV are set, they are prepended as a shell preamble: +// +// cd '/the/dir' && KEY='val' /bin/sh -c 'original command' +func (c *ExecContext) WrappedCommand(cmd string) string { + prefix := c.shellPrefix() + if prefix == "" { + return cmd + } + return prefix + "/bin/sh -c " + shellescape(cmd) +} + +// StartCommand returns the shell command for a START step. The process is +// launched in the background via nohup so that the outer shell exits +// immediately, allowing the build to continue. stdout/stderr of the +// background process are discarded (the process keeps running in the VM). +// +// Multiple START steps can be issued to run several background processes +// simultaneously before a healthcheck is evaluated. +func (c *ExecContext) StartCommand(cmd string) string { + prefix := c.shellPrefix() + return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &" +} + +// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command. +// Returns an empty string when no context is set. +func (c *ExecContext) shellPrefix() string { + if c.WorkDir == "" && len(c.EnvVars) == 0 { + return "" + } + var sb strings.Builder + if c.WorkDir != "" { + sb.WriteString("cd ") + sb.WriteString(shellescape(c.WorkDir)) + sb.WriteString(" && ") + } + for k, v := range c.EnvVars { + sb.WriteString(k) + sb.WriteByte('=') + sb.WriteString(shellescape(v)) + sb.WriteByte(' ') + } + return sb.String() +} + +// shellescape wraps s in single quotes, escaping any embedded single quotes. +// This is POSIX-safe for paths, env values, and shell commands. +func shellescape(s string) string { + return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'" +} diff --git a/internal/recipe/context_test.go b/internal/recipe/context_test.go new file mode 100644 index 00000000..b00dfce0 --- /dev/null +++ b/internal/recipe/context_test.go @@ -0,0 +1,114 @@ +package recipe + +import "testing" + +func TestExecContext_WrappedCommand(t *testing.T) { + tests := []struct { + name string + ctx ExecContext + cmd string + want string + }{ + { + name: "no context", + ctx: ExecContext{}, + cmd: "apt install -y curl", + want: "apt install -y curl", + }, + { + name: "workdir only", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "npm install", + want: "cd '/app' && /bin/sh -c 'npm install'", + }, + { + name: "env only", + ctx: ExecContext{EnvVars: map[string]string{"PORT": "8080"}}, + cmd: "node server.js", + want: "PORT='8080' /bin/sh -c 'node server.js'", + }, + { + name: "workdir with space", + ctx: ExecContext{WorkDir: "/my project"}, + cmd: "make build", + want: "cd '/my project' && /bin/sh -c 'make build'", + }, + { + name: "command with single quotes", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "echo 'hello'", + want: "cd '/app' && /bin/sh -c 'echo '\\''hello'\\'''", + }, + { + name: "env value with single quotes", + ctx: ExecContext{EnvVars: map[string]string{"MSG": "it's fine"}}, + cmd: "echo $MSG", + want: "MSG='it'\\''s fine' /bin/sh -c 'echo $MSG'", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.ctx.WrappedCommand(tc.cmd) + if got != tc.want { + t.Errorf("WrappedCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want) + } + }) + } +} + +func TestExecContext_StartCommand(t *testing.T) { + tests := []struct { + name string + ctx ExecContext + cmd string + want string + }{ + { + name: "no context", + ctx: ExecContext{}, + cmd: "python3 app.py", + want: "nohup /bin/sh -c 'python3 app.py' >/dev/null 2>&1 &", + }, + { + name: "with workdir", + ctx: ExecContext{WorkDir: "/app"}, + cmd: "python3 server.py", + want: "cd '/app' && nohup /bin/sh -c 'python3 server.py' >/dev/null 2>&1 &", + }, + { + name: "with env", + ctx: ExecContext{EnvVars: map[string]string{"PORT": "9000"}}, + cmd: "node index.js", + want: "PORT='9000' nohup /bin/sh -c 'node index.js' >/dev/null 2>&1 &", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.ctx.StartCommand(tc.cmd) + if got != tc.want { + t.Errorf("StartCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want) + } + }) + } +} + +func TestShellescape(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"simple", "'simple'"}, + {"/path/to/dir", "'/path/to/dir'"}, + {"it's fine", "'it'\\''s fine'"}, + {"", "''"}, + {"a'b'c", "'a'\\''b'\\''c'"}, + } + for _, tc := range tests { + got := shellescape(tc.input) + if got != tc.want { + t.Errorf("shellescape(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/internal/recipe/executor.go b/internal/recipe/executor.go new file mode 100644 index 00000000..3df45dc5 --- /dev/null +++ b/internal/recipe/executor.go @@ -0,0 +1,185 @@ +package recipe + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "connectrpc.com/connect" + + pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" +) + +// DefaultStepTimeout is the fallback timeout for RUN steps that carry no +// explicit --timeout flag. +const DefaultStepTimeout = 30 * time.Second + +// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB). +type BuildLogEntry struct { + Step int `json:"step"` + Phase string `json:"phase"` + Cmd string `json:"cmd"` + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + Exit int32 `json:"exit"` + Ok bool `json:"ok"` + Elapsed int64 `json:"elapsed_ms"` +} + +// ExecFunc is the agent.Exec call signature used by the executor. It matches +// the method on the hostagent Connect RPC client. +type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error) + +// Execute runs steps sequentially against sandboxID using execFn. +// +// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build"). +// - startStep is the 1-based offset so entries are globally numbered across phases. +// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes. +// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward +// into subsequent phases when the caller passes the same pointer. +// +// Returns all log entries appended during this call, the next step counter +// value, and whether all steps succeeded. On false the last entry contains +// failure details; the caller is responsible for destroying the sandbox and +// recording the build error. +func Execute( + ctx context.Context, + phase string, + steps []Step, + sandboxID string, + startStep int, + defaultTimeout time.Duration, + bctx *ExecContext, + execFn ExecFunc, +) (entries []BuildLogEntry, nextStep int, ok bool) { + if defaultTimeout <= 0 { + defaultTimeout = 10 * time.Minute + } + + step := startStep + for _, st := range steps { + step++ + slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw) + + switch st.Kind { + case KindENV: + if bctx.EnvVars == nil { + bctx.EnvVars = make(map[string]string) + } + bctx.EnvVars[st.Key] = st.Value + entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true}) + + case KindWORKDIR: + bctx.WorkDir = st.Path + entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true}) + + case KindUSER, KindCOPY: + verb := strings.ToUpper(strings.Fields(st.Raw)[0]) + entries = append(entries, BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Stderr: verb + " is not yet supported", + Ok: false, + }) + return entries, step, false + + case KindSTART: + entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn) + entries = append(entries, entry) + if !succeeded { + return entries, step, false + } + + case KindRUN: + timeout := defaultTimeout + if st.Timeout > 0 { + timeout = st.Timeout + } + entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn) + entries = append(entries, entry) + if !succeeded { + return entries, step, false + } + } + } + return entries, step, true +} + +func execRun( + ctx context.Context, + st Step, + sandboxID, phase string, + step int, + timeout time.Duration, + bctx *ExecContext, + execFn ExecFunc, +) (BuildLogEntry, bool) { + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + start := time.Now() + resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", bctx.WrappedCommand(st.Shell)}, + TimeoutSec: int32(timeout.Seconds()), + })) + + entry := BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Elapsed: time.Since(start).Milliseconds(), + } + if err != nil { + entry.Stderr = fmt.Sprintf("exec error: %v", err) + return entry, false + } + entry.Stdout = string(resp.Msg.Stdout) + entry.Stderr = string(resp.Msg.Stderr) + entry.Exit = resp.Msg.ExitCode + entry.Ok = resp.Msg.ExitCode == 0 + return entry, entry.Ok +} + +func execStart( + ctx context.Context, + st Step, + sandboxID, phase string, + step int, + bctx *ExecContext, + execFn ExecFunc, +) (BuildLogEntry, bool) { + // START uses a short timeout: just long enough for the shell to fork and + // return. The background process itself runs indefinitely inside the VM. + execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + start := time.Now() + resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{ + SandboxId: sandboxID, + Cmd: "/bin/sh", + Args: []string{"-c", bctx.StartCommand(st.Shell)}, + TimeoutSec: 10, + })) + + entry := BuildLogEntry{ + Step: step, + Phase: phase, + Cmd: st.Raw, + Elapsed: time.Since(start).Milliseconds(), + } + if err != nil { + entry.Stderr = fmt.Sprintf("start error: %v", err) + return entry, false + } + entry.Exit = resp.Msg.ExitCode + entry.Ok = resp.Msg.ExitCode == 0 + if !entry.Ok { + entry.Stderr = fmt.Sprintf("start failed with exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr)) + } + return entry, entry.Ok +} diff --git a/internal/recipe/step.go b/internal/recipe/step.go new file mode 100644 index 00000000..7d510362 --- /dev/null +++ b/internal/recipe/step.go @@ -0,0 +1,129 @@ +package recipe + +import ( + "fmt" + "strings" + "time" +) + +// Kind identifies the instruction type in a recipe line. +type Kind int + +const ( + KindRUN Kind = iota // Execute a command and wait for it to exit. + KindSTART // Start a command in the background (non-blocking). + KindENV // Set an environment variable for subsequent steps. + KindWORKDIR // Set the working directory for subsequent steps. + KindUSER // Switch the unix user for subsequent steps. (stub) + KindCOPY // Copy files into the sandbox. (stub) +) + +// Step is the parsed representation of one recipe instruction. +type Step struct { + Kind Kind + Raw string // original string, preserved for logging + Shell string // KindRUN, KindSTART: the shell command text + Timeout time.Duration // KindRUN: 0 means use caller's default + Key string // KindENV: variable name + Value string // KindENV: variable value + Path string // KindWORKDIR: directory path +} + +// ParseStep parses a single recipe instruction string into a Step. +// Instructions are Dockerfile-like: a keyword followed by arguments. +// +// Supported syntax: +// +// RUN — run command, wait for exit +// RUN --timeout= — run command with explicit timeout (e.g. --timeout=5m) +// START — start command in background, return immediately +// ENV = — set environment variable +// WORKDIR — set working directory +// USER — not yet supported +// COPY — not yet supported +func ParseStep(s string) (Step, error) { + s = strings.TrimSpace(s) + if s == "" { + return Step{}, fmt.Errorf("empty step") + } + + // Split on first space to get the keyword. + keyword, rest, _ := strings.Cut(s, " ") + rest = strings.TrimSpace(rest) + + switch strings.ToUpper(keyword) { + case "RUN": + return parseRUN(s, rest) + case "START": + return parseSTART(s, rest) + case "ENV": + return parseENV(s, rest) + case "WORKDIR": + return parseWORKDIR(s, rest) + case "USER": + return Step{Kind: KindUSER, Raw: s}, nil + case "COPY": + return Step{Kind: KindCOPY, Raw: s}, nil + default: + return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword) + } +} + +// ParseRecipe parses all recipe lines, returning on the first error. +func ParseRecipe(lines []string) ([]Step, error) { + steps := make([]Step, 0, len(lines)) + for i, line := range lines { + st, err := ParseStep(line) + if err != nil { + return nil, fmt.Errorf("recipe line %d: %w", i+1, err) + } + steps = append(steps, st) + } + return steps, nil +} + +func parseRUN(raw, rest string) (Step, error) { + var timeout time.Duration + if strings.HasPrefix(rest, "--timeout=") { + rest = rest[len("--timeout="):] + flag, cmd, found := strings.Cut(rest, " ") + if !found || strings.TrimSpace(cmd) == "" { + return Step{}, fmt.Errorf("RUN --timeout= flag has no command: %q", raw) + } + d, err := time.ParseDuration(flag) + if err != nil { + return Step{}, fmt.Errorf("RUN --timeout= invalid duration %q: %w", flag, err) + } + timeout = d + rest = strings.TrimSpace(cmd) + } + if rest == "" { + return Step{}, fmt.Errorf("RUN requires a command: %q", raw) + } + return Step{Kind: KindRUN, Raw: raw, Shell: rest, Timeout: timeout}, nil +} + +func parseSTART(raw, rest string) (Step, error) { + if rest == "" { + return Step{}, fmt.Errorf("START requires a command: %q", raw) + } + return Step{Kind: KindSTART, Raw: raw, Shell: rest}, nil +} + +func parseENV(raw, rest string) (Step, error) { + key, value, found := strings.Cut(rest, "=") + if !found { + return Step{}, fmt.Errorf("ENV requires KEY=VALUE format: %q", raw) + } + if key == "" { + return Step{}, fmt.Errorf("ENV key is empty: %q", raw) + } + return Step{Kind: KindENV, Raw: raw, Key: key, Value: value}, nil +} + +func parseWORKDIR(raw, path string) (Step, error) { + if path == "" { + return Step{}, fmt.Errorf("WORKDIR requires a path: %q", raw) + } + return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil +} diff --git a/internal/recipe/step_test.go b/internal/recipe/step_test.go new file mode 100644 index 00000000..2370bb21 --- /dev/null +++ b/internal/recipe/step_test.go @@ -0,0 +1,208 @@ +package recipe + +import ( + "testing" + "time" +) + +func TestParseStep(t *testing.T) { + tests := []struct { + name string + input string + want Step + wantErr bool + }{ + // RUN + { + name: "RUN basic", + input: "RUN apt install -y curl", + want: Step{Kind: KindRUN, Raw: "RUN apt install -y curl", Shell: "apt install -y curl"}, + }, + { + name: "RUN lowercase", + input: "run echo hello", + want: Step{Kind: KindRUN, Raw: "run echo hello", Shell: "echo hello"}, + }, + { + name: "RUN with timeout", + input: "RUN --timeout=5m npm install", + want: Step{Kind: KindRUN, Raw: "RUN --timeout=5m npm install", Shell: "npm install", Timeout: 5 * time.Minute}, + }, + { + name: "RUN with timeout seconds", + input: "RUN --timeout=30s make build", + want: Step{Kind: KindRUN, Raw: "RUN --timeout=30s make build", Shell: "make build", Timeout: 30 * time.Second}, + }, + { + name: "RUN no command", + input: "RUN", + wantErr: true, + }, + { + name: "RUN timeout no command", + input: "RUN --timeout=5m", + wantErr: true, + }, + { + name: "RUN invalid timeout", + input: "RUN --timeout=notaduration echo hi", + wantErr: true, + }, + // START + { + name: "START basic", + input: "START python3 app.py", + want: Step{Kind: KindSTART, Raw: "START python3 app.py", Shell: "python3 app.py"}, + }, + { + name: "START uppercase", + input: "START node server.js --port=8080", + want: Step{Kind: KindSTART, Raw: "START node server.js --port=8080", Shell: "node server.js --port=8080"}, + }, + { + name: "START no command", + input: "START", + wantErr: true, + }, + // ENV + { + name: "ENV basic", + input: "ENV FOO=bar", + want: Step{Kind: KindENV, Raw: "ENV FOO=bar", Key: "FOO", Value: "bar"}, + }, + { + name: "ENV value with spaces", + input: "ENV GREETING=hello world", + want: Step{Kind: KindENV, Raw: "ENV GREETING=hello world", Key: "GREETING", Value: "hello world"}, + }, + { + name: "ENV value with equals sign", + input: "ENV URL=http://example.com?a=1", + want: Step{Kind: KindENV, Raw: "ENV URL=http://example.com?a=1", Key: "URL", Value: "http://example.com?a=1"}, + }, + { + name: "ENV empty value", + input: "ENV FOO=", + want: Step{Kind: KindENV, Raw: "ENV FOO=", Key: "FOO", Value: ""}, + }, + { + name: "ENV missing equals", + input: "ENV FOO", + wantErr: true, + }, + { + name: "ENV empty key", + input: "ENV =value", + wantErr: true, + }, + // WORKDIR + { + name: "WORKDIR basic", + input: "WORKDIR /app", + want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /app", Path: "/app"}, + }, + { + name: "WORKDIR with spaces in path", + input: "WORKDIR /my project", + want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /my project", Path: "/my project"}, + }, + { + name: "WORKDIR empty", + input: "WORKDIR", + wantErr: true, + }, + // USER and COPY stubs + { + name: "USER stub", + input: "USER www-data", + want: Step{Kind: KindUSER, Raw: "USER www-data"}, + }, + { + name: "COPY stub", + input: "COPY config.yaml /etc/app/config.yaml", + want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"}, + }, + // Unknown keyword + { + name: "unknown keyword", + input: "FROBNICATE something", + wantErr: true, + }, + // Empty input + { + name: "empty string", + input: "", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseStep(tc.input) + if tc.wantErr { + if err == nil { + t.Fatalf("ParseStep(%q) expected error, got %+v", tc.input, got) + } + return + } + if err != nil { + t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err) + } + if got != tc.want { + t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want) + } + }) + } +} + +func TestParseRecipe(t *testing.T) { + t.Run("valid recipe", func(t *testing.T) { + lines := []string{ + "RUN apt update", + "WORKDIR /app", + "ENV PORT=8080", + "START python3 server.py", + "RUN --timeout=2m pip install -r requirements.txt", + } + steps, err := ParseRecipe(lines) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(steps) != 5 { + t.Fatalf("expected 5 steps, got %d", len(steps)) + } + if steps[0].Kind != KindRUN { + t.Errorf("step 0: want KindRUN, got %v", steps[0].Kind) + } + if steps[1].Kind != KindWORKDIR { + t.Errorf("step 1: want KindWORKDIR, got %v", steps[1].Kind) + } + if steps[3].Kind != KindSTART { + t.Errorf("step 3: want KindSTART, got %v", steps[3].Kind) + } + if steps[4].Timeout != 2*time.Minute { + t.Errorf("step 4: want 2m timeout, got %v", steps[4].Timeout) + } + }) + + t.Run("error on invalid line", func(t *testing.T) { + lines := []string{ + "RUN apt update", + "BADCMD something", + } + _, err := ParseRecipe(lines) + if err == nil { + t.Fatal("expected error for invalid line, got nil") + } + }) + + t.Run("empty recipe", func(t *testing.T) { + steps, err := ParseRecipe(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(steps) != 0 { + t.Fatalf("expected 0 steps, got %d", len(steps)) + } + }) +} diff --git a/internal/service/build.go b/internal/service/build.go index 2592a6d9..1108044d 100644 --- a/internal/service/build.go +++ b/internal/service/build.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "sync" "time" "connectrpc.com/connect" @@ -14,6 +15,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" + "git.omukk.dev/wrenn/sandbox/internal/recipe" "git.omukk.dev/wrenn/sandbox/internal/scheduler" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" ) @@ -27,14 +29,14 @@ const ( // preBuildCmds run before the user recipe to prepare the build environment. var preBuildCmds = []string{ - "apt update", + "RUN apt update", } // postBuildCmds run after the user recipe to clean up caches and reduce image size. var postBuildCmds = []string{ - "apt clean", - "apt autoremove -y", - "rm -rf /var/lib/apt/lists/*", + "RUN apt clean", + "RUN apt autoremove -y", + "RUN rm -rf /var/lib/apt/lists/*", } // buildAgentClient is the subset of the host agent client used by the build worker. @@ -46,24 +48,15 @@ type buildAgentClient interface { FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error) } -// BuildLogEntry represents a single entry in the build log JSONB array. -type BuildLogEntry struct { - Step int `json:"step"` - Phase string `json:"phase"` // "pre-build", "recipe", or "post-build" - Cmd string `json:"cmd"` - Stdout string `json:"stdout"` - Stderr string `json:"stderr"` - Exit int32 `json:"exit"` - Ok bool `json:"ok"` - Elapsed int64 `json:"elapsed_ms"` -} - // BuildService handles template build orchestration. type BuildService struct { DB *db.Queries Redis *redis.Client Pool *lifecycle.HostClientPool Scheduler scheduler.HostScheduler + + mu sync.Mutex + cancelMap map[string]context.CancelFunc // buildID → per-build cancel func } // BuildCreateParams holds the parameters for creating a template build. @@ -74,6 +67,7 @@ type BuildCreateParams struct { Healthcheck string VCPUs int32 MemoryMB int32 + SkipPrePost bool } // Create inserts a new build record and enqueues it to Redis. @@ -97,6 +91,11 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp buildIDStr := id.FormatBuildID(buildID) newTemplateID := id.NewTemplateID() + defaultSteps := len(preBuildCmds) + len(postBuildCmds) + if p.SkipPrePost { + defaultSteps = 0 + } + build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{ ID: buildID, Name: p.Name, @@ -105,9 +104,10 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp Healthcheck: p.Healthcheck, Vcpus: p.VCPUs, MemoryMb: p.MemoryMB, - TotalSteps: int32(len(p.Recipe) + len(preBuildCmds) + len(postBuildCmds)), + TotalSteps: int32(len(p.Recipe) + defaultSteps), TemplateID: newTemplateID, TeamID: id.PlatformTeamID, + SkipPrePost: p.SkipPrePost, }) if err != nil { return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err) @@ -131,6 +131,40 @@ func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) { return s.DB.ListTemplateBuilds(ctx) } +// Cancel cancels a pending or running build. For pending builds the status is +// updated in the DB and the worker skips it when dequeued. For running builds +// the per-build context is cancelled, which causes the current exec step to +// abort; executeBuild then detects the cancellation and records the status. +func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error { + build, err := s.DB.GetTemplateBuild(ctx, buildID) + if err != nil { + return fmt.Errorf("get build: %w", err) + } + switch build.Status { + case "success", "failed", "cancelled": + return fmt.Errorf("build is already %s", build.Status) + } + + // Mark cancelled in DB first. This handles both pending builds (which haven't + // been picked up yet) and acts as a flag for executeBuild to check on start. + if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + ID: buildID, Status: "cancelled", + }); err != nil { + return fmt.Errorf("update build status: %w", err) + } + + // If the build is currently running, signal its context. + buildIDStr := id.FormatBuildID(buildID) + s.mu.Lock() + cancel, running := s.cancelMap[buildIDStr] + s.mu.Unlock() + if running { + cancel() + } + + return nil +} + // StartWorkers launches n goroutines that consume from the Redis build queue. // The returned cancel function stops all workers. func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc { @@ -172,14 +206,38 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { return } - build, err := s.DB.GetTemplateBuild(ctx, buildID) + // Create a per-build context so this build can be cancelled independently of + // the worker. Register in cancelMap before fetching the build so that a + // concurrent Cancel call can always find and signal it. + buildCtx, buildCancel := context.WithCancel(ctx) + defer buildCancel() + + s.mu.Lock() + if s.cancelMap == nil { + s.cancelMap = make(map[string]context.CancelFunc) + } + s.cancelMap[buildIDStr] = buildCancel + s.mu.Unlock() + defer func() { + s.mu.Lock() + delete(s.cancelMap, buildIDStr) + s.mu.Unlock() + }() + + build, err := s.DB.GetTemplateBuild(buildCtx, buildID) if err != nil { log.Error("failed to fetch build", "error", err) return } + // Skip if already cancelled (Cancel was called before we dequeued). + if build.Status == "cancelled" { + log.Info("build already cancelled, skipping") + return + } + // Mark as running. - if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{ ID: buildID, Status: "running", }); err != nil { log.Error("failed to update build status", "error", err) @@ -187,22 +245,22 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { } // Parse user recipe. - var recipe []string - if err := json.Unmarshal(build.Recipe, &recipe); err != nil { - s.failBuild(ctx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err)) + var userRecipe []string + if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil { + s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err)) return } // Pick a platform host and create a sandbox. - host, err := s.Scheduler.SelectHost(ctx, id.PlatformTeamID, false) + host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false) if err != nil { - s.failBuild(ctx, buildID, fmt.Sprintf("no host available: %v", err)) + s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err)) return } agent, err := s.Pool.GetForHost(host) if err != nil { - s.failBuild(ctx, buildID, fmt.Sprintf("agent client error: %v", err)) + s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err)) return } @@ -214,16 +272,16 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { baseTeamID := id.PlatformTeamID baseTemplateID := id.MinimalTemplateID if build.BaseTemplate != "minimal" { - baseTmpl, err := s.DB.GetPlatformTemplateByName(ctx, build.BaseTemplate) + baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate) if err != nil { - s.failBuild(ctx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err)) + s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err)) return } baseTeamID = baseTmpl.TeamID baseTemplateID = baseTmpl.ID } - resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{ + resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{ SandboxId: sandboxIDStr, Template: build.BaseTemplate, TeamId: id.UUIDString(baseTeamID), @@ -234,129 +292,121 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { DiskSizeMb: 5120, // 5 GB for template builds })) if err != nil { - s.failBuild(ctx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) + s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err)) return } _ = resp // Record sandbox/host association. - _ = s.DB.UpdateBuildSandbox(ctx, db.UpdateBuildSandboxParams{ + _ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{ ID: buildID, SandboxID: sandboxID, HostID: host.ID, }) + // Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always + // valid; panic on error is appropriate here since it would be a programmer mistake. + preBuildSteps, err := recipe.ParseRecipe(preBuildCmds) + if err != nil { + panic(fmt.Sprintf("invalid pre-build recipe: %v", err)) + } + userRecipeSteps, err := recipe.ParseRecipe(userRecipe) + if err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err)) + return + } + postBuildSteps, err := recipe.ParseRecipe(postBuildCmds) + if err != nil { + panic(fmt.Sprintf("invalid post-build recipe: %v", err)) + } + // Execute build phases: pre-build → user recipe → post-build. - var logs []BuildLogEntry + // bctx carries working directory and env vars across all phases. + var logs []recipe.BuildLogEntry step := 0 + bctx := &recipe.ExecContext{} - // Helper to run a list of commands in a given phase. - // timeout=0 means no timeout (uses parent context). - runPhase := func(phase string, cmds []string, timeout time.Duration) bool { - for _, cmd := range cmds { - step++ - log.Info("executing build step", "phase", phase, "step", step, "cmd", cmd) - - execCtx := ctx - var cancel context.CancelFunc - // When no timeout is specified, use 10 minutes as a generous upper - // bound. The host agent defaults TimeoutSec=0 to 30s, so we must - // always send an explicit value. - effectiveTimeout := timeout - if effectiveTimeout <= 0 { - effectiveTimeout = 10 * time.Minute - } - execCtx, cancel = context.WithTimeout(ctx, effectiveTimeout) - timeoutSec := int32(effectiveTimeout.Seconds()) - - start := time.Now() - execResp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{ - SandboxId: sandboxIDStr, - Cmd: "/bin/sh", - Args: []string{"-c", cmd}, - TimeoutSec: timeoutSec, - })) - cancel() - - entry := BuildLogEntry{ - Step: step, - Phase: phase, - Cmd: cmd, - Elapsed: time.Since(start).Milliseconds(), - } - - if err != nil { - entry.Stderr = err.Error() - entry.Ok = false - logs = append(logs, entry) - s.updateLogs(ctx, buildID, step, logs) - s.destroySandbox(ctx, agent, sandboxIDStr) - s.failBuild(ctx, buildID, fmt.Sprintf("%s step %d failed: %v", phase, step, err)) + runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool { + newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec) + logs = append(logs, newEntries...) + step = nextStep + s.updateLogs(buildCtx, buildID, step, logs) + if !ok { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + // If the build was cancelled, status is already set — don't overwrite with "failed". + if buildCtx.Err() != nil { return false } - - entry.Stdout = string(execResp.Msg.Stdout) - entry.Stderr = string(execResp.Msg.Stderr) - entry.Exit = execResp.Msg.ExitCode - entry.Ok = execResp.Msg.ExitCode == 0 - logs = append(logs, entry) - s.updateLogs(ctx, buildID, step, logs) - - if execResp.Msg.ExitCode != 0 { - s.destroySandbox(ctx, agent, sandboxIDStr) - s.failBuild(ctx, buildID, fmt.Sprintf("%s step %d failed with exit code %d", phase, step, execResp.Msg.ExitCode)) - return false + last := newEntries[len(newEntries)-1] + reason := last.Stderr + if reason == "" { + reason = fmt.Sprintf("exit code %d", last.Exit) } + s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason)) } - return true + return ok } - if !runPhase("pre-build", preBuildCmds, 0) { + if !build.SkipPrePost { + if !runPhase("pre-build", preBuildSteps, 0) { + return + } + } + if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) { return } - if !runPhase("recipe", recipe, buildCommandTimeout) { - return - } - if !runPhase("post-build", postBuildCmds, 0) { - return + if !build.SkipPrePost { + if !runPhase("post-build", postBuildSteps, 0) { + return + } } // Healthcheck or direct snapshot. var sizeBytes int64 if build.Healthcheck != "" { log.Info("running healthcheck", "cmd", build.Healthcheck) - if err := s.waitForHealthcheck(ctx, agent, sandboxIDStr, build.Healthcheck); err != nil { - s.destroySandbox(ctx, agent, sandboxIDStr) - s.failBuild(ctx, buildID, fmt.Sprintf("healthcheck failed: %v", err)) + if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, build.Healthcheck); err != nil { + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err)) return } // Healthcheck passed → full snapshot (with memory/CPU state). log.Info("healthcheck passed, creating snapshot") - snapResp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{ + snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{ SandboxId: sandboxIDStr, Name: build.Name, TeamId: id.UUIDString(build.TeamID), TemplateId: id.UUIDString(build.TemplateID), })) if err != nil { - s.destroySandbox(ctx, agent, sandboxIDStr) - s.failBuild(ctx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err)) return } sizeBytes = snapResp.Msg.SizeBytes } else { // No healthcheck → image-only template (rootfs only). log.Info("no healthcheck, flattening rootfs") - flatResp, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{ + flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{ SandboxId: sandboxIDStr, Name: build.Name, TeamId: id.UUIDString(build.TeamID), TemplateId: id.UUIDString(build.TemplateID), })) if err != nil { - s.destroySandbox(ctx, agent, sandboxIDStr) - s.failBuild(ctx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) + s.destroySandbox(buildCtx, agent, sandboxIDStr) + if buildCtx.Err() != nil { + return + } + s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err)) return } sizeBytes = flatResp.Msg.SizeBytes @@ -368,7 +418,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { templateType = "snapshot" } - if _, err := s.DB.InsertTemplate(ctx, db.InsertTemplateParams{ + if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{ ID: build.TemplateID, Name: build.Name, Type: templateType, @@ -386,7 +436,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) { // No additional destroy needed. // Mark build as success. - if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{ + if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{ ID: buildID, Status: "success", }); err != nil { log.Error("failed to mark build as success", "error", err) @@ -429,7 +479,7 @@ func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentC } } -func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []BuildLogEntry) { +func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) { logsJSON, err := json.Marshal(logs) if err != nil { slog.Warn("failed to marshal build logs", "error", err) From 377e856c8f7f80fa6f55f0489bebcf789369ed0b Mon Sep 17 00:00:00 2001 From: pptx704 Date: Mon, 30 Mar 2026 21:28:57 +0600 Subject: [PATCH 26/28] Fix lint warnings: drop deprecated Name field from snapshot response, check errcheck in benchmark Co-Authored-By: Claude Sonnet 4.6 --- internal/hostagent/server.go | 1 - internal/id/id_test.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index ab016f8a..7cd78f4c 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -129,7 +129,6 @@ func (s *Server) CreateSnapshot( return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err)) } return connect.NewResponse(&pb.CreateSnapshotResponse{ - Name: msg.Name, SizeBytes: sizeBytes, }), nil } diff --git a/internal/id/id_test.go b/internal/id/id_test.go index 6fb23945..2000e9cb 100644 --- a/internal/id/id_test.go +++ b/internal/id/id_test.go @@ -113,6 +113,6 @@ func BenchmarkParseSandboxID(b *testing.B) { s := FormatSandboxID(id) b.ResetTimer() for i := 0; i < b.N; i++ { - ParseSandboxID(s) + _, _ = ParseSandboxID(s) } } From 2b4c5e0176a0aee8997f0bc7f92e1d3e94c14aaf Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 1 Apr 2026 15:09:44 +0600 Subject: [PATCH 27/28] Add pre-pause proxy connection drain and sandbox proxy caching Introduce ConnTracker (atomic.Bool + WaitGroup) to track in-flight proxy connections per sandbox. Before pausing a VM, the manager drains active connections with a 2s grace period, preventing Go runtime corruption inside the guest caused by stale TCP state surviving Firecracker snapshot/restore. Also add: - AcquireProxyConn on Manager for atomic lookup + connection tracking - Proxy cache (120s TTL) on CP SandboxProxyWrapper with single-query DB lookup (GetSandboxProxyTarget) to avoid two round-trips - Reset() on ConnTracker to re-enable connections if pause fails --- db/queries/sandboxes.sql | 8 ++ internal/api/handler_sandbox_proxy.go | 149 +++++++++++++++++++------- internal/auth/cert.go | 6 +- internal/db/sandboxes.sql.go | 26 +++++ internal/hostagent/proxy.go | 15 +-- internal/sandbox/conntracker.go | 66 ++++++++++++ internal/sandbox/manager.go | 37 ++++++- 7 files changed, 253 insertions(+), 54 deletions(-) create mode 100644 internal/sandbox/conntracker.go diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 8cbd10be..2b195744 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -9,6 +9,14 @@ SELECT * FROM sandboxes WHERE id = $1; -- name: GetSandboxByTeam :one SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2; +-- name: GetSandboxProxyTarget :one +-- Returns the sandbox status and its host's address in one query. +-- Used by SandboxProxyWrapper to avoid two round-trips. +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2; + -- name: ListSandboxes :many SELECT * FROM sandboxes ORDER BY created_at DESC; diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index a7b9f5b1..963dff69 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -1,6 +1,8 @@ package api import ( + "context" + "errors" "fmt" "log/slog" "net/http" @@ -9,6 +11,8 @@ import ( "regexp" "strconv" "strings" + "sync" + "time" "github.com/jackc/pgx/v5/pgtype" @@ -18,10 +22,45 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/lifecycle" ) +// Sentinel errors returned by proxyTarget, used to map to HTTP status codes +// without relying on error message text. +var ( + errProxySandboxNotFound = errors.New("sandbox not found") + errProxyNoHostAddress = errors.New("host agent has no address") +) + +const proxyCacheTTL = 120 * time.Second + // sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or // "49999-cl-abcd1234.example.com". Captures: port, sandbox ID. var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`) +// errProxySandboxNotRunning carries the sandbox status so callers can include +// it in the HTTP response without parsing error strings. +type errProxySandboxNotRunning struct{ status string } + +func (e errProxySandboxNotRunning) Error() string { + return fmt.Sprintf("sandbox is not running (status: %s)", e.status) +} + +// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair. +// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure +// can capture the correct port without the cache key needing to include it. +type proxyCacheEntry struct { + agentURL *url.URL + expiresAt time.Time +} + +// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation. +type proxyCacheKey [32]byte + +func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey { + var k proxyCacheKey + copy(k[:16], sandboxID.Bytes[:]) + copy(k[16:], teamID.Bytes[:]) + return k +} + // SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests // whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching // requests are reverse-proxied through the host agent that owns the sandbox. @@ -34,6 +73,9 @@ type SandboxProxyWrapper struct { db *db.Queries pool *lifecycle.HostClientPool transport http.RoundTripper + + cacheMu sync.Mutex + cache map[proxyCacheKey]proxyCacheEntry } // NewSandboxProxyWrapper creates a new proxy wrapper. @@ -43,9 +85,63 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec db: queries, pool: pool, transport: pool.Transport(), + cache: make(map[proxyCacheKey]proxyCacheEntry), } } +// proxyTarget looks up the cached agent URL for (sandboxID, teamID). +// On a miss it queries the DB, resolves the address, and populates the cache. +// The *httputil.ReverseProxy is built by the caller so the Director closure +// captures the correct port without the cache key needing to include it. +func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) { + cacheKey := makeProxyCacheKey(sandboxID, teamID) + + h.cacheMu.Lock() + entry, ok := h.cache[cacheKey] + h.cacheMu.Unlock() + + if ok && time.Now().Before(entry.expiresAt) { + return entry.agentURL, nil + } + + // Cache miss or expired — query DB. + target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{ + ID: sandboxID, + TeamID: teamID, + }) + if err != nil { + return nil, errProxySandboxNotFound + } + if target.Status != "running" { + return nil, errProxySandboxNotRunning{status: target.Status} + } + if target.HostAddress == "" { + return nil, errProxyNoHostAddress + } + + agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress)) + if err != nil { + return nil, fmt.Errorf("invalid host agent address: %w", err) + } + + h.cacheMu.Lock() + h.cache[cacheKey] = proxyCacheEntry{ + agentURL: agentURL, + expiresAt: time.Now().Add(proxyCacheTTL), + } + h.cacheMu.Unlock() + + return agentURL, nil +} + +// evictProxyCache removes the cached entry for a (sandbox, team) pair. +// Called on 502 so a stopped/moved sandbox is re-resolved on the next request. +func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) { + h.cacheMu.Lock() + delete(h.cache, makeProxyCacheKey(sandboxID, teamID)) + h.cacheMu.Unlock() +} + func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { host := r.Host // Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost") @@ -82,51 +178,26 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - ctx := r.Context() - - // Look up sandbox and verify ownership. - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ - ID: sandboxID, - TeamID: teamID, - }) + agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID) if err != nil { - http.Error(w, "sandbox not found", http.StatusNotFound) - return - } - - if sb.Status != "running" { - http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) - return - } - - agentHost, err := h.db.GetHost(ctx, sb.HostID) - if err != nil { - http.Error(w, "host agent not found", http.StatusServiceUnavailable) - return - } - - if agentHost.Address == "" { - http.Error(w, "host agent has no address", http.StatusServiceUnavailable) - return - } - - agentAddr := h.pool.ResolveAddr(agentHost.Address) - upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) - - target, err := url.Parse(agentAddr) - if err != nil { - http.Error(w, "invalid host agent address", http.StatusInternalServerError) + switch { + case errors.Is(err, errProxySandboxNotFound): + http.Error(w, err.Error(), http.StatusNotFound) + case errors.As(err, new(errProxySandboxNotRunning)): + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } return } proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = upstreamPath - req.URL.RawQuery = r.URL.RawQuery - req.Host = target.Host + req.URL.Scheme = agentURL.Scheme + req.URL.Host = agentURL.Host + req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path + req.Host = agentURL.Host }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("sandbox proxy error", @@ -134,10 +205,10 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) "port", port, "error", err, ) + h.evictProxyCache(sandboxID, teamID) http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) }, } - proxy.ServeHTTP(w, r) } diff --git a/internal/auth/cert.go b/internal/auth/cert.go index 1af48672..d76f1de3 100644 --- a/internal/auth/cert.go +++ b/internal/auth/cert.go @@ -235,9 +235,9 @@ func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config { pool := x509.NewCertPool() pool.AddCert(ca.Cert) return &tls.Config{ - RootCAs: pool, - GetClientCertificate: certStore.GetClientCertificate, - MinVersion: tls.VersionTLS13, + RootCAs: pool, + GetClientCertificate: certStore.GetClientCertificate, + MinVersion: tls.VersionTLS13, } } diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 4107f1ab..3ce16443 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -105,6 +105,32 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara return i, err } +const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2 +` + +type GetSandboxProxyTargetParams struct { + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` +} + +type GetSandboxProxyTargetRow struct { + Status string `json:"status"` + HostAddress string `json:"host_address"` +} + +// Returns the sandbox status and its host's address in one query. +// Used by SandboxProxyWrapper to avoid two round-trips. +func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) { + row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID) + var i GetSandboxProxyTargetRow + err := row.Scan(&i.Status, &i.HostAddress) + return i, err +} + const insertSandbox = `-- name: InsertSandbox :one INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index b4a39ee0..bbee4741 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -8,7 +8,6 @@ import ( "strconv" "strings" - "git.omukk.dev/wrenn/sandbox/internal/models" "git.omukk.dev/wrenn/sandbox/internal/sandbox" ) @@ -62,18 +61,14 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - sb, err := h.mgr.Get(sandboxID) - if err != nil { - http.Error(w, "sandbox not found", http.StatusNotFound) + hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID) + if !ok { + http.Error(w, "sandbox is not available", http.StatusServiceUnavailable) return } + defer tracker.Release() - if sb.Status != models.StatusRunning { - http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) - return - } - - targetHost := fmt.Sprintf("%s:%d", sb.HostIP.String(), portNum) + targetHost := fmt.Sprintf("%s:%d", hostIP, portNum) proxy := &httputil.ReverseProxy{ Transport: h.transport, diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go new file mode 100644 index 00000000..d9eac72b --- /dev/null +++ b/internal/sandbox/conntracker.go @@ -0,0 +1,66 @@ +package sandbox + +import ( + "sync" + "sync/atomic" + "time" +) + +// ConnTracker tracks active proxy connections for a single sandbox and +// provides a drain mechanism for pre-pause graceful shutdown. +// It is safe for concurrent use. +type ConnTracker struct { + draining atomic.Bool + wg sync.WaitGroup +} + +// Acquire registers one in-flight connection. Returns false if the tracker +// is already draining; the caller must not call Release in that case. +func (t *ConnTracker) Acquire() bool { + if t.draining.Load() { + return false + } + t.wg.Add(1) + // Re-check after Add: Drain may have set draining between our Load + // and Add. If so, undo the Add and reject the connection. + if t.draining.Load() { + t.wg.Done() + return false + } + return true +} + +// Release marks one connection as complete. Must be called exactly once +// per successful Acquire. +func (t *ConnTracker) Release() { + t.wg.Done() +} + +// Drain marks the tracker as draining (all future Acquire calls return +// false) and waits up to timeout for in-flight connections to finish. +// +// Note: if the timeout expires with connections still in-flight, the +// internal goroutine waiting on wg.Wait() will remain until those +// connections complete. This is bounded by the number of hung connections +// at drain time and self-heals once they close. +func (t *ConnTracker) Drain(timeout time.Duration) { + t.draining.Store(true) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + } +} + +// Reset re-enables the tracker after a failed drain. This allows the +// sandbox to accept proxy connections again if the pause operation fails +// and the VM is resumed. +func (t *ConnTracker) Reset() { + t.draining.Store(false) +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index ac2bc221..67a70ca0 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net" "os" "os/exec" "path/filepath" @@ -50,7 +51,8 @@ type sandboxState struct { models.Sandbox slot *network.Slot client *envdclient.Client - uffdSocketPath string // non-empty for sandboxes restored from snapshot + connTracker *ConnTracker // tracks in-flight proxy connections for pre-pause drain + uffdSocketPath string // non-empty for sandboxes restored from snapshot dmDevice *devicemapper.SnapshotDevice baseImagePath string // path to the base template rootfs (for loop registry release) @@ -224,6 +226,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template }, slot: slot, client: client, + connTracker: &ConnTracker{}, dmDevice: dmDev, baseImagePath: baseRootfs, } @@ -308,10 +311,17 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Step 0: Drain in-flight proxy connections before freezing vCPUs. + // This prevents Go runtime corruption inside the guest caused by stale + // TCP state from connections that were alive when the VM was snapshotted. + sb.connTracker.Drain(2 * time.Second) + slog.Debug("pause: proxy connections drained", "id", sandboxID) + pauseStart := time.Now() // Step 1: Pause the VM (freeze vCPUs). if err := m.vm.Pause(ctx, sandboxID); err != nil { + sb.connTracker.Reset() return fmt.Errorf("pause VM: %w", err) } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) @@ -326,8 +336,10 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // resumeOnError unpauses the VM so the sandbox stays usable when a // post-freeze step fails. If the resume itself fails, the sandbox is - // left frozen — the caller should destroy it. + // left frozen — the caller should destroy it. It also resets the + // connection tracker so the sandbox can accept proxy connections again. resumeOnError := func() { + sb.connTracker.Reset() if err := m.vm.Resume(ctx, sandboxID); err != nil { slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) } @@ -692,6 +704,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseImagePath, @@ -1094,6 +1107,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseRootfs, @@ -1190,6 +1204,25 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) { return sb.client, nil } +// AcquireProxyConn atomically looks up a sandbox by ID and registers an +// in-flight proxy connection. Returns the sandbox's host-reachable IP, the +// connection tracker, and true on success. The caller must call +// tracker.Release() when the request completes. Returns zero values and +// false if the sandbox is not found, not running, or is draining for a pause. +func (m *Manager) AcquireProxyConn(sandboxID string) (net.IP, *ConnTracker, bool) { + m.mu.RLock() + sb, ok := m.boxes[sandboxID] + m.mu.RUnlock() + + if !ok || sb.Status != models.StatusRunning { + return nil, nil, false + } + if !sb.connTracker.Acquire() { + return nil, nil, false + } + return sb.HostIP, sb.connTracker, true +} + // Ping resets the inactivity timer for a running sandbox. func (m *Manager) Ping(sandboxID string) error { m.mu.Lock() From 8b5fa3438efcd226f046581a4c2a244c06d76926 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 1 Apr 2026 15:47:28 +0600 Subject: [PATCH 28/28] Replace gopsutil port scanner with direct /proc/net/tcp reading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The envd port scanner used gopsutil's net.Connections() which walks /proc/{pid}/fd to enumerate socket inodes. This corrupts Go runtime semaphore state when the VM is paused mid-operation and restored from a Firecracker snapshot. Replace with a direct /proc/net/tcp + /proc/net/tcp6 parser that reads a single file per address family — no /proc/{pid}/fd walk, no goroutines, no WaitGroups. Also replace concurrent-map (smap) in the scanner with a plain sync.RWMutex-protected map, since concurrent-map's Items() spawns goroutines with a WaitGroup internally, which is equally unsafe across snapshot boundaries. Use socket inode instead of PID for the port forwarding map key, since inode is available directly from /proc/net/tcp without the fd walk. --- envd/internal/port/conn.go | 165 +++++++++++++++++++++++++++ envd/internal/port/forward.go | 18 +-- envd/internal/port/scan.go | 47 +++++--- envd/internal/port/scanSubscriber.go | 19 ++- envd/internal/port/scanfilter.go | 8 +- 5 files changed, 216 insertions(+), 41 deletions(-) create mode 100644 envd/internal/port/conn.go diff --git a/envd/internal/port/conn.go b/envd/internal/port/conn.go new file mode 100644 index 00000000..8a8c032c --- /dev/null +++ b/envd/internal/port/conn.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 + +package port + +import ( + "bufio" + "encoding/hex" + "fmt" + "net" + "os" + "strconv" + "strings" + "syscall" +) + +// ConnStat represents a single TCP connection read from /proc/net/tcp(6). +// It contains only the fields needed by the port scanner and forwarder. +type ConnStat struct { + LocalIP string + LocalPort uint32 + Status string + Family uint32 // syscall.AF_INET or syscall.AF_INET6 + Inode uint64 // socket inode, unique per connection +} + +// tcpStates maps the hex state values from /proc/net/tcp to string names +// matching the gopsutil convention used by ScannerFilter. +var tcpStates = map[string]string{ + "01": "ESTABLISHED", + "02": "SYN_SENT", + "03": "SYN_RECV", + "04": "FIN_WAIT1", + "05": "FIN_WAIT2", + "06": "TIME_WAIT", + "07": "CLOSE", + "08": "CLOSE_WAIT", + "09": "LAST_ACK", + "0A": "LISTEN", + "0B": "CLOSING", +} + +// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns +// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil +// performs, which is unsafe across Firecracker snapshot/restore boundaries. +func ReadTCPConnections() ([]ConnStat, error) { + var conns []ConnStat + + tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp: %w", err) + } + conns = append(conns, tcp4...) + + tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err) + } + conns = append(conns, tcp6...) + + return conns, nil +} + +// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file. +// +// Format (fields are whitespace-separated): +// +// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode +// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345 +func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var conns []ConnStat + scanner := bufio.NewScanner(f) + + // Skip header line. + scanner.Scan() + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 10 { + continue + } + + // fields[1] = local_address (hex_ip:hex_port) + ip, port, err := parseHexAddr(fields[1], family) + if err != nil { + continue + } + + // fields[3] = state (hex) + state, ok := tcpStates[fields[3]] + if !ok { + state = "UNKNOWN" + } + + // fields[9] = inode + inode, err := strconv.ParseUint(fields[9], 10, 64) + if err != nil { + continue + } + + conns = append(conns, ConnStat{ + LocalIP: ip, + LocalPort: port, + Status: state, + Family: family, + Inode: inode, + }) + } + + return conns, scanner.Err() +} + +// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp. +// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word). +// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word). +func parseHexAddr(s string, family uint32) (string, uint32, error) { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid address: %s", s) + } + + port64, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return "", 0, err + } + + ipHex := parts[0] + ipBytes, err := hex.DecodeString(ipHex) + if err != nil { + return "", 0, err + } + + var ip net.IP + if family == syscall.AF_INET { + if len(ipBytes) != 4 { + return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes)) + } + // /proc/net/tcp stores IPv4 as a single little-endian 32-bit word. + ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0]) + } else { + if len(ipBytes) != 16 { + return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes)) + } + // /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words. + ip = make(net.IP, 16) + for i := 0; i < 4; i++ { + ip[i*4+0] = ipBytes[i*4+3] + ip[i*4+1] = ipBytes[i*4+2] + ip[i*4+2] = ipBytes[i*4+1] + ip[i*4+3] = ipBytes[i*4+0] + } + } + + return ip.String(), uint32(port64), nil +} diff --git a/envd/internal/port/forward.go b/envd/internal/port/forward.go index e8365196..bf516ff9 100644 --- a/envd/internal/port/forward.go +++ b/envd/internal/port/forward.go @@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21) type PortToForward struct { socat *exec.Cmd - // Process ID of the process that's listening on port. - pid int32 + // Socket inode of the listening socket (unique per connection). + inode uint64 // family version of the ip. family uint32 state PortState @@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { // Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state. // This will make sure we won't delete them later. for _, p := range procs { - key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port) + key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort) // We check if the opened port is in our map of forwarded ports. val, portOk := f.ports[key] @@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { val.state = PortStateForward } else { f.logger.Debug(). - Str("ip", p.Laddr.IP). - Uint32("port", p.Laddr.Port). + Str("ip", p.LocalIP). + Uint32("port", p.LocalPort). Uint32("family", familyToIPVersion(p.Family)). Str("state", p.Status). Msg("Detected new opened port on localhost that is not forwarded") // The opened port wasn't in the map so we create a new PortToForward and start forwarding. ptf := &PortToForward{ - pid: p.Pid, - port: p.Laddr.Port, + inode: p.Inode, + port: p.LocalPort, state: PortStateForward, family: familyToIPVersion(p.Family), } @@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) { f.logger.Debug(). Str("socatCmd", cmd.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). @@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) { logger := f.logger.With(). Str("socatCmd", p.socat.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). diff --git a/envd/internal/port/scan.go b/envd/internal/port/scan.go index 766202a6..2b155233 100644 --- a/envd/internal/port/scan.go +++ b/envd/internal/port/scan.go @@ -3,19 +3,21 @@ package port import ( + "sync" "time" "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" - - "git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap" ) type Scanner struct { - Processes chan net.ConnectionStat - scanExit chan struct{} - subs *smap.Map[*ScannerSubscriber] - period time.Duration + scanExit chan struct{} + period time.Duration + + // Plain mutex-protected map instead of concurrent-map. The concurrent-map + // library's Items() spawns goroutines and uses a WaitGroup internally, + // which corrupts Go runtime semaphore state across Firecracker snapshot/restore. + mu sync.RWMutex + subs map[string]*ScannerSubscriber } func (s *Scanner) Destroy() { @@ -24,33 +26,44 @@ func (s *Scanner) Destroy() { func NewScanner(period time.Duration) *Scanner { return &Scanner{ - period: period, - subs: smap.New[*ScannerSubscriber](), - scanExit: make(chan struct{}), - Processes: make(chan net.ConnectionStat), + period: period, + subs: make(map[string]*ScannerSubscriber), + scanExit: make(chan struct{}), } } func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber { subscriber := NewScannerSubscriber(logger, id, filter) - s.subs.Insert(id, subscriber) + + s.mu.Lock() + s.subs[id] = subscriber + s.mu.Unlock() return subscriber } func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) { - s.subs.Remove(sub.ID()) + s.mu.Lock() + delete(s.subs, sub.ID()) + s.mu.Unlock() + sub.Destroy() } // ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers. func (s *Scanner) ScanAndBroadcast() { for { - // tcp monitors both ipv4 and ipv6 connections. - processes, _ := net.Connections("tcp") - for _, sub := range s.subs.Items() { - sub.Signal(processes) + // Read directly from /proc/net/tcp and /proc/net/tcp6 instead of + // using gopsutil's net.Connections(), which walks /proc/{pid}/fd + // and causes Go runtime corruption after Firecracker snapshot/restore. + conns, _ := ReadTCPConnections() + + s.mu.RLock() + for _, sub := range s.subs { + sub.Signal(conns) } + s.mu.RUnlock() + select { case <-s.scanExit: return diff --git a/envd/internal/port/scanSubscriber.go b/envd/internal/port/scanSubscriber.go index 6a4f5b01..bad99083 100644 --- a/envd/internal/port/scanSubscriber.go +++ b/envd/internal/port/scanSubscriber.go @@ -4,7 +4,6 @@ package port import ( "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" ) // If we want to create a listener/subscriber pattern somewhere else we should move @@ -13,7 +12,7 @@ import ( type ScannerSubscriber struct { logger *zerolog.Logger filter *ScannerFilter - Messages chan ([]net.ConnectionStat) + Messages chan ([]ConnStat) id string } @@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt logger: logger, id: id, filter: filter, - Messages: make(chan []net.ConnectionStat), + Messages: make(chan []ConnStat), } } @@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() { close(ss.Messages) } -func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) { +func (ss *ScannerSubscriber) Signal(conns []ConnStat) { // Filter isn't specified. Accept everything. if ss.filter == nil { - ss.Messages <- proc + ss.Messages <- conns } else { - filtered := []net.ConnectionStat{} - for i := range proc { + filtered := []ConnStat{} + for i := range conns { // We need to access the list directly otherwise there will be implicit memory aliasing - // If the filter matched a process, we will send it to a channel. - if ss.filter.Match(&proc[i]) { - filtered = append(filtered, proc[i]) + // If the filter matched a connection, we will send it to a channel. + if ss.filter.Match(&conns[i]) { + filtered = append(filtered, conns[i]) } } ss.Messages <- filtered diff --git a/envd/internal/port/scanfilter.go b/envd/internal/port/scanfilter.go index 941023d9..f87667f2 100644 --- a/envd/internal/port/scanfilter.go +++ b/envd/internal/port/scanfilter.go @@ -4,8 +4,6 @@ package port import ( "slices" - - "github.com/shirou/gopsutil/v4/net" ) type ScannerFilter struct { @@ -13,15 +11,15 @@ type ScannerFilter struct { IPs []string } -func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool { +func (sf *ScannerFilter) Match(conn *ConnStat) bool { // Filter is an empty struct. if sf.State == "" && len(sf.IPs) == 0 { return false } - ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP) + ipMatch := slices.Contains(sf.IPs, conn.LocalIP) - if ipMatch && sf.State == proc.Status { + if ipMatch && sf.State == conn.Status { return true }