1
0
forked from wrenn/wrenn

Merge pull request 'Feature: HTTP communication with sandbox' (#10) from code-interpreter into dev

Reviewed-on: wrenn/sandbox#10
This commit is contained in:
2026-04-02 17:41:07 +00:00
129 changed files with 7527 additions and 1824 deletions

View File

@ -5,13 +5,13 @@ DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
REDIS_URL=redis://localhost:6379/0
# Control Plane
CP_LISTEN_ADDR=:8000
WRENN_CP_LISTEN_ADDR=:8080
# Host Agent
AGENT_LISTEN_ADDR=:50051
AGENT_FILES_ROOTDIR=/var/lib/wrenn
AGENT_HOST_INTERFACE=eth0
AGENT_CP_URL=http://localhost:8000
WRENN_HOST_LISTEN_ADDR=:50051
WRENN_DIR=/var/lib/wrenn
WRENN_HOST_INTERFACE=eth0
WRENN_CP_URL=http://localhost:8080
# Lago (billing — external service)
LAGO_API_URL=http://localhost:3000
@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY=
# Auth
JWT_SECRET=
# mTLS — CP→Agent channel
# Generate a self-signed CA with:
# openssl ecparam -genkey -name P-256 -noout -out ca.key
# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca"
# Then set these to the file contents (newlines replaced with \n or use multiline env).
WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE-----
WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY-----
# OAuth
OAUTH_GITHUB_CLIENT_ID=
OAUTH_GITHUB_CLIENT_SECRET=

View File

@ -39,7 +39,7 @@ dev: dev-infra migrate-up dev-cp
dev-infra:
docker compose -f deploy/docker-compose.dev.yml up -d
@echo "Waiting for PostgreSQL..."
@until pg_isready -h localhost -p 5432 -q; do sleep 0.5; done
@until docker compose -f deploy/docker-compose.dev.yml exec -T postgres pg_isready -q 2>/dev/null; do sleep 0.5; done
@echo "Dev infrastructure ready."
dev-down:
@ -53,7 +53,7 @@ dev-agent:
sudo go run ./cmd/host-agent
dev-frontend:
cd frontend && pnpm dev --port 5173
cd frontend && pnpm dev --port 5173 --host 0.0.0.0
dev-envd:
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002

View File

@ -51,12 +51,12 @@ Copy `.env.example` to `.env` and edit:
DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
# Control plane
CP_LISTEN_ADDR=:8000
WRENN_CP_LISTEN_ADDR=:8000
CP_HOST_AGENT_ADDR=http://localhost:50051
# Host agent
AGENT_LISTEN_ADDR=:50051
AGENT_FILES_ROOTDIR=/var/lib/wrenn
WRENN_HOST_LISTEN_ADDR=:50051
WRENN_DIR=/var/lib/wrenn
```
### Run
@ -69,7 +69,7 @@ make migrate-up
./builds/wrenn-cp
```
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`).
Control plane listens on `WRENN_CP_LISTEN_ADDR` (default `:8000`).
### Host registration
@ -87,16 +87,16 @@ Hosts must be registered with the control plane before they can serve sandboxes.
2. **Start the host agent** with the registration token and its externally-reachable address:
```bash
sudo AGENT_CP_URL=http://cp-host:8000 \
sudo WRENN_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent \
--register <token-from-step-1> \
--address 10.0.1.5:50051
```
On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$AGENT_FILES_ROOTDIR/host-token`.
On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$WRENN_DIR/host-token`.
3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically:
```bash
sudo AGENT_CP_URL=http://cp-host:8000 \
sudo WRENN_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent --address 10.0.1.5:50051
```
@ -107,7 +107,7 @@ Hosts must be registered with the control plane before they can serve sandboxes.
```
Then restart the agent with the new token.
The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `WRENN_HOST_LISTEN_ADDR` (default `:50051`).
### Rootfs images

View File

@ -15,6 +15,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/api"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/config"
"git.omukk.dev/wrenn/sandbox/internal/db"
@ -68,8 +69,52 @@ func main() {
}
slog.Info("connected to redis")
// mTLS: parse internal CA and build a TLS-capable host client pool.
// When CA env vars are absent the pool falls back to plain HTTP (dev mode).
var ca *auth.CA
if cfg.CACert != "" && cfg.CAKey != "" {
var err error
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
if err != nil {
slog.Error("failed to parse mTLS CA from environment", "error", err)
os.Exit(1)
}
slog.Info("mTLS enabled: CA loaded")
} else {
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
}
// Host client pool — manages Connect RPC clients to host agents.
hostPool := lifecycle.NewHostClientPool()
var hostPool *lifecycle.HostClientPool
if ca != nil {
cpCertStore, err := auth.NewCPCertStore(ca)
if err != nil {
slog.Error("failed to issue CP client certificate", "error", err)
os.Exit(1)
}
// Renew the CP client certificate periodically so it never expires
// while the control plane is running (TTL = 24h, renewal = every 12h).
go func() {
ticker := time.NewTicker(auth.CPCertRenewInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := cpCertStore.Refresh(); err != nil {
slog.Error("failed to renew CP client certificate", "error", err)
} else {
slog.Info("CP client certificate renewed")
}
}
}
}()
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
slog.Info("host client pool: mTLS enabled")
} else {
hostPool = lifecycle.NewHostClientPool()
}
// Scheduler — picks a host for each new sandbox (round-robin for now).
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
@ -88,7 +133,11 @@ func main() {
}
// API server.
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca)
// Start template build workers (2 concurrent).
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
defer stopBuildWorkers()
// Start host monitor (passive + active reconciliation every 30s).
monitor := api.NewHostMonitor(queries, hostPool, audit.New(queries), 30*time.Second)
@ -98,9 +147,14 @@ func main() {
sampler := api.NewMetricsSampler(queries, 10*time.Second)
sampler.Start(ctx)
// Wrap the API handler with the sandbox proxy so that requests with
// {port}-{sandbox_id}.{domain} Host headers are routed to the sandbox's
// host agent. All other requests pass through to the normal API router.
proxyWrapper := api.NewSandboxProxyWrapper(srv.Handler(), queries, hostPool)
httpServer := &http.Server{
Addr: cfg.ListenAddr,
Handler: srv.Handler(),
Handler: proxyWrapper,
}
// Graceful shutdown on signal.

View File

@ -2,8 +2,10 @@ package main
import (
"context"
"crypto/tls"
"flag"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
@ -14,8 +16,10 @@ import (
"github.com/joho/godotenv"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
"git.omukk.dev/wrenn/sandbox/internal/network"
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
@ -42,16 +46,17 @@ func main() {
slog.Warn("failed to enable ip_forward", "error", err)
}
// Clean up any stale dm-snapshot devices from a previous crash.
// Clean up stale resources from a previous crash.
devicemapper.CleanupStaleDevices()
network.CleanupStaleNamespaces()
listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
cpURL := os.Getenv("AGENT_CP_URL")
tokenFile := filepath.Join(rootDir, "host.jwt")
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
cpURL := os.Getenv("WRENN_CP_URL")
credsFile := filepath.Join(rootDir, "host-credentials.json")
if cpURL == "" {
slog.Error("AGENT_CP_URL environment variable is required")
slog.Error("WRENN_CP_URL environment variable is required")
os.Exit(1)
}
if *advertiseAddr == "" {
@ -59,11 +64,15 @@ func main() {
os.Exit(1)
}
// Expand base images to the standard disk size (sparse, no extra physical
// disk). This ensures dm-snapshot sandboxes see the full size from boot.
if err := sandbox.EnsureImageSizes(rootDir, sandbox.DefaultDiskSizeMB); err != nil {
slog.Error("failed to expand base images", "error", err)
os.Exit(1)
}
cfg := sandbox.Config{
KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
ImagesDir: filepath.Join(rootDir, "images"),
SandboxesDir: filepath.Join(rootDir, "sandboxes"),
SnapshotsDir: filepath.Join(rootDir, "snapshots"),
WrennDir: rootDir,
}
mgr := sandbox.New(cfg)
@ -74,10 +83,10 @@ func main() {
mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating.
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: tokenFile,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
@ -85,17 +94,29 @@ func main() {
os.Exit(1)
}
hostID, err := hostagent.HostIDFromToken(hostToken)
if err != nil {
slog.Error("failed to extract host ID from token", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", hostID)
slog.Info("host registered", "host_id", creds.HostID)
// httpServer is declared here so the shutdown func can reference it.
httpServer := &http.Server{Addr: listenAddr}
// Set up mTLS if the CP issued a certificate during registration.
var certStore hostagent.CertStore
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
slog.Error("failed to load host TLS certificate", "error", err)
os.Exit(1)
}
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
if tlsCfg == nil {
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
os.Exit(1)
}
httpServer.TLSConfig = tlsCfg
slog.Info("mTLS enabled on agent server")
} else {
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
}
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
// and httpServer.Shutdown are each called exactly once regardless of
// whether shutdown is triggered by a signal, a heartbeat 404, or the
@ -119,13 +140,16 @@ func main() {
})
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
proxyHandler := hostagent.NewProxyHandler(mgr)
mux := http.NewServeMux()
mux.Handle(path, handler)
mux.Handle("/proxy/", proxyHandler)
httpServer.Handler = mux
// Start heartbeat loop. Handler must be set before this because the
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second,
hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second,
// pauseAll: called on 3 consecutive network failures.
func() {
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
@ -136,6 +160,17 @@ func main() {
func() {
doShutdown("host deleted from CP")
},
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
func(tf *hostagent.TokenFile) {
if tf.CertPEM == "" || tf.KeyPEM == "" {
return
}
if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil {
slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err)
} else {
slog.Info("TLS cert hot-swapped after credentials refresh")
}
},
)
// Graceful shutdown on SIGINT/SIGTERM.
@ -146,10 +181,30 @@ func main() {
doShutdown("signal: " + sig.String())
}()
slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err)
os.Exit(1)
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
if httpServer.TLSConfig != nil {
// When TLSConfig is pre-populated (cert via GetCertificate callback),
// ListenAndServeTLS does not work because it requires on-disk cert/key paths.
// Instead, create the TLS listener manually and call Serve.
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
if err != nil {
slog.Error("failed to start TLS listener", "error", err)
os.Exit(1)
}
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("https server error", "error", err)
os.Exit(1)
}
} else {
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
slog.Error("failed to start listener", "error", err)
os.Exit(1)
}
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err)
os.Exit(1)
}
}
slog.Info("host agent stopped")

View File

@ -1,25 +1,237 @@
-- +goose Up
CREATE TABLE sandboxes (
id TEXT PRIMARY KEY,
owner_id TEXT NOT NULL DEFAULT '',
host_id TEXT NOT NULL DEFAULT 'default',
template TEXT NOT NULL DEFAULT 'minimal',
status TEXT NOT NULL DEFAULT 'pending',
vcpus INTEGER NOT NULL DEFAULT 1,
memory_mb INTEGER NOT NULL DEFAULT 512,
timeout_sec INTEGER NOT NULL DEFAULT 0,
guest_ip TEXT NOT NULL DEFAULT '',
host_ip TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TIMESTAMPTZ,
last_active_at TIMESTAMPTZ,
last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW()
-- teams
CREATE TABLE teams (
id UUID PRIMARY KEY,
name TEXT NOT NULL,
slug TEXT NOT NULL UNIQUE,
is_byoc BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX idx_teams_slug ON teams(slug);
-- users
CREATE TABLE users (
id UUID PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
password_hash TEXT,
name TEXT NOT NULL DEFAULT '',
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- users_teams (junction)
CREATE TABLE users_teams (
user_id UUID NOT NULL REFERENCES users(id),
team_id UUID NOT NULL REFERENCES teams(id),
is_default BOOLEAN NOT NULL DEFAULT FALSE,
role TEXT NOT NULL DEFAULT 'member',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (team_id, user_id)
);
CREATE INDEX idx_users_teams_user ON users_teams(user_id);
-- team_api_keys
CREATE TABLE team_api_keys (
id UUID PRIMARY KEY,
team_id UUID NOT NULL REFERENCES teams(id),
name TEXT NOT NULL,
key_hash TEXT NOT NULL UNIQUE,
key_prefix TEXT NOT NULL,
created_by UUID NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_used TIMESTAMPTZ
);
CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id);
-- oauth_providers
CREATE TABLE oauth_providers (
provider TEXT NOT NULL,
provider_id TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES users(id),
email TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (provider, provider_id)
);
CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
-- admin_permissions
CREATE TABLE admin_permissions (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id),
permission TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (user_id, permission)
);
CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id);
-- hosts
CREATE TABLE hosts (
id UUID PRIMARY KEY,
type TEXT NOT NULL DEFAULT 'regular',
team_id UUID REFERENCES teams(id),
provider TEXT NOT NULL DEFAULT '',
availability_zone TEXT NOT NULL DEFAULT '',
arch TEXT NOT NULL DEFAULT '',
cpu_cores INTEGER NOT NULL DEFAULT 0,
memory_mb INTEGER NOT NULL DEFAULT 0,
disk_gb INTEGER NOT NULL DEFAULT 0,
address TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'pending',
last_heartbeat_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}',
created_by UUID NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
cert_fingerprint TEXT NOT NULL DEFAULT '',
mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX idx_hosts_type ON hosts(type);
CREATE INDEX idx_hosts_team ON hosts(team_id);
CREATE INDEX idx_hosts_status ON hosts(status);
-- host_tokens
CREATE TABLE host_tokens (
id UUID PRIMARY KEY,
host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
created_by UUID NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL,
used_at TIMESTAMPTZ
);
CREATE INDEX idx_host_tokens_host ON host_tokens(host_id);
-- host_tags
CREATE TABLE host_tags (
host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
tag TEXT NOT NULL,
PRIMARY KEY (host_id, tag)
);
CREATE INDEX idx_host_tags_tag ON host_tags(tag);
-- host_refresh_tokens
CREATE TABLE host_refresh_tokens (
id UUID PRIMARY KEY,
host_id UUID NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ
);
CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
-- templates (TEXT primary key — not UUID)
CREATE TABLE templates (
name TEXT PRIMARY KEY,
type TEXT NOT NULL DEFAULT 'base',
vcpus INTEGER NOT NULL DEFAULT 1,
memory_mb INTEGER NOT NULL DEFAULT 512,
size_bytes BIGINT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
team_id UUID NOT NULL
);
CREATE INDEX idx_templates_team ON templates(team_id);
-- sandboxes
CREATE TABLE sandboxes (
id UUID PRIMARY KEY,
team_id UUID NOT NULL REFERENCES teams(id),
host_id UUID NOT NULL,
template TEXT NOT NULL DEFAULT 'minimal',
status TEXT NOT NULL DEFAULT 'pending',
vcpus INTEGER NOT NULL DEFAULT 1,
memory_mb INTEGER NOT NULL DEFAULT 512,
timeout_sec INTEGER NOT NULL DEFAULT 300,
disk_size_mb INTEGER NOT NULL DEFAULT 5120,
guest_ip TEXT NOT NULL DEFAULT '',
host_ip TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TIMESTAMPTZ,
last_active_at TIMESTAMPTZ,
last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_sandboxes_status ON sandboxes(status);
CREATE INDEX idx_sandboxes_host_status ON sandboxes(host_id, status);
CREATE INDEX idx_sandboxes_team ON sandboxes(team_id);
-- audit_logs (id and team_id are UUID; actor_id and resource_id are TEXT for polymorphism)
CREATE TABLE audit_logs (
id UUID PRIMARY KEY,
team_id UUID NOT NULL,
actor_type TEXT NOT NULL,
actor_id TEXT,
actor_name TEXT NOT NULL DEFAULT '',
resource_type TEXT NOT NULL,
resource_id TEXT,
action TEXT NOT NULL,
scope TEXT NOT NULL DEFAULT 'team',
status TEXT NOT NULL DEFAULT 'success',
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_audit_logs_team_time ON audit_logs(team_id, created_at DESC);
CREATE INDEX idx_audit_logs_team_resource ON audit_logs(team_id, resource_type, created_at DESC);
-- sandbox_metrics_snapshots
CREATE TABLE sandbox_metrics_snapshots (
id BIGSERIAL PRIMARY KEY,
team_id UUID NOT NULL,
sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
running_count INTEGER NOT NULL DEFAULT 0,
vcpus_reserved INTEGER NOT NULL DEFAULT 0,
memory_mb_reserved INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX idx_metrics_snapshots_team_time ON sandbox_metrics_snapshots(team_id, sampled_at DESC);
-- sandbox_metric_points
CREATE TABLE sandbox_metric_points (
sandbox_id UUID NOT NULL,
tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')),
ts BIGINT NOT NULL,
cpu_pct FLOAT8 NOT NULL DEFAULT 0,
mem_bytes BIGINT NOT NULL DEFAULT 0,
disk_bytes BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (sandbox_id, tier, ts)
);
CREATE INDEX idx_sandbox_metric_points_sandbox_tier ON sandbox_metric_points(sandbox_id, tier);
-- template_builds
CREATE TABLE template_builds (
id UUID PRIMARY KEY,
name TEXT NOT NULL,
base_template TEXT NOT NULL,
recipe JSONB NOT NULL DEFAULT '[]',
healthcheck TEXT NOT NULL DEFAULT '',
vcpus INTEGER NOT NULL DEFAULT 1,
memory_mb INTEGER NOT NULL DEFAULT 512,
status TEXT NOT NULL DEFAULT 'pending',
current_step INTEGER NOT NULL DEFAULT 0,
total_steps INTEGER NOT NULL DEFAULT 0,
logs JSONB NOT NULL DEFAULT '[]',
error TEXT NOT NULL DEFAULT '',
sandbox_id UUID,
host_id UUID,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ
);
-- +goose Down
DROP TABLE sandboxes;
DROP TABLE IF EXISTS template_builds;
DROP TABLE IF EXISTS sandbox_metric_points;
DROP TABLE IF EXISTS sandbox_metrics_snapshots;
DROP TABLE IF EXISTS audit_logs;
DROP TABLE IF EXISTS sandboxes;
DROP TABLE IF EXISTS templates;
DROP TABLE IF EXISTS host_refresh_tokens;
DROP TABLE IF EXISTS host_tags;
DROP TABLE IF EXISTS host_tokens;
DROP TABLE IF EXISTS hosts;
DROP TABLE IF EXISTS admin_permissions;
DROP TABLE IF EXISTS oauth_providers;
DROP TABLE IF EXISTS team_api_keys;
DROP TABLE IF EXISTS users_teams;
DROP TABLE IF EXISTS users;
DROP TABLE IF EXISTS teams;

View File

@ -1,14 +0,0 @@
-- +goose Up
CREATE TABLE templates (
name TEXT PRIMARY KEY,
type TEXT NOT NULL DEFAULT 'base', -- 'base' or 'snapshot'
vcpus INTEGER,
memory_mb INTEGER,
size_bytes BIGINT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- +goose Down
DROP TABLE templates;

View File

@ -1,46 +0,0 @@
-- +goose Up
CREATE TABLE users (
id TEXT PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE teams (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE users_teams (
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
is_default BOOLEAN NOT NULL DEFAULT TRUE,
role TEXT NOT NULL DEFAULT 'owner',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (team_id, user_id)
);
CREATE INDEX idx_users_teams_user ON users_teams(user_id);
CREATE TABLE team_api_keys (
id TEXT PRIMARY KEY,
team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
name TEXT NOT NULL DEFAULT '',
key_hash TEXT NOT NULL UNIQUE,
key_prefix TEXT NOT NULL DEFAULT '',
created_by TEXT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_used TIMESTAMPTZ
);
CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id);
-- +goose Down
DROP TABLE team_api_keys;
DROP TABLE users_teams;
DROP TABLE teams;
DROP TABLE users;

View File

@ -1,31 +0,0 @@
-- +goose Up
ALTER TABLE sandboxes
ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
UPDATE sandboxes SET team_id = owner_id WHERE owner_id != '';
ALTER TABLE sandboxes
DROP COLUMN owner_id;
ALTER TABLE templates
ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
CREATE INDEX idx_sandboxes_team ON sandboxes(team_id);
CREATE INDEX idx_templates_team ON templates(team_id);
-- +goose Down
ALTER TABLE sandboxes
ADD COLUMN owner_id TEXT NOT NULL DEFAULT '';
UPDATE sandboxes SET owner_id = team_id WHERE team_id != '';
ALTER TABLE sandboxes
DROP COLUMN team_id;
ALTER TABLE templates
DROP COLUMN team_id;
DROP INDEX IF EXISTS idx_sandboxes_team;
DROP INDEX IF EXISTS idx_templates_team;

View File

@ -1,22 +0,0 @@
-- +goose Up
ALTER TABLE users
ALTER COLUMN password_hash DROP NOT NULL;
CREATE TABLE oauth_providers (
provider TEXT NOT NULL,
provider_id TEXT NOT NULL,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
email TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (provider, provider_id)
);
CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
-- +goose Down
DROP TABLE oauth_providers;
UPDATE users SET password_hash = '' WHERE password_hash IS NULL;
ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL;

View File

@ -1,21 +0,0 @@
-- +goose Up
ALTER TABLE users
ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE;
CREATE TABLE admin_permissions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
permission TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (user_id, permission)
);
CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id);
-- +goose Down
DROP TABLE admin_permissions;
ALTER TABLE users
DROP COLUMN is_admin;

View File

@ -1,9 +0,0 @@
-- +goose Up
ALTER TABLE teams
ADD COLUMN is_byoc BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose Down
ALTER TABLE teams
DROP COLUMN is_byoc;

View File

@ -1,47 +0,0 @@
-- +goose Up
CREATE TABLE hosts (
id TEXT PRIMARY KEY,
type TEXT NOT NULL DEFAULT 'regular', -- 'regular' or 'byoc'
team_id TEXT REFERENCES teams(id) ON DELETE SET NULL,
provider TEXT,
availability_zone TEXT,
arch TEXT,
cpu_cores INTEGER,
memory_mb INTEGER,
disk_gb INTEGER,
address TEXT, -- ip:port of host agent
status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'online', 'offline', 'draining'
last_heartbeat_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}',
created_by TEXT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE host_tokens (
id TEXT PRIMARY KEY,
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
created_by TEXT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL,
used_at TIMESTAMPTZ
);
CREATE TABLE host_tags (
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
tag TEXT NOT NULL,
PRIMARY KEY (host_id, tag)
);
CREATE INDEX idx_hosts_type ON hosts(type);
CREATE INDEX idx_hosts_team ON hosts(team_id);
CREATE INDEX idx_hosts_status ON hosts(status);
CREATE INDEX idx_host_tokens_host ON host_tokens(host_id);
CREATE INDEX idx_host_tags_tag ON host_tags(tag);
-- +goose Down
DROP TABLE host_tags;
DROP TABLE host_tokens;
DROP TABLE hosts;

View File

@ -1,11 +0,0 @@
-- +goose Up
ALTER TABLE hosts
ADD COLUMN cert_fingerprint TEXT,
ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose Down
ALTER TABLE hosts
DROP COLUMN cert_fingerprint,
DROP COLUMN mtls_enabled;

View File

@ -1,17 +0,0 @@
-- +goose Up
ALTER TABLE teams ADD COLUMN slug TEXT;
ALTER TABLE teams ADD COLUMN deleted_at TIMESTAMPTZ;
-- Backfill slugs for existing teams using MD5 of their ID.
-- MD5 returns 32 hex chars; take chars 1-6 and 7-12 to form a 6-6 slug.
UPDATE teams SET slug = LEFT(MD5(id), 6) || '-' || SUBSTRING(MD5(id), 7, 6);
ALTER TABLE teams ALTER COLUMN slug SET NOT NULL;
CREATE UNIQUE INDEX idx_teams_slug ON teams(slug);
-- +goose Down
DROP INDEX idx_teams_slug;
ALTER TABLE teams DROP COLUMN deleted_at;
ALTER TABLE teams DROP COLUMN slug;

View File

@ -1,5 +0,0 @@
-- +goose Up
ALTER TABLE users ADD COLUMN name TEXT NOT NULL DEFAULT '';
-- +goose Down
ALTER TABLE users DROP COLUMN name;

View File

@ -1,19 +0,0 @@
-- +goose Up
-- Refresh tokens for host agent JWT rotation.
-- Hosts exchange a refresh token for a new short-lived JWT + new refresh token (rotation).
-- Refresh tokens expire after 60 days; hosts must re-register with a new one-time token after that.
CREATE TABLE host_refresh_tokens (
id TEXT PRIMARY KEY,
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE, -- SHA-256 hex of the opaque token
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ -- NULL = active; set on rotation or host delete
);
CREATE INDEX idx_host_refresh_tokens_host ON host_refresh_tokens(host_id);
-- +goose Down
DROP TABLE host_refresh_tokens;

View File

@ -1,28 +0,0 @@
-- +goose Up
CREATE TABLE audit_logs (
id TEXT PRIMARY KEY,
team_id TEXT NOT NULL,
actor_type TEXT NOT NULL, -- 'user', 'api_key', 'system'
actor_id TEXT, -- user_id or api_key_id; NULL for system
actor_name TEXT, -- display name snapshotted at write time; NULL for system
resource_type TEXT NOT NULL, -- 'sandbox', 'snapshot', 'team', 'api_key', 'member', 'host'
resource_id TEXT, -- primary ID of the affected resource; NULL when not applicable
action TEXT NOT NULL, -- 'create', 'pause', 'resume', 'destroy', 'delete', 'rename',
-- 'revoke', 'add', 'remove', 'leave', 'role_update',
-- 'marked_down', 'marked_up'
scope TEXT NOT NULL, -- 'team' or 'admin'
status TEXT NOT NULL, -- 'success', 'info', 'warning', 'error'
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Primary access pattern: team feed sorted newest-first with cursor pagination.
CREATE INDEX idx_audit_logs_team_time ON audit_logs (team_id, created_at DESC);
-- Secondary index: filtered by resource_type and action within a team.
CREATE INDEX idx_audit_logs_team_resource ON audit_logs (team_id, resource_type, action, created_at DESC);
-- +goose Down
DROP TABLE audit_logs;

View File

@ -1,18 +0,0 @@
-- +goose Up
CREATE TABLE sandbox_metrics_snapshots (
id BIGSERIAL PRIMARY KEY,
team_id TEXT NOT NULL,
sampled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
running_count INTEGER NOT NULL,
vcpus_reserved INTEGER NOT NULL,
memory_mb_reserved INTEGER NOT NULL
);
-- All queries filter on team_id first then range-scan sampled_at.
CREATE INDEX idx_metrics_snapshots_team_time
ON sandbox_metrics_snapshots (team_id, sampled_at DESC);
-- +goose Down
DROP TABLE sandbox_metrics_snapshots;

View File

@ -1,16 +0,0 @@
-- +goose Up
CREATE TABLE sandbox_metric_points (
sandbox_id TEXT NOT NULL,
tier TEXT NOT NULL CHECK (tier IN ('10m', '2h', '24h')),
ts BIGINT NOT NULL,
cpu_pct FLOAT8 NOT NULL DEFAULT 0,
mem_bytes BIGINT NOT NULL DEFAULT 0,
disk_bytes BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (sandbox_id, tier, ts)
);
CREATE INDEX idx_sandbox_metric_points_sandbox_tier
ON sandbox_metric_points (sandbox_id, tier);
-- +goose Down
DROP TABLE IF EXISTS sandbox_metric_points;

View File

@ -0,0 +1,82 @@
-- +goose Up
-- 1. Add UUID id column to templates and make it the primary key.
ALTER TABLE templates ADD COLUMN id UUID DEFAULT gen_random_uuid();
UPDATE templates SET id = gen_random_uuid() WHERE id IS NULL;
ALTER TABLE templates ALTER COLUMN id SET NOT NULL;
ALTER TABLE templates DROP CONSTRAINT templates_pkey;
ALTER TABLE templates ADD PRIMARY KEY (id);
-- 2. Name becomes a display field with team-scoped uniqueness.
ALTER TABLE templates ADD CONSTRAINT uq_templates_team_name UNIQUE (team_id, name);
-- 3. Prevent team templates from using names that belong to global (platform) templates.
-- A team template insert/update with a name matching any platform template is rejected.
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION check_global_template_name_collision()
RETURNS TRIGGER AS $$
BEGIN
IF NEW.team_id != '00000000-0000-0000-0000-000000000000' THEN
IF EXISTS (
SELECT 1 FROM templates
WHERE name = NEW.name
AND team_id = '00000000-0000-0000-0000-000000000000'
) THEN
RAISE EXCEPTION 'template name "%" is reserved by a global template', NEW.name
USING ERRCODE = 'unique_violation';
END IF;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- +goose StatementEnd
CREATE TRIGGER trg_check_global_template_name
BEFORE INSERT OR UPDATE ON templates
FOR EACH ROW
EXECUTE FUNCTION check_global_template_name_collision();
-- 4. Seed the built-in "minimal" template so it appears in all listings.
-- Both id and team_id are the all-zeros UUID (platform sentinel).
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES (
'00000000-0000-0000-0000-000000000000',
'minimal',
'base',
1,
512,
0,
'00000000-0000-0000-0000-000000000000'
) ON CONFLICT DO NOTHING;
-- 5. Add template UUID references to template_builds.
ALTER TABLE template_builds
ADD COLUMN template_id UUID,
ADD COLUMN team_id UUID;
-- 5. Add template UUID references to sandboxes.
ALTER TABLE sandboxes
ADD COLUMN template_id UUID,
ADD COLUMN template_team_id UUID;
-- +goose Down
ALTER TABLE sandboxes
DROP COLUMN IF EXISTS template_team_id,
DROP COLUMN IF EXISTS template_id;
ALTER TABLE template_builds
DROP COLUMN IF EXISTS team_id,
DROP COLUMN IF EXISTS template_id;
-- Remove the seeded minimal template.
DELETE FROM templates WHERE id = '00000000-0000-0000-0000-000000000000';
DROP TRIGGER IF EXISTS trg_check_global_template_name ON templates;
DROP FUNCTION IF EXISTS check_global_template_name_collision();
ALTER TABLE templates DROP CONSTRAINT IF EXISTS uq_templates_team_name;
ALTER TABLE templates DROP CONSTRAINT IF EXISTS templates_pkey;
ALTER TABLE templates ADD PRIMARY KEY (name);
ALTER TABLE templates DROP COLUMN IF EXISTS id;

View File

@ -0,0 +1,7 @@
-- +goose Up
ALTER TABLE hosts DROP COLUMN mtls_enabled;
ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ;
-- +goose Down
ALTER TABLE hosts DROP COLUMN cert_expires_at;
ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;

View File

@ -0,0 +1,11 @@
-- +goose Up
-- Allow completed_at to be set when a build is cancelled.
-- (The UpdateBuildStatus query is updated in sqlc; no schema change needed for that.)
-- Add skip_pre_post flag: when true, the pre-build and post-build command phases
-- are skipped for this build.
ALTER TABLE template_builds ADD COLUMN skip_pre_post BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose Down
ALTER TABLE template_builds DROP COLUMN skip_pre_post;

View File

@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
status = 'online',
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
updated_at = NOW()
WHERE id = $1 AND status = 'pending';
-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1;
-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;

View File

@ -1,6 +1,6 @@
-- name: InsertSandbox :one
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING *;
-- name: GetSandbox :one
@ -9,6 +9,14 @@ SELECT * FROM sandboxes WHERE id = $1;
-- name: GetSandboxByTeam :one
SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2;
-- name: GetSandboxProxyTarget :one
-- Returns the sandbox status and its host's address in one query.
-- Used by SandboxProxyWrapper to avoid two round-trips.
SELECT s.status, h.address AS host_address
FROM sandboxes s
JOIN hosts h ON h.id = s.host_id
WHERE s.id = $1 AND s.team_id = $2;
-- name: ListSandboxes :many
SELECT * FROM sandboxes ORDER BY created_at DESC;
@ -50,7 +58,7 @@ WHERE id = $1;
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::text[]);
WHERE id = ANY($1::uuid[]);
-- name: ListActiveSandboxesByTeam :many
SELECT * FROM sandboxes
@ -72,4 +80,4 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::text[]) AND status = 'missing';
WHERE id = ANY($1::uuid[]) AND status = 'missing';

View File

@ -0,0 +1,33 @@
-- name: InsertTemplateBuild :one
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
RETURNING *;
-- name: GetTemplateBuild :one
SELECT * FROM template_builds WHERE id = $1;
-- name: ListTemplateBuilds :many
SELECT * FROM template_builds ORDER BY created_at DESC;
-- name: UpdateBuildStatus :one
UPDATE template_builds
SET status = $2,
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
WHERE id = $1
RETURNING *;
-- name: UpdateBuildProgress :exec
UPDATE template_builds
SET current_step = $2, logs = $3
WHERE id = $1;
-- name: UpdateBuildSandbox :exec
UPDATE template_builds
SET sandbox_id = $2, host_id = $3
WHERE id = $1;
-- name: UpdateBuildError :exec
UPDATE template_builds
SET error = $2, status = 'failed', completed_at = NOW()
WHERE id = $1;

View File

@ -1,13 +1,22 @@
-- name: InsertTemplate :one
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6)
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING *;
-- name: GetTemplate :one
SELECT * FROM templates WHERE name = $1;
SELECT * FROM templates WHERE id = $1;
-- name: GetTemplateByTeam :one
SELECT * FROM templates WHERE name = $1 AND team_id = $2;
-- Platform templates (team_id = 00000000-...) are visible to all teams.
SELECT * FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000');
-- name: GetTemplateByName :one
-- Look up a template by team_id and name (exact team match, no global fallback).
SELECT * FROM templates WHERE team_id = $1 AND name = $2;
-- name: GetPlatformTemplateByName :one
-- Check if a global (platform) template exists with the given name.
SELECT * FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1;
-- name: ListTemplates :many
SELECT * FROM templates ORDER BY created_at DESC;
@ -16,13 +25,23 @@ SELECT * FROM templates ORDER BY created_at DESC;
SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC;
-- name: ListTemplatesByTeam :many
SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;
-- Platform templates are visible to all teams.
SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC;
-- name: ListTemplatesByTeamAndType :many
SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC;
-- Platform templates are visible to all teams.
SELECT * FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC;
-- name: DeleteTemplate :exec
DELETE FROM templates WHERE name = $1;
DELETE FROM templates WHERE id = $1;
-- name: DeleteTemplateByTeam :exec
DELETE FROM templates WHERE name = $1 AND team_id = $2;
-- name: DeleteTemplatesByTeam :exec
-- Bulk delete all templates owned by a team (for team soft-delete cleanup).
DELETE FROM templates WHERE team_id = $1;
-- name: ListTemplatesByTeamOnly :many
-- List templates owned by a specific team (NOT including platform templates).
SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;

41
deploy/Caddyfile.dev Normal file
View File

@ -0,0 +1,41 @@
# Sandbox port forwarding: {port}-{sandbox_id}.localhost
# Matches subdomains like 49999-sb-abcd1234.localhost and proxies them
# to the control plane, which inspects the Host header and routes to
# the correct host agent.
#
# NOTE: Wildcard *.localhost DNS resolution requires local setup.
# Option 1: Add entries to /etc/hosts for each sandbox
# Option 2: Use dnsmasq: address=/.localhost/127.0.0.1
# Option 3: Use systemd-resolved (Ubuntu default — *.localhost resolves to 127.0.0.1)
http://*.localhost {
reverse_proxy host.docker.internal:8080
}
# Main entry point: API + frontend
http://localhost {
# API routes — strip /api prefix and proxy to the control plane.
# The frontend calls /api/v1/... which becomes /v1/... at the CP.
handle_path /api/* {
reverse_proxy host.docker.internal:8080
}
# Backend routes served directly (SDK clients, OAuth initiation)
handle /v1/* {
reverse_proxy host.docker.internal:8080
}
handle /openapi.yaml {
reverse_proxy host.docker.internal:8080
}
handle /docs {
reverse_proxy host.docker.internal:8080
}
handle /auth/oauth/* {
reverse_proxy host.docker.internal:8080
}
# Everything else — proxy to the frontend dev server
# This includes: /login, /dashboard/*, /admin/*, /auth/github/callback
handle {
reverse_proxy host.docker.internal:5173
}
}

View File

@ -15,19 +15,14 @@ services:
ports:
- "6379:6379"
prometheus:
image: prom/prometheus:latest
caddy:
image: caddy:2-alpine
ports:
- "9090:9090"
- "8000:80"
volumes:
- ./deploy/prometheus.yml:/etc/prometheus/prometheus.yml
grafana:
image: grafana/grafana:latest
ports:
- "3001:3000"
environment:
GF_SECURITY_ADMIN_PASSWORD: admin
- ./Caddyfile.dev:/etc/caddy/Caddyfile:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
pgdata:

View File

@ -14,14 +14,12 @@ import (
"os/exec"
"time"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
"golang.org/x/sys/unix"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
)
var (
@ -29,11 +27,6 @@ var (
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
const (
maxTimeInPast = 50 * time.Millisecond
maxTimeInFuture = 5 * time.Second
)
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
@ -172,20 +165,6 @@ func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJ
return err
}
if data.Timestamp != nil {
// Check if current time differs significantly from the received timestamp
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
if err != nil {
logger.Error().Msgf("Failed to set system time: %v", err)
}
} else {
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
}
}
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
@ -308,10 +287,3 @@ func getIPFamily(address string) (txeh.IPFamily, error) {
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
}

View File

@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@ -59,71 +58,6 @@ func TestSimpleCases(t *testing.T) {
}
}
func TestShouldSetSystemTime(t *testing.T) {
t.Parallel()
sandboxTime := time.Now()
tests := []struct {
name string
hostTime time.Time
want bool
}{
{
name: "sandbox time far ahead of host time (should set)",
hostTime: sandboxTime.Add(-10 * time.Second),
want: true,
},
{
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
hostTime: sandboxTime.Add(-50 * time.Millisecond),
want: false,
},
{
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
hostTime: sandboxTime.Add(-40 * time.Millisecond),
want: false,
},
{
name: "sandbox time slightly ahead of host time (should not set)",
hostTime: sandboxTime.Add(-10 * time.Millisecond),
want: false,
},
{
name: "sandbox time equals host time (should not set)",
hostTime: sandboxTime,
want: false,
},
{
name: "sandbox time slightly behind host time (should not set)",
hostTime: sandboxTime.Add(1 * time.Second),
want: false,
},
{
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
hostTime: sandboxTime.Add(4 * time.Second),
want: false,
},
{
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
hostTime: sandboxTime.Add(5 * time.Second),
want: false,
},
{
name: "sandbox time far behind host time (should set)",
hostTime: sandboxTime.Add(1 * time.Minute),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
assert.Equal(t, tt.want, got)
})
}
}
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))

165
envd/internal/port/conn.go Normal file
View File

@ -0,0 +1,165 @@
// SPDX-License-Identifier: Apache-2.0
package port
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
"syscall"
)
// ConnStat represents a single TCP connection read from /proc/net/tcp(6).
// It contains only the fields needed by the port scanner and forwarder.
type ConnStat struct {
LocalIP string
LocalPort uint32
Status string
Family uint32 // syscall.AF_INET or syscall.AF_INET6
Inode uint64 // socket inode, unique per connection
}
// tcpStates maps the hex state values from /proc/net/tcp to string names
// matching the gopsutil convention used by ScannerFilter.
var tcpStates = map[string]string{
"01": "ESTABLISHED",
"02": "SYN_SENT",
"03": "SYN_RECV",
"04": "FIN_WAIT1",
"05": "FIN_WAIT2",
"06": "TIME_WAIT",
"07": "CLOSE",
"08": "CLOSE_WAIT",
"09": "LAST_ACK",
"0A": "LISTEN",
"0B": "CLOSING",
}
// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns
// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil
// performs, which is unsafe across Firecracker snapshot/restore boundaries.
func ReadTCPConnections() ([]ConnStat, error) {
var conns []ConnStat
tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET)
if err != nil {
return nil, fmt.Errorf("parse /proc/net/tcp: %w", err)
}
conns = append(conns, tcp4...)
tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6)
if err != nil {
return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err)
}
conns = append(conns, tcp6...)
return conns, nil
}
// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file.
//
// Format (fields are whitespace-separated):
//
// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode
// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345
func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var conns []ConnStat
scanner := bufio.NewScanner(f)
// Skip header line.
scanner.Scan()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 10 {
continue
}
// fields[1] = local_address (hex_ip:hex_port)
ip, port, err := parseHexAddr(fields[1], family)
if err != nil {
continue
}
// fields[3] = state (hex)
state, ok := tcpStates[fields[3]]
if !ok {
state = "UNKNOWN"
}
// fields[9] = inode
inode, err := strconv.ParseUint(fields[9], 10, 64)
if err != nil {
continue
}
conns = append(conns, ConnStat{
LocalIP: ip,
LocalPort: port,
Status: state,
Family: family,
Inode: inode,
})
}
return conns, scanner.Err()
}
// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp.
// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word).
// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word).
func parseHexAddr(s string, family uint32) (string, uint32, error) {
parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 {
return "", 0, fmt.Errorf("invalid address: %s", s)
}
port64, err := strconv.ParseUint(parts[1], 16, 32)
if err != nil {
return "", 0, err
}
ipHex := parts[0]
ipBytes, err := hex.DecodeString(ipHex)
if err != nil {
return "", 0, err
}
var ip net.IP
if family == syscall.AF_INET {
if len(ipBytes) != 4 {
return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes))
}
// /proc/net/tcp stores IPv4 as a single little-endian 32-bit word.
ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0])
} else {
if len(ipBytes) != 16 {
return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes))
}
// /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words.
ip = make(net.IP, 16)
for i := 0; i < 4; i++ {
ip[i*4+0] = ipBytes[i*4+3]
ip[i*4+1] = ipBytes[i*4+2]
ip[i*4+2] = ipBytes[i*4+1]
ip[i*4+3] = ipBytes[i*4+0]
}
}
return ip.String(), uint32(port64), nil
}

View File

@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
type PortToForward struct {
socat *exec.Cmd
// Process ID of the process that's listening on port.
pid int32
// Socket inode of the listening socket (unique per connection).
inode uint64
// family version of the ip.
family uint32
state PortState
@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
// This will make sure we won't delete them later.
for _, p := range procs {
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort)
// We check if the opened port is in our map of forwarded ports.
val, portOk := f.ports[key]
@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
val.state = PortStateForward
} else {
f.logger.Debug().
Str("ip", p.Laddr.IP).
Uint32("port", p.Laddr.Port).
Str("ip", p.LocalIP).
Uint32("port", p.LocalPort).
Uint32("family", familyToIPVersion(p.Family)).
Str("state", p.Status).
Msg("Detected new opened port on localhost that is not forwarded")
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
ptf := &PortToForward{
pid: p.Pid,
port: p.Laddr.Port,
inode: p.Inode,
port: p.LocalPort,
state: PortStateForward,
family: familyToIPVersion(p.Family),
}
@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
f.logger.Debug().
Str("socatCmd", cmd.String()).
Int32("pid", p.pid).
Uint64("inode", p.inode).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) {
logger := f.logger.With().
Str("socatCmd", p.socat.String()).
Int32("pid", p.pid).
Uint64("inode", p.inode).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).

View File

@ -3,19 +3,21 @@
package port
import (
"sync"
"time"
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
)
type Scanner struct {
Processes chan net.ConnectionStat
scanExit chan struct{}
subs *smap.Map[*ScannerSubscriber]
period time.Duration
scanExit chan struct{}
period time.Duration
// Plain mutex-protected map instead of concurrent-map. The concurrent-map
// library's Items() spawns goroutines and uses a WaitGroup internally,
// which corrupts Go runtime semaphore state across Firecracker snapshot/restore.
mu sync.RWMutex
subs map[string]*ScannerSubscriber
}
func (s *Scanner) Destroy() {
@ -24,33 +26,44 @@ func (s *Scanner) Destroy() {
func NewScanner(period time.Duration) *Scanner {
return &Scanner{
period: period,
subs: smap.New[*ScannerSubscriber](),
scanExit: make(chan struct{}),
Processes: make(chan net.ConnectionStat),
period: period,
subs: make(map[string]*ScannerSubscriber),
scanExit: make(chan struct{}),
}
}
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
subscriber := NewScannerSubscriber(logger, id, filter)
s.subs.Insert(id, subscriber)
s.mu.Lock()
s.subs[id] = subscriber
s.mu.Unlock()
return subscriber
}
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
s.subs.Remove(sub.ID())
s.mu.Lock()
delete(s.subs, sub.ID())
s.mu.Unlock()
sub.Destroy()
}
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
func (s *Scanner) ScanAndBroadcast() {
for {
// tcp monitors both ipv4 and ipv6 connections.
processes, _ := net.Connections("tcp")
for _, sub := range s.subs.Items() {
sub.Signal(processes)
// Read directly from /proc/net/tcp and /proc/net/tcp6 instead of
// using gopsutil's net.Connections(), which walks /proc/{pid}/fd
// and causes Go runtime corruption after Firecracker snapshot/restore.
conns, _ := ReadTCPConnections()
s.mu.RLock()
for _, sub := range s.subs {
sub.Signal(conns)
}
s.mu.RUnlock()
select {
case <-s.scanExit:
return

View File

@ -4,7 +4,6 @@ package port
import (
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
)
// If we want to create a listener/subscriber pattern somewhere else we should move
@ -13,7 +12,7 @@ import (
type ScannerSubscriber struct {
logger *zerolog.Logger
filter *ScannerFilter
Messages chan ([]net.ConnectionStat)
Messages chan ([]ConnStat)
id string
}
@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt
logger: logger,
id: id,
filter: filter,
Messages: make(chan []net.ConnectionStat),
Messages: make(chan []ConnStat),
}
}
@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() {
close(ss.Messages)
}
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
func (ss *ScannerSubscriber) Signal(conns []ConnStat) {
// Filter isn't specified. Accept everything.
if ss.filter == nil {
ss.Messages <- proc
ss.Messages <- conns
} else {
filtered := []net.ConnectionStat{}
for i := range proc {
filtered := []ConnStat{}
for i := range conns {
// We need to access the list directly otherwise there will be implicit memory aliasing
// If the filter matched a process, we will send it to a channel.
if ss.filter.Match(&proc[i]) {
filtered = append(filtered, proc[i])
// If the filter matched a connection, we will send it to a channel.
if ss.filter.Match(&conns[i]) {
filtered = append(filtered, conns[i])
}
}
ss.Messages <- filtered

View File

@ -4,8 +4,6 @@ package port
import (
"slices"
"github.com/shirou/gopsutil/v4/net"
)
type ScannerFilter struct {
@ -13,15 +11,15 @@ type ScannerFilter struct {
IPs []string
}
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
func (sf *ScannerFilter) Match(conn *ConnStat) bool {
// Filter is an empty struct.
if sf.State == "" && len(sf.IPs) == 0 {
return false
}
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
ipMatch := slices.Contains(sf.IPs, conn.LocalIP)
if ipMatch && sf.State == proc.Status {
if ipMatch && sf.State == conn.Status {
return true
}

View File

@ -0,0 +1,76 @@
import { apiFetch, type ApiResult } from '$lib/api/client';
export type BuildLogEntry = {
step: number;
phase: string; // "pre-build", "recipe", or "post-build"
cmd: string;
stdout: string;
stderr: string;
exit: number;
ok: boolean;
elapsed_ms: number;
};
export type Build = {
id: string;
name: string;
base_template: string;
recipe: string[];
healthcheck?: string;
vcpus: number;
memory_mb: number;
status: string;
current_step: number;
total_steps: number;
logs: BuildLogEntry[];
error?: string;
sandbox_id?: string;
host_id?: string;
created_at: string;
started_at?: string;
completed_at?: string;
};
export type CreateBuildParams = {
name: string;
base_template?: string;
recipe: string[];
healthcheck?: string;
vcpus?: number;
memory_mb?: number;
skip_pre_post?: boolean;
};
export async function createBuild(params: CreateBuildParams): Promise<ApiResult<Build>> {
return apiFetch('POST', '/api/v1/admin/builds', params);
}
export async function listBuilds(): Promise<ApiResult<Build[]>> {
return apiFetch('GET', '/api/v1/admin/builds');
}
export async function getBuild(id: string): Promise<ApiResult<Build>> {
return apiFetch('GET', `/api/v1/admin/builds/${id}`);
}
export type AdminTemplate = {
name: string;
type: string;
vcpus: number;
memory_mb: number;
size_bytes: number;
team_id: string;
created_at: string;
};
export async function listAdminTemplates(): Promise<ApiResult<AdminTemplate[]>> {
return apiFetch('GET', '/api/v1/admin/templates');
}
export async function deleteAdminTemplate(name: string): Promise<ApiResult<void>> {
return apiFetch('DELETE', `/api/v1/admin/templates/${name}`);
}
export async function cancelBuild(id: string): Promise<ApiResult<void>> {
return apiFetch('POST', `/api/v1/admin/builds/${id}/cancel`);
}

View File

@ -54,6 +54,7 @@ export type Snapshot = {
memory_mb?: number;
size_bytes: number;
created_at: string;
platform: boolean;
};
export async function createSnapshot(sandboxId: string, name?: string): Promise<ApiResult<Snapshot>> {

View File

@ -3,6 +3,7 @@
import { auth } from '$lib/auth.svelte';
import {
IconServer,
IconTemplate,
IconSettings,
IconLogout,
IconSidebar,
@ -21,7 +22,8 @@
};
const managementItems: NavItem[] = [
{ label: 'Hosts', icon: IconServer, href: '/admin/hosts' }
{ label: 'Hosts', icon: IconServer, href: '/admin/hosts' },
{ label: 'Templates', icon: IconTemplate, href: '/admin/templates' }
];
function isActive(href: string): boolean {

View File

@ -56,6 +56,7 @@
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]"
placeholder="minimal"
/>
<p class="mt-1.5 text-meta text-[var(--color-text-muted)]">Name of a snapshot or base image to boot from.</p>
</div>
<div class="grid grid-cols-2 gap-3">
@ -85,14 +86,16 @@
</div>
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="create-timeout">Idle timeout (seconds — 0 = never pause)</label>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="create-timeout">Idle timeout</label>
<input
id="create-timeout"
type="number"
min="0"
bind:value={createForm.timeout_sec}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none transition-colors duration-150 focus:border-[var(--color-accent)]"
placeholder="0"
/>
<p class="mt-1.5 text-meta text-[var(--color-text-muted)]">Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.</p>
</div>
</div>

View File

@ -0,0 +1,902 @@
<script lang="ts">
import AdminSidebar from '$lib/components/AdminSidebar.svelte';
import { onMount, onDestroy } from 'svelte';
import { toast } from '$lib/toast.svelte';
import { formatDate, timeAgo } from '$lib/utils/format';
import {
listBuilds,
createBuild,
cancelBuild,
listAdminTemplates,
deleteAdminTemplate,
type Build,
type BuildLogEntry,
type AdminTemplate
} from '$lib/api/builds';
let collapsed = $state(
typeof window !== 'undefined'
? localStorage.getItem('wrenn_sidebar_collapsed') === 'true'
: false
);
let activeTab = $state<'templates' | 'builds'>('templates');
// Templates state
let templates = $state<AdminTemplate[]>([]);
let templatesLoading = $state(true);
let templatesError = $state<string | null>(null);
// Builds state
let builds = $state<Build[]>([]);
let buildsLoading = $state(true);
let buildsError = $state<string | null>(null);
// Polling
let pollInterval: ReturnType<typeof setInterval> | null = null;
let hasActiveBuilds = $derived(builds.some((b) => b.status === 'pending' || b.status === 'running'));
// Build log expansion
let expandedBuildId = $state<string | null>(null);
let expandedSteps = $state<Set<number>>(new Set());
// Delete template state
let deleteTarget = $state<AdminTemplate | null>(null);
let deleting = $state(false);
let deleteError = $state<string | null>(null);
// Create dialog state
let showCreate = $state(false);
let createForm = $state({
name: '',
base_template: 'minimal',
vcpus: 1,
memory_mb: 512,
recipe: '',
healthcheck: '',
skip_pre_post: false
});
let creating = $state(false);
let createError = $state<string | null>(null);
// Cancel build state
let cancelingBuildId = $state<string | null>(null);
// Stats
let templateCount = $derived(templates.length);
let snapshotCount = $derived(templates.filter((t) => t.type === 'snapshot').length);
let baseCount = $derived(templates.filter((t) => t.type === 'base').length);
let runningBuilds = $derived(builds.filter((b) => b.status === 'running').length);
async function fetchTemplates() {
templatesLoading = true;
templatesError = null;
const result = await listAdminTemplates();
if (result.ok) {
templates = result.data;
} else {
templatesError = result.error;
}
templatesLoading = false;
}
async function fetchBuilds() {
const wasFirst = buildsLoading;
if (wasFirst) buildsLoading = true;
buildsError = null;
const result = await listBuilds();
if (result.ok) {
builds = result.data;
} else {
buildsError = result.error;
}
if (wasFirst) buildsLoading = false;
}
function startPolling() {
stopPolling();
pollInterval = setInterval(() => {
if (hasActiveBuilds) fetchBuilds();
}, 3000);
}
function stopPolling() {
if (pollInterval) {
clearInterval(pollInterval);
pollInterval = null;
}
}
async function handleCreate() {
creating = true;
createError = null;
const lines = createForm.recipe
.split('\n')
.map((l) => l.trim())
.filter((l) => l.length > 0);
if (lines.length === 0) {
createError = 'Recipe must contain at least one command.';
creating = false;
return;
}
const result = await createBuild({
name: createForm.name.trim(),
base_template: createForm.base_template.trim() || 'minimal',
recipe: lines,
healthcheck: createForm.healthcheck.trim() || undefined,
vcpus: createForm.vcpus,
memory_mb: createForm.memory_mb,
skip_pre_post: createForm.skip_pre_post
});
if (result.ok) {
showCreate = false;
createForm = { name: '', base_template: 'minimal', vcpus: 1, memory_mb: 512, recipe: '', healthcheck: '', skip_pre_post: false };
builds = [result.data, ...builds];
activeTab = 'builds';
expandedBuildId = result.data.id;
toast.success('Build queued');
startPolling();
} else {
createError = result.error;
}
creating = false;
}
async function handleDeleteTemplate() {
if (!deleteTarget) return;
deleting = true;
deleteError = null;
const name = deleteTarget.name;
const result = await deleteAdminTemplate(name);
if (result.ok) {
templates = templates.filter((t) => t.name !== name);
deleteTarget = null;
toast.success('Template deleted');
} else {
deleteError = result.error;
}
deleting = false;
}
async function handleCancelBuild(buildId: string) {
cancelingBuildId = buildId;
const result = await cancelBuild(buildId);
if (result.ok) {
builds = builds.map((b) => b.id === buildId ? { ...b, status: 'cancelled' } : b);
toast.success('Build cancelled');
} else {
toast.error(result.error ?? 'Failed to cancel build');
}
cancelingBuildId = null;
}
function toggleBuildExpand(buildId: string) {
if (expandedBuildId === buildId) {
expandedBuildId = null;
expandedSteps = new Set();
} else {
expandedBuildId = buildId;
expandedSteps = new Set();
}
}
function toggleStepExpand(step: number) {
const next = new Set(expandedSteps);
if (next.has(step)) {
next.delete(step);
} else {
next.add(step);
}
expandedSteps = next;
}
function formatBytes(bytes: number): string {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
function formatDuration(startedAt?: string, completedAt?: string): string {
if (!startedAt) return '—';
const start = new Date(startedAt).getTime();
const end = completedAt ? new Date(completedAt).getTime() : Date.now();
const sec = Math.round((end - start) / 1000);
if (sec < 60) return `${sec}s`;
return `${Math.floor(sec / 60)}m ${sec % 60}s`;
}
function statusColor(status: string): string {
switch (status) {
case 'success': return 'var(--color-accent-bright)';
case 'failed': return 'var(--color-red)';
case 'running': return 'var(--color-blue)';
case 'cancelled': return 'var(--color-amber)';
default: return 'var(--color-text-muted)';
}
}
// Returns [keyword, rest] from a recipe instruction string.
function splitInstruction(cmd: string): [string, string] {
const idx = cmd.indexOf(' ');
if (idx === -1) return [cmd.toUpperCase(), ''];
return [cmd.slice(0, idx).toUpperCase(), cmd.slice(idx + 1)];
}
function keywordColor(keyword: string): string {
switch (keyword) {
case 'RUN': return 'var(--color-blue)';
case 'START': return 'var(--color-accent-bright)';
case 'ENV': return 'var(--color-amber)';
case 'WORKDIR': return 'var(--color-text-tertiary)';
default: return 'var(--color-text-muted)';
}
}
onMount(() => {
fetchTemplates();
fetchBuilds().then(startPolling);
});
onDestroy(stopPolling);
</script>
<div class="flex h-screen overflow-hidden bg-[var(--color-bg-0)]">
<AdminSidebar bind:collapsed />
<main class="flex min-w-0 flex-1 flex-col overflow-hidden">
<!-- Header -->
<header class="flex shrink-0 flex-col gap-4 border-b border-[var(--color-border)] bg-[var(--color-bg-1)] px-6 py-5">
<div class="flex items-start justify-between">
<div>
<h1 class="font-serif text-[1.75rem] leading-none tracking-[-0.03em] text-[var(--color-text-bright)]">
Templates
</h1>
<p class="mt-1.5 text-ui text-[var(--color-text-tertiary)]">
Build and manage global templates available to all teams.
</p>
</div>
<button
onclick={() => { showCreate = true; createError = null; createForm = { name: '', base_template: 'minimal', vcpus: 1, memory_mb: 512, recipe: '', healthcheck: '' }; }}
class="flex items-center gap-2 rounded-[var(--radius-button)] bg-[var(--color-accent)] px-4 py-2 text-ui font-semibold text-white shadow-sm transition-all duration-150 hover:brightness-115 hover:-translate-y-px active:translate-y-0"
>
<svg width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><line x1="12" y1="5" x2="12" y2="19"/><line x1="5" y1="12" x2="19" y2="12"/></svg>
Create Template
</button>
</div>
<!-- Stat pills -->
{#if !templatesLoading && !templatesError}
<div class="flex items-center gap-2">
<div class="flex items-baseline gap-1 rounded-[var(--radius-button)] border border-[var(--color-border)] bg-[var(--color-bg-2)] px-2.5 py-1">
<span class="font-mono font-semibold text-ui tabular-nums text-[var(--color-text-bright)]">{templateCount}</span>
<span class="text-label text-[var(--color-text-muted)]">templates</span>
</div>
<div class="flex items-baseline gap-1 rounded-[var(--radius-button)] border border-[var(--color-border)] bg-[var(--color-bg-2)] px-2.5 py-1">
<span class="font-mono font-semibold text-ui tabular-nums text-[var(--color-text-bright)]">{baseCount}</span>
<span class="text-label text-[var(--color-text-muted)]">base</span>
</div>
<div class="flex items-baseline gap-1 rounded-[var(--radius-button)] border border-[var(--color-accent)]/25 bg-[var(--color-accent)]/8 px-2.5 py-1">
<span class="font-mono font-semibold text-ui tabular-nums text-[var(--color-accent-bright)]">{snapshotCount}</span>
<span class="text-label text-[var(--color-accent-bright)]/70">snapshots</span>
</div>
{#if runningBuilds > 0}
<div class="flex items-baseline gap-1.5 rounded-[var(--radius-button)] border border-[var(--color-blue)]/25 bg-[var(--color-blue)]/8 px-2.5 py-1">
<span class="relative mt-px flex h-1.5 w-1.5 shrink-0 self-center">
<span class="absolute inline-flex h-full w-full animate-ping rounded-full bg-[var(--color-blue)] opacity-60"></span>
<span class="relative inline-flex h-1.5 w-1.5 rounded-full bg-[var(--color-blue)]"></span>
</span>
<span class="font-mono font-semibold text-ui tabular-nums text-[var(--color-blue)]">{runningBuilds}</span>
<span class="text-label text-[var(--color-blue)]/70">building</span>
</div>
{/if}
</div>
{/if}
</header>
<!-- Tabs -->
<div class="flex shrink-0 border-b border-[var(--color-border)] bg-[var(--color-bg-1)] px-6">
{#each [['templates', 'Templates', templateCount], ['builds', 'Builds', builds.length]] as [id, label, count] (id)}
<button
onclick={() => { activeTab = id as 'templates' | 'builds'; }}
class="relative py-3 pr-5 text-ui transition-colors duration-150 {activeTab === id
? 'font-medium text-[var(--color-text-bright)]'
: 'text-[var(--color-text-tertiary)] hover:text-[var(--color-text-secondary)]'}"
>
{label}
{#if activeTab === id}
<span class="absolute bottom-0 left-0 right-5 h-[2px] rounded-t-full bg-[var(--color-accent)]"></span>
{/if}
{#if !templatesLoading}
<span class="ml-2 rounded-full bg-[var(--color-bg-4)] px-1.5 py-0.5 text-label text-[var(--color-text-muted)]">
{count}
</span>
{/if}
</button>
{/each}
</div>
<!-- Body -->
<div class="flex-1 overflow-y-auto p-6">
{#if activeTab === 'templates'}
{#if templatesLoading}
{@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])}
{:else if templatesError}
<div class="rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{templatesError}
</div>
{:else if templates.length === 0}
{@render emptyState('templates')}
{:else}
{@render templatesTable()}
{/if}
{:else}
{#if buildsLoading}
{@render skeletonRows(4, ['Build', 'Name', 'Status', 'Progress', 'Started', 'Duration'])}
{:else if buildsError}
<div class="rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{buildsError}
</div>
{:else if builds.length === 0}
{@render emptyState('builds')}
{:else}
{@render buildsTable()}
{/if}
{/if}
</div>
</main>
</div>
<!-- ── Snippets ─────────────────────────────────────────────────────── -->
{#snippet skeletonRows(count: number, headers: string[])}
<div class="rounded-[var(--radius-card)] border border-[var(--color-border)] bg-[var(--color-bg-1)] overflow-hidden">
<table class="w-full">
<thead>
<tr class="border-b border-[var(--color-border)]">
{#each headers as h}
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">{h}</th>
{/each}
</tr>
</thead>
<tbody>
{#each Array(count) as _, i}
<tr class="border-b border-[var(--color-border)] last:border-0" style="animation-delay: {i * 60}ms">
{#each headers as _h, j}
<td class="px-4 py-3.5">
<div class="skeleton h-3 rounded" style="width: {60 + j * 12}px"></div>
</td>
{/each}
</tr>
{/each}
</tbody>
</table>
</div>
{/snippet}
{#snippet emptyState(type: 'templates' | 'builds')}
<div class="flex flex-col items-center justify-center py-24 text-center">
<div class="mb-5 flex h-16 w-16 items-center justify-center rounded-[var(--radius-card)] border border-[var(--color-border)] bg-[var(--color-bg-2)]">
{#if type === 'templates'}
<svg width="28" height="28" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round" class="text-[var(--color-text-muted)]"><rect x="3" y="3" width="7" height="7"/><rect x="14" y="3" width="7" height="7"/><rect x="14" y="14" width="7" height="7"/><rect x="3" y="14" width="7" height="7"/></svg>
{:else}
<svg width="28" height="28" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round" class="text-[var(--color-text-muted)]"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8Z"/><path d="M14 2v6h6"/><path d="m16 13-3.5 3.5-2-2L8 17"/></svg>
{/if}
</div>
<p class="font-serif text-[1.125rem] leading-snug text-[var(--color-text-secondary)]">
{type === 'templates' ? 'No templates yet.' : 'No builds yet.'}
</p>
<p class="mt-1.5 text-ui text-[var(--color-text-muted)]">
{type === 'templates'
? 'Create a template to provide pre-configured environments for all teams.'
: 'Start a template build to see progress and logs here.'}
</p>
</div>
{/snippet}
{#snippet templatesTable()}
<div class="rounded-[var(--radius-card)] border border-[var(--color-border)] bg-[var(--color-bg-1)] overflow-hidden">
<table class="w-full">
<thead>
<tr class="border-b border-[var(--color-border)]">
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Name</th>
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Type</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] md:table-cell">Specs</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] lg:table-cell">Size</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] lg:table-cell">Created</th>
<th class="px-4 py-3"></th>
</tr>
</thead>
<tbody>
{#each templates as tmpl (tmpl.name)}
<tr class="border-b border-[var(--color-border)] last:border-0 transition-colors duration-200 hover:bg-[var(--color-bg-2)]">
<td class="px-4 py-3.5">
<span class="font-mono text-meta text-[var(--color-text-primary)]">{tmpl.name}</span>
</td>
<td class="px-4 py-3.5">
{#if tmpl.type === 'snapshot'}
<span class="inline-flex items-center rounded-full border border-[var(--color-accent)]/25 bg-[var(--color-accent)]/8 px-2 py-0.5 text-label font-medium text-[var(--color-accent-bright)]">
snapshot
</span>
{:else}
<span class="inline-flex items-center rounded-full border border-[var(--color-border)] bg-[var(--color-bg-3)] px-2 py-0.5 text-label font-medium text-[var(--color-text-secondary)]">
base
</span>
{/if}
</td>
<td class="hidden px-4 py-3.5 md:table-cell">
{#if tmpl.vcpus && tmpl.memory_mb}
<span class="text-meta text-[var(--color-text-secondary)]">
{tmpl.vcpus} vCPU · {tmpl.memory_mb} MB
</span>
{:else}
<span class="text-meta text-[var(--color-text-muted)]"></span>
{/if}
</td>
<td class="hidden px-4 py-3.5 lg:table-cell">
<span class="font-mono text-meta text-[var(--color-text-muted)]">
{tmpl.size_bytes ? formatBytes(tmpl.size_bytes) : '—'}
</span>
</td>
<td class="hidden px-4 py-3.5 lg:table-cell">
<span class="text-meta text-[var(--color-text-muted)]" title={formatDate(tmpl.created_at)}>
{timeAgo(tmpl.created_at)}
</span>
</td>
<td class="px-4 py-3.5 text-right">
<button
onclick={() => { deleteTarget = tmpl; deleteError = null; }}
class="rounded-[var(--radius-button)] px-3 py-1.5 text-meta text-[var(--color-text-tertiary)] transition-colors duration-150 hover:bg-[var(--color-red)]/10 hover:text-[var(--color-red)]"
>
Delete
</button>
</td>
</tr>
{/each}
</tbody>
</table>
</div>
{/snippet}
{#snippet buildsTable()}
<div class="space-y-0 rounded-[var(--radius-card)] border border-[var(--color-border)] bg-[var(--color-bg-1)] overflow-hidden">
<table class="w-full">
<thead>
<tr class="border-b border-[var(--color-border)]">
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Build</th>
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Name</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] md:table-cell">Base</th>
<th class="px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Status</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] md:table-cell">Progress</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] lg:table-cell">Started</th>
<th class="hidden px-4 py-3 text-left text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)] lg:table-cell">Duration</th>
</tr>
</thead>
<tbody>
{#each builds as build (build.id)}
<tr
class="border-b border-[var(--color-border)] last:border-0 cursor-pointer transition-colors duration-200
{expandedBuildId === build.id ? 'bg-[var(--color-bg-2)]' : 'hover:bg-[var(--color-bg-2)]'}"
onclick={() => toggleBuildExpand(build.id)}
>
<td class="px-4 py-3.5">
<div class="flex items-center gap-2">
<svg
width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
class="shrink-0 text-[var(--color-text-muted)] transition-transform duration-200 {expandedBuildId === build.id ? 'rotate-90' : ''}"
>
<polyline points="9 18 15 12 9 6"/>
</svg>
<span class="font-mono text-meta text-[var(--color-text-primary)]">{build.id}</span>
</div>
</td>
<td class="px-4 py-3.5">
<span class="text-meta text-[var(--color-text-primary)]">{build.name}</span>
</td>
<td class="hidden px-4 py-3.5 md:table-cell">
<span class="font-mono text-meta text-[var(--color-text-muted)]">{build.base_template}</span>
</td>
<td class="px-4 py-3.5">
<span class="flex items-center gap-1.5 text-meta font-medium" style="color: {statusColor(build.status)}">
{#if build.status === 'running'}
<span class="relative flex h-1.5 w-1.5 shrink-0">
<span class="absolute inline-flex h-full w-full animate-ping rounded-full opacity-60" style="background: {statusColor(build.status)}"></span>
<span class="relative inline-flex h-1.5 w-1.5 rounded-full" style="background: {statusColor(build.status)}"></span>
</span>
{:else if build.status === 'success'}
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><polyline points="20 6 9 17 4 12"/></svg>
{:else if build.status === 'failed'}
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
{:else}
<span class="h-1.5 w-1.5 shrink-0 rounded-full" style="background: {statusColor(build.status)}"></span>
{/if}
{build.status}
</span>
</td>
<td class="hidden px-4 py-3.5 md:table-cell">
<span class="font-mono text-meta text-[var(--color-text-muted)]">
{build.current_step} / {build.total_steps}
</span>
{#if build.status === 'running' && build.total_steps > 0}
<div class="mt-1.5 h-1 w-20 overflow-hidden rounded-full bg-[var(--color-bg-4)]">
<div
class="h-full rounded-full bg-[var(--color-blue)] transition-all duration-500"
style="width: {(build.current_step / build.total_steps) * 100}%"
></div>
</div>
{/if}
</td>
<td class="hidden px-4 py-3.5 lg:table-cell">
<span class="text-meta text-[var(--color-text-muted)]" title={formatDate(build.started_at)}>
{build.started_at ? timeAgo(build.started_at) : '—'}
</span>
</td>
<td class="hidden px-4 py-3.5 lg:table-cell">
<span class="font-mono text-meta text-[var(--color-text-muted)]">
{formatDuration(build.started_at, build.completed_at)}
</span>
</td>
</tr>
<!-- Expanded build logs -->
{#if expandedBuildId === build.id}
<tr>
<td colspan="7" class="border-b border-[var(--color-border)] last:border-0">
<div class="bg-[var(--color-bg-0)] px-6 py-4" style="animation: fadeUp 0.15s ease both">
{#if build.status === 'pending' || build.status === 'running'}
<div class="mb-4 flex justify-end">
<button
onclick={(e) => { e.stopPropagation(); handleCancelBuild(build.id); }}
disabled={cancelingBuildId === build.id}
class="flex items-center gap-1.5 rounded-[var(--radius-button)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/8 px-3 py-1.5 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/15 disabled:opacity-50"
>
{#if cancelingBuildId === build.id}
<svg class="animate-spin" width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
{:else}
<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
{/if}
Cancel build
</button>
</div>
{/if}
{#if build.error}
<div class="mb-4 rounded-[var(--radius-input)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-3 py-2 text-meta text-[var(--color-red)]">
{build.error}
</div>
{/if}
{#if build.logs && build.logs.length > 0}
<div class="space-y-1">
{#each build.logs as log, i (i)}
{@const isInternal = log.phase === 'pre-build' || log.phase === 'post-build'}
{@const recipeIdx = log.phase === 'recipe' ? build.logs.filter(l => l.phase === 'recipe' && l.step <= log.step).length : 0}
{@const phaseLabel = isInternal ? (log.phase === 'pre-build' ? 'Pre-build' : 'Post-build') : `Step ${recipeIdx}`}
{@const [kw, kwRest] = splitInstruction(log.cmd)}
<div class="rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-1)] overflow-hidden">
<!-- Step header -->
<button
onclick={(e) => { e.stopPropagation(); toggleStepExpand(log.step); }}
class="flex w-full items-center gap-3 px-3 py-2.5 text-left transition-colors duration-150 hover:bg-[var(--color-bg-2)]"
>
<!-- Status icon -->
{#if log.ok}
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="var(--color-accent-bright)" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" class="shrink-0"><polyline points="20 6 9 17 4 12"/></svg>
{:else}
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="var(--color-red)" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" class="shrink-0"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
{/if}
<span class="shrink-0 text-label font-semibold text-[var(--color-text-tertiary)]">{phaseLabel}</span>
<code class="flex-1 truncate font-mono text-meta"><span style="color: {keywordColor(kw)}">{kw}</span>{#if kwRest}<span class="text-[var(--color-text-secondary)]"> {kwRest}</span>{/if}</code>
<span class="shrink-0 font-mono text-label text-[var(--color-text-muted)]">{log.elapsed_ms}ms</span>
{#if log.exit !== 0}
<span class="shrink-0 rounded-full bg-[var(--color-red)]/10 px-1.5 py-0.5 font-mono text-label text-[var(--color-red)]">
exit {log.exit}
</span>
{/if}
<svg
width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
class="shrink-0 text-[var(--color-text-muted)] transition-transform duration-200 {expandedSteps.has(log.step) ? 'rotate-90' : ''}"
>
<polyline points="9 18 15 12 9 6"/>
</svg>
</button>
<!-- Step output -->
{#if expandedSteps.has(log.step)}
<div class="border-t border-[var(--color-border)] bg-[var(--color-bg-0)] px-3 py-3" style="animation: fadeUp 0.12s ease both">
{#if log.stdout}
<div class="mb-2">
<span class="text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">stdout</span>
<pre class="mt-1 max-h-48 overflow-auto rounded-[var(--radius-input)] bg-[var(--color-bg-1)] px-3 py-2 font-mono text-meta leading-relaxed text-[var(--color-text-secondary)]">{log.stdout}</pre>
</div>
{/if}
{#if log.stderr}
<div>
<span class="text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">stderr</span>
<pre class="mt-1 max-h-48 overflow-auto rounded-[var(--radius-input)] bg-[var(--color-bg-1)] px-3 py-2 font-mono text-meta leading-relaxed text-[var(--color-red)]/80">{log.stderr}</pre>
</div>
{/if}
{#if !log.stdout && !log.stderr}
<span class="text-meta text-[var(--color-text-muted)]">No output</span>
{/if}
</div>
{/if}
</div>
{/each}
</div>
{:else}
<div class="flex items-center gap-2 text-meta text-[var(--color-text-muted)]">
{#if build.status === 'pending' || build.status === 'running'}
<svg class="animate-spin" width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
{build.status === 'pending' ? 'Waiting for worker…' : 'Running…'}
{:else}
No build logs recorded.
{/if}
</div>
{/if}
<!-- Recipe reference -->
{#if build.recipe && build.recipe.length > 0}
<div class="mt-4 border-t border-[var(--color-border)] pt-4">
<span class="text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Recipe</span>
<div class="mt-2 rounded-[var(--radius-input)] bg-[var(--color-bg-1)] border border-[var(--color-border)] px-3 py-2">
{#each build.recipe as cmd, i}
{@const [kw, kwRest] = splitInstruction(cmd)}
<div class="flex gap-2 py-0.5">
<span class="shrink-0 font-mono text-label text-[var(--color-text-muted)] tabular-nums">{i + 1}.</span>
<code class="font-mono text-meta"><span style="color: {keywordColor(kw)}">{kw}</span>{#if kwRest}<span class="text-[var(--color-text-secondary)]"> {kwRest}</span>{/if}</code>
</div>
{/each}
</div>
</div>
{/if}
{#if build.healthcheck}
<div class="mt-3">
<span class="text-label font-semibold uppercase tracking-[0.06em] text-[var(--color-text-tertiary)]">Healthcheck</span>
<code class="ml-2 font-mono text-meta text-[var(--color-text-secondary)]">{build.healthcheck}</code>
</div>
{/if}
</div>
</td>
</tr>
{/if}
{/each}
</tbody>
</table>
</div>
{/snippet}
<!-- ── Create Template Dialog ──────────────────────────────────────── -->
{#if showCreate}
<div class="fixed inset-0 z-50 flex items-center justify-center">
<div
class="absolute inset-0 bg-black/60"
role="button"
tabindex="-1"
onclick={() => { if (!creating) showCreate = false; }}
onkeydown={(e) => { if (e.key === 'Escape' && !creating) showCreate = false; }}
></div>
<div
class="relative w-full max-w-[520px] max-h-[90vh] overflow-y-auto rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] p-6 shadow-xl"
style="animation: fadeUp 0.18s cubic-bezier(0.25,1,0.5,1) both"
>
<h2 class="font-serif text-[1.375rem] leading-tight tracking-[-0.02em] text-[var(--color-text-bright)]">
Create Template
</h2>
<p class="mt-1.5 text-ui text-[var(--color-text-tertiary)]">
Build a new global template by running commands on a base image.
</p>
{#if createError}
<div class="mt-4 rounded-[var(--radius-input)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-3 py-2 text-meta text-[var(--color-red)]">
{createError}
</div>
{/if}
<div class="mt-5 space-y-4">
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-name">
Template Name
</label>
<input
id="tmpl-name"
type="text"
placeholder="e.g. python312, node20-full"
bind:value={createForm.name}
disabled={creating}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
/>
</div>
<div class="grid grid-cols-3 gap-3">
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-base">
Base
</label>
<input
id="tmpl-base"
type="text"
bind:value={createForm.base_template}
disabled={creating}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
/>
</div>
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-vcpus">
vCPUs
</label>
<input
id="tmpl-vcpus"
type="number"
min="1"
bind:value={createForm.vcpus}
disabled={creating}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
/>
</div>
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-memory">
Memory MB
</label>
<input
id="tmpl-memory"
type="number"
min="128"
step="128"
bind:value={createForm.memory_mb}
disabled={creating}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
/>
</div>
</div>
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-recipe">
Recipe <span class="normal-case font-normal text-[var(--color-text-muted)]">(one instruction per line)</span>
</label>
<textarea
id="tmpl-recipe"
rows="7"
placeholder={"RUN apt-get install -y python3 python3-pip\nWORKDIR /app\nENV PORT=8080\nRUN pip3 install numpy pandas\nSTART python3 server.py"}
bind:value={createForm.recipe}
disabled={creating}
class="w-full resize-y rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-meta leading-relaxed text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
></textarea>
<p class="mt-1 text-label text-[var(--color-text-muted)]">
Supports <code class="font-mono">RUN</code>, <code class="font-mono">START</code>, <code class="font-mono">WORKDIR</code>, <code class="font-mono">ENV key=value</code>. RUN steps have a 30s timeout; override with <code class="font-mono">RUN --timeout=5m</code>.
</p>
</div>
<div>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="tmpl-healthcheck">
Healthcheck <span class="normal-case font-normal text-[var(--color-text-muted)]">(optional)</span>
</label>
<input
id="tmpl-healthcheck"
type="text"
placeholder="e.g. curl -s http://localhost:8080/health"
bind:value={createForm.healthcheck}
disabled={creating}
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-meta text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)] disabled:opacity-60"
/>
<p class="mt-1 text-label text-[var(--color-text-muted)]">
If set, the build will poll this command every 1s (up to 60s) after the recipe completes. On success, a full snapshot (with memory state) is created. Without a healthcheck, only the rootfs is saved.
</p>
</div>
<label class="flex cursor-pointer items-center gap-2.5">
<input
type="checkbox"
bind:checked={createForm.skip_pre_post}
disabled={creating}
class="h-4 w-4 cursor-pointer rounded border border-[var(--color-border)] bg-[var(--color-bg-4)] accent-[var(--color-accent)]"
/>
<span class="text-ui text-[var(--color-text-secondary)]">Skip pre-build and post-build steps</span>
</label>
</div>
<div class="mt-6 flex justify-end gap-3">
<button
onclick={() => (showCreate = false)}
disabled={creating}
class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
>
Cancel
</button>
<button
onclick={handleCreate}
disabled={creating || !createForm.name.trim() || !createForm.recipe.trim()}
class="flex items-center gap-2 rounded-[var(--radius-button)] bg-[var(--color-accent)] px-5 py-2 text-ui font-semibold text-white transition-all duration-150 hover:brightness-115 hover:-translate-y-px active:translate-y-0 disabled:opacity-50 disabled:hover:translate-y-0"
>
{#if creating}
<svg class="animate-spin" width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
Creating…
{:else}
Start Build
{/if}
</button>
</div>
</div>
</div>
{/if}
<!-- ── Delete Template Confirmation ────────────────────────────────── -->
{#if deleteTarget}
<div class="fixed inset-0 z-50 flex items-center justify-center">
<div
class="absolute inset-0 bg-black/60"
role="button"
tabindex="-1"
onclick={() => { if (!deleting) deleteTarget = null; }}
onkeydown={(e) => { if (e.key === 'Escape' && !deleting) deleteTarget = null; }}
></div>
<div
class="relative w-full max-w-[420px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] p-6 shadow-xl"
style="animation: fadeUp 0.18s cubic-bezier(0.25,1,0.5,1) both"
>
<h2 class="font-serif text-[1.375rem] leading-tight tracking-[-0.02em] text-[var(--color-text-bright)]">
Delete Template
</h2>
<p class="mt-1.5 text-ui text-[var(--color-text-tertiary)]">
Permanently remove <code class="rounded bg-[var(--color-bg-4)] px-1.5 py-0.5 font-mono text-[0.8rem] text-[var(--color-text-primary)]">{deleteTarget.name}</code> from all hosts.
</p>
{#if deleteError}
<div class="mt-3 rounded-[var(--radius-input)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-3 py-2 text-meta text-[var(--color-red)]">
{deleteError}
</div>
{/if}
<div class="mt-6 flex justify-end gap-3">
<button
onclick={() => (deleteTarget = null)}
disabled={deleting}
class="rounded-[var(--radius-button)] border border-[var(--color-border)] px-4 py-2 text-ui text-[var(--color-text-secondary)] transition-colors duration-150 hover:border-[var(--color-border-mid)] hover:text-[var(--color-text-primary)] disabled:opacity-50"
>
Cancel
</button>
<button
onclick={handleDeleteTemplate}
disabled={deleting}
class="flex items-center gap-2 rounded-[var(--radius-button)] bg-[var(--color-red)] px-5 py-2 text-ui font-semibold text-white transition-all duration-150 hover:brightness-110 disabled:opacity-50"
>
{#if deleting}
<svg class="animate-spin" width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
Deleting…
{:else}
Delete
{/if}
</button>
</div>
</div>
</div>
{/if}
<style>
@keyframes fadeUp {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes shimmer {
0% { background-position: -200% 0; }
100% { background-position: 200% 0; }
}
.skeleton {
background: linear-gradient(
90deg,
var(--color-bg-3) 25%,
var(--color-bg-4) 50%,
var(--color-bg-3) 75%
);
background-size: 200% 100%;
animation: shimmer 1.4s ease infinite;
}
</style>

View File

@ -47,7 +47,7 @@
Capsules
</h1>
<p class="mt-2 text-ui text-[var(--color-text-secondary)]">
Isolated VMs. Start cold in under a second — pause, snapshot, or destroy at will.
All active and recent capsules across your team.
</p>
</div>

View File

@ -247,6 +247,13 @@
return `${Math.floor(seconds / 86400)}d ago`;
}
function fmtTimeout(sec: number): string {
if (!sec) return 'None';
if (sec < 60) return `${sec}s`;
if (sec < 3600) return `${Math.round(sec / 60)}m`;
return `${Math.round(sec / 3600)}h`;
}
function handleClickOutside(event: MouseEvent) {
if (openMenuId && !(event.target as Element)?.closest('.status-menu-container')) {
openMenuId = null;
@ -300,7 +307,7 @@
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-2)] py-2 pl-9 pr-3 font-mono text-ui text-[var(--color-text-bright)] outline-none placeholder:text-[var(--color-text-muted)] transition-colors duration-150 focus:border-[var(--color-accent)]"
/>
</div>
<span class="text-ui text-[var(--color-text-secondary)]">{filteredCapsules.length} total</span>
<span class="text-ui text-[var(--color-text-secondary)]">{filteredCapsules.length} capsule{filteredCapsules.length !== 1 ? 's' : ''}</span>
<div class="flex-1"></div>
@ -363,8 +370,11 @@
</div>
{#if error}
<div class="mb-4 rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{error}
<div class="mb-4 flex items-start gap-3 rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3">
<svg class="mt-0.5 shrink-0 text-[var(--color-red)]" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10" /><line x1="12" y1="8" x2="12" y2="12" /><line x1="12" y1="16" x2="12.01" y2="16" />
</svg>
<span class="text-ui text-[var(--color-red)]">{error}. Try refreshing the page.</span>
</div>
{/if}
@ -466,14 +476,14 @@
<!-- Idle Timeout -->
<div class="px-5 py-4">
<span class="font-mono text-ui text-[var(--color-text-secondary)]">{capsule.timeout_sec ? `${capsule.timeout_sec}s` : '—'}</span>
<span class="font-mono text-ui text-[var(--color-text-secondary)]">{fmtTimeout(capsule.timeout_sec)}</span>
</div>
<!-- Started -->
<div class="px-5 py-4">
<span class="text-ui text-[var(--color-text-secondary)]" title={capsule.started_at ?? ''}>{formatTime(capsule.started_at)}</span>
{#if capsule.last_active_at}
<span class="ml-1.5 text-label text-[var(--color-text-muted)]">{timeAgo(capsule.last_active_at)}</span>
<span class="ml-1.5 text-label text-[var(--color-text-muted)]">active {timeAgo(capsule.last_active_at)}</span>
{/if}
</div>
@ -612,7 +622,7 @@
<line x1="12" y1="9" x2="12" y2="13" />
<line x1="12" y1="17" x2="12.01" y2="17" />
</svg>
<p class="text-meta text-[var(--color-amber)] leading-relaxed">This capsule will be <strong class="font-semibold">paused first</strong> — memory state is captured at rest.</p>
<p class="text-meta text-[var(--color-amber)] leading-relaxed">This capsule will be <strong class="font-semibold">paused first</strong>, then its full state (memory + disk) will be captured.</p>
</div>
{:else}
<p class="text-ui text-[var(--color-text-tertiary)]">The capsule's current memory state will be captured and stored as a reusable snapshot.</p>

View File

@ -0,0 +1 @@
export const prerender = false;

View File

@ -512,7 +512,7 @@
<svg class="shrink-0 text-[var(--color-red)]" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10" /><line x1="12" y1="8" x2="12" y2="12" /><line x1="12" y1="16" x2="12.01" y2="16" />
</svg>
<span class="text-ui text-[var(--color-red)]">Failed to load metrics: {metricsError}</span>
<span class="text-ui text-[var(--color-red)]">Could not load metrics: {metricsError}. Will retry automatically.</span>
</div>
{/if}

View File

@ -114,15 +114,15 @@
}
function emptyHeading(f: TypeFilter): string {
if (f === 'snapshot') return 'No snapshots';
if (f === 'base') return 'No images';
return 'No templates yet';
if (f === 'snapshot') return 'No snapshots yet';
if (f === 'base') return 'No base images';
return 'No snapshots yet';
}
function emptyDescription(f: TypeFilter): string {
if (f === 'snapshot') return 'Pause a capsule from the Capsules page, then snapshot it to capture its state.';
if (f === 'base') return 'Base images are added by the Wrenn team. Contact support to request a custom image.';
return 'To create a snapshot, go to Capsules, pause a running capsule, then choose Snapshot.';
if (f === 'snapshot') return 'Pause a running capsule, then choose Snapshot to save its state.';
if (f === 'base') return 'Base images are provided by the Wrenn team. Contact support to request a custom one.';
return 'Pause a running capsule, then choose Snapshot to save its state. You can launch new capsules from any snapshot.';
}
onMount(fetchSnapshots);
@ -162,7 +162,7 @@
Templates
</h1>
<p class="mt-2 text-ui text-[var(--color-text-secondary)]">
Snapshots capture a live capsule state. Base images are the rootfs every capsule starts from. Launch a full VM from any template.
Snapshots capture a running capsule's state. Base images are the starting point for every new capsule. Launch from either.
</p>
</div>
</div>
@ -206,8 +206,11 @@
{#if pageTab === 'snapshots'}
<div class="p-8" style="animation: fadeUp 0.35s ease both">
{#if error}
<div class="mb-4 rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3 text-ui text-[var(--color-red)]">
{error}
<div class="mb-4 flex items-start gap-3 rounded-[var(--radius-card)] border border-[var(--color-red)]/30 bg-[var(--color-red)]/5 px-4 py-3">
<svg class="mt-0.5 shrink-0 text-[var(--color-red)]" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10" /><line x1="12" y1="8" x2="12" y2="12" /><line x1="12" y1="16" x2="12.01" y2="16" />
</svg>
<span class="text-ui text-[var(--color-red)]">{error}. Try refreshing the page.</span>
</div>
{/if}
@ -274,11 +277,11 @@
</div>
<span class="text-meta text-[var(--color-text-muted)]">
{filteredSnapshots.length}
{typeFilter === 'all'
? filteredSnapshots.length === 1 ? 'template' : 'templates'
: typeFilter === 'snapshot'
? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
: filteredSnapshots.length === 1 ? 'image' : 'images'}
{typeFilter === 'snapshot'
? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
: typeFilter === 'base'
? filteredSnapshots.length === 1 ? 'image' : 'images'
: filteredSnapshots.length === 1 ? 'item' : 'total'}
</span>
</div>
@ -420,6 +423,7 @@
<div class="w-px shrink-0 bg-[var(--color-border-mid)]"></div>
<!-- Chevron / dropdown trigger -->
<button
disabled={snapshot.platform}
onclick={(e) => {
e.stopPropagation();
if (openDropdownName === snapshot.name) {
@ -430,7 +434,7 @@
openDropdownName = snapshot.name;
}
}}
class="flex items-center px-2 py-1.5 text-[var(--color-text-secondary)] transition-colors duration-150 hover:bg-[var(--color-bg-4)] hover:text-[var(--color-text-bright)]"
class="flex items-center px-2 py-1.5 text-[var(--color-text-secondary)] transition-colors duration-150 hover:bg-[var(--color-bg-4)] hover:text-[var(--color-text-bright)] disabled:cursor-not-allowed disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-[var(--color-text-secondary)]"
>
<svg
class="transition-transform duration-150 {openDropdownName === snapshot.name ? 'rotate-180' : ''}"
@ -447,11 +451,11 @@
<p class="mt-3 text-meta text-[var(--color-text-muted)]">
{filteredSnapshots.length}
{typeFilter === 'all'
? filteredSnapshots.length === 1 ? 'template' : 'templates'
: typeFilter === 'snapshot'
? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
: filteredSnapshots.length === 1 ? 'image' : 'images'}
{typeFilter === 'snapshot'
? filteredSnapshots.length === 1 ? 'snapshot' : 'snapshots'
: typeFilter === 'base'
? filteredSnapshots.length === 1 ? 'image' : 'images'
: filteredSnapshots.length === 1 ? 'item' : 'total'}
{typeFilter !== 'all' ? '· filtered' : '· total'}
</p>
{/if}
@ -481,21 +485,23 @@
class="fixed z-50 w-32 overflow-hidden rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] py-1"
style="top: {dropdownPos.top}px; left: {dropdownPos.left}px; animation: fadeUp 0.15s ease both"
>
<button
onclick={(e) => {
e.stopPropagation();
const target = snapshots.find((s) => s.name === openDropdownName);
openDropdownName = null;
if (target) { deleteTarget = target; deleteError = null; }
}}
class="flex w-full items-center gap-2 px-3 py-2 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/5"
>
<svg width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="shrink-0">
<polyline points="3 6 5 6 21 6" />
<path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6m3 0V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2" />
</svg>
Delete
</button>
{#if !dropdownSnapshot.platform}
<button
onclick={(e) => {
e.stopPropagation();
const target = snapshots.find((s) => s.name === openDropdownName);
openDropdownName = null;
if (target) { deleteTarget = target; deleteError = null; }
}}
class="flex w-full items-center gap-2 px-3 py-2 text-meta text-[var(--color-red)] transition-colors duration-150 hover:bg-[var(--color-red)]/5"
>
<svg width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="shrink-0">
<polyline points="3 6 5 6 21 6" />
<path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6m3 0V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2" />
</svg>
Delete
</button>
{/if}
</div>
{/if}
{/if}
@ -513,10 +519,10 @@
class="relative w-full max-w-[380px] rounded-[var(--radius-card)] border border-[var(--color-border-mid)] bg-[var(--color-bg-2)] p-6"
style="animation: fadeUp 0.2s ease both"
>
<h2 class="font-serif text-heading tracking-[-0.02em] text-[var(--color-text-bright)]">Delete Snapshot</h2>
<h2 class="font-serif text-heading tracking-[-0.02em] text-[var(--color-text-bright)]">Delete snapshot</h2>
<p class="mt-2 text-ui text-[var(--color-text-tertiary)]">
Permanently delete <span class="font-mono font-medium text-[var(--color-text-secondary)]">{deleteTarget.name}</span>.
Any capsule using this template will not be affected, but you won't be able to launch from it again.
Running capsules won't be affected, but you won't be able to launch new ones from it.
</p>
{#if deleteTarget.type === 'snapshot'}
@ -526,7 +532,7 @@
<line x1="12" y1="9" x2="12" y2="13" /><line x1="12" y1="17" x2="12.01" y2="17" />
</svg>
<p class="text-meta leading-relaxed text-[var(--color-amber)]">
This live capture includes saved memory state. Any capsule relying on it will be unable to resume.
This snapshot includes memory state. Paused capsules that depend on it won't be able to resume.
</p>
</div>
{/if}
@ -580,7 +586,7 @@
>
<h2 class="font-serif text-heading tracking-[-0.02em] text-[var(--color-text-bright)]">Launch Capsule</h2>
<p class="mt-1 text-ui text-[var(--color-text-tertiary)]">
Configure resources and launch. The VM will clone from this template and be ready in seconds.
Configure resources and launch a new capsule from this snapshot.
</p>
{#if launchError}
@ -655,14 +661,16 @@
<!-- Timeout -->
<div class="mt-4">
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="launch-timeout">Auto-pause timeout (seconds, 0 = never)</label>
<label class="mb-1.5 block text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-tertiary)]" for="launch-timeout">Idle timeout</label>
<input
id="launch-timeout"
type="number"
min="0"
bind:value={launchTimeoutSec}
placeholder="0"
class="w-full rounded-[var(--radius-input)] border border-[var(--color-border)] bg-[var(--color-bg-4)] px-3 py-2 font-mono text-ui text-[var(--color-text-bright)] outline-none transition-colors duration-150 focus:border-[var(--color-accent)]"
/>
<p class="mt-1.5 text-meta text-[var(--color-text-muted)]">Seconds of inactivity before the capsule pauses. Set to 0 to keep it running indefinitely.</p>
</div>
<div class="mt-6 flex justify-end gap-3">

View File

@ -50,7 +50,6 @@
let nameInputEl = $state<HTMLInputElement | null>(null);
// Copy state
let copiedSlug = $state(false);
let copiedId = $state(false);
// Add member dialog
@ -139,16 +138,11 @@
savingName = false;
}
async function copyToClipboard(text: string, which: 'slug' | 'id') {
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text);
if (which === 'slug') {
copiedSlug = true;
setTimeout(() => (copiedSlug = false), 2000);
} else {
copiedId = true;
setTimeout(() => (copiedId = false), 2000);
}
copiedId = true;
setTimeout(() => (copiedId = false), 2000);
} catch {
toast.error('Copy failed — select the text and copy manually.');
}
@ -514,115 +508,58 @@
</div>
</div>
<!-- Slug + ID row (2-col grid) -->
<div class="grid grid-cols-2 divide-x divide-[var(--color-border)]">
<!-- Slug -->
<div class="flex items-center gap-3 px-5 py-4">
<div class="min-w-0 flex-1">
<div
class="mb-1 text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-muted)]"
>
Slug
</div>
<span class="block truncate font-mono text-ui text-[var(--color-text-secondary)]"
>{team.slug}</span
>
</div>
<button
onclick={() => copyToClipboard(team!.slug, 'slug')}
title="Copy slug"
class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
{copiedSlug
? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
: 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
<!-- Team ID -->
<div class="flex items-center gap-3 px-5 py-4">
<div class="min-w-0 flex-1">
<div
class="mb-1 text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-muted)]"
>
{#if copiedSlug}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2.5"
stroke-linecap="round"
stroke-linejoin="round"
class="checkmark-draw"
>
<polyline points="20 6 9 17 4 12" />
</svg>
Copied
{:else}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<rect x="9" y="9" width="13" height="13" rx="2" ry="2" />
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1" />
</svg>
Copy
{/if}
</button>
</div>
<!-- Team ID -->
<div class="flex items-center gap-3 px-5 py-4">
<div class="min-w-0 flex-1">
<div
class="mb-1 text-label font-semibold uppercase tracking-[0.05em] text-[var(--color-text-muted)]"
>
Team ID
</div>
<span class="block truncate font-mono text-ui text-[var(--color-text-secondary)]"
>{team.id}</span
>
Team ID
</div>
<button
onclick={() => copyToClipboard(team!.id, 'id')}
title="Copy team ID"
class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
{copiedId
? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
: 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
<span class="block truncate font-mono text-ui text-[var(--color-text-secondary)]"
>{team.id}</span
>
{#if copiedId}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2.5"
stroke-linecap="round"
stroke-linejoin="round"
class="checkmark-draw"
>
<polyline points="20 6 9 17 4 12" />
</svg>
Copied
{:else}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<rect x="9" y="9" width="13" height="13" rx="2" ry="2" />
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1" />
</svg>
Copy
{/if}
</button>
</div>
<button
onclick={() => copyToClipboard(team!.id)}
title="Copy team ID"
class="flex shrink-0 items-center gap-1.5 rounded-[var(--radius-button)] border px-3 py-1.5 text-meta font-semibold transition-all duration-150
{copiedId
? 'border-[var(--color-accent)]/40 bg-[var(--color-accent-glow-mid)] text-[var(--color-accent-mid)]'
: 'border-[var(--color-border-mid)] text-[var(--color-text-secondary)] hover:text-[var(--color-text-primary)]'}"
>
{#if copiedId}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2.5"
stroke-linecap="round"
stroke-linejoin="round"
class="checkmark-draw"
>
<polyline points="20 6 9 17 4 12" />
</svg>
Copied
{:else}
<svg
width="12"
height="12"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<rect x="9" y="9" width="13" height="13" rx="2" ry="2" />
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1" />
</svg>
Copy
{/if}
</button>
</div>
</div>
</section>

View File

@ -7,7 +7,7 @@ export default defineConfig({
server: {
proxy: {
'/api': {
target: 'http://localhost:8000',
target: 'http://localhost:8080',
rewrite: (path) => path.replace(/^\/api/, '')
}
}

View File

@ -1,6 +1,6 @@
#!/bin/sh
# wrenn-init: minimal PID 1 init for Firecracker microVMs.
# Mounts virtual filesystems then execs envd.
# Mounts virtual filesystems, starts chronyd for time sync, then execs tini + envd.
set -e
@ -25,7 +25,19 @@ echo "nameserver 8.8.8.8" > /etc/resolv.conf
echo "nameserver 8.8.4.4" >> /etc/resolv.conf
# Set a standard PATH so envd and all child processes can find common binaries.
export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games
# Write chrony config to sync time from the KVM PTP hardware clock.
# /dev/ptp0 is a paravirtual clock exposed by KVM — no network required.
mkdir -p /etc/chrony /run/chrony
cat > /etc/chrony/chrony.conf <<EOF
refclock PHC /dev/ptp0 poll 2 dpoll 2
driftfile /run/chrony/chrony.drift
makestep 1.0 -1
EOF
# Start chronyd in the background before handing off to tini.
chronyd -f /etc/chrony/chrony.conf 2>/dev/null || true
# Exec tini as PID 1 — it reaps zombie processes and forwards signals to envd.
exec /sbin/tini -- /usr/local/bin/envd

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@ -11,7 +13,7 @@ import (
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)

View File

@ -0,0 +1,229 @@
package api
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
)
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
// without relying on error message text.
var (
errProxySandboxNotFound = errors.New("sandbox not found")
errProxyNoHostAddress = errors.New("host agent has no address")
)
const proxyCacheTTL = 120 * time.Second
// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
// errProxySandboxNotRunning carries the sandbox status so callers can include
// it in the HTTP response without parsing error strings.
type errProxySandboxNotRunning struct{ status string }
func (e errProxySandboxNotRunning) Error() string {
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
}
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
// can capture the correct port without the cache key needing to include it.
type proxyCacheEntry struct {
agentURL *url.URL
expiresAt time.Time
}
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
type proxyCacheKey [32]byte
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
var k proxyCacheKey
copy(k[:16], sandboxID.Bytes[:])
copy(k[16:], teamID.Bytes[:])
return k
}
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
// requests are reverse-proxied through the host agent that owns the sandbox.
// All other requests are passed through to the inner handler.
//
// Authentication is via X-API-Key header only (no JWT). The API key's team
// must own the sandbox.
type SandboxProxyWrapper struct {
inner http.Handler
db *db.Queries
pool *lifecycle.HostClientPool
transport http.RoundTripper
cacheMu sync.Mutex
cache map[proxyCacheKey]proxyCacheEntry
}
// NewSandboxProxyWrapper creates a new proxy wrapper.
func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
return &SandboxProxyWrapper{
inner: inner,
db: queries,
pool: pool,
transport: pool.Transport(),
cache: make(map[proxyCacheKey]proxyCacheEntry),
}
}
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
// On a miss it queries the DB, resolves the address, and populates the cache.
// The *httputil.ReverseProxy is built by the caller so the Director closure
// captures the correct port without the cache key needing to include it.
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
cacheKey := makeProxyCacheKey(sandboxID, teamID)
h.cacheMu.Lock()
entry, ok := h.cache[cacheKey]
h.cacheMu.Unlock()
if ok && time.Now().Before(entry.expiresAt) {
return entry.agentURL, nil
}
// Cache miss or expired — query DB.
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
ID: sandboxID,
TeamID: teamID,
})
if err != nil {
return nil, errProxySandboxNotFound
}
if target.Status != "running" {
return nil, errProxySandboxNotRunning{status: target.Status}
}
if target.HostAddress == "" {
return nil, errProxyNoHostAddress
}
agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
if err != nil {
return nil, fmt.Errorf("invalid host agent address: %w", err)
}
h.cacheMu.Lock()
h.cache[cacheKey] = proxyCacheEntry{
agentURL: agentURL,
expiresAt: time.Now().Add(proxyCacheTTL),
}
h.cacheMu.Unlock()
return agentURL, nil
}
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
h.cacheMu.Lock()
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
h.cacheMu.Unlock()
}
func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Host
// Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
host = host[:colonIdx]
}
matches := sandboxHostPattern.FindStringSubmatch(host)
if matches == nil {
h.inner.ServeHTTP(w, r)
return
}
port := matches[1]
sandboxIDStr := matches[2]
// Validate port.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
// Authenticate: require API key or JWT, extract team ID.
teamID, err := h.authenticateRequest(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
return
}
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
return
}
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
if err != nil {
switch {
case errors.Is(err, errProxySandboxNotFound):
http.Error(w, err.Error(), http.StatusNotFound)
case errors.As(err, new(errProxySandboxNotRunning)):
http.Error(w, err.Error(), http.StatusConflict)
default:
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
return
}
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = agentURL.Scheme
req.URL.Host = agentURL.Host
req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path
req.Host = agentURL.Host
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("sandbox proxy error",
"sandbox_id", sandboxIDStr,
"port", port,
"error", err,
)
h.evictProxyCache(sandboxID, teamID)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}
// authenticateRequest validates the request's API key and returns the team ID.
// Only API key authentication is supported for sandbox proxy requests (not JWT).
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
key := r.Header.Get("X-API-Key")
if key == "" {
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
}
hash := auth.HashAPIKey(key)
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid API key")
}
return row.TeamID, nil
}

View File

@ -9,6 +9,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -39,11 +40,11 @@ type apiKeyResponse struct {
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
@ -57,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
}
if k.CreatedAt.Valid {
@ -118,7 +119,13 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /v1/api-keys/{id}.
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
keyID := chi.URLParam(r, "id")
keyIDStr := chi.URLParam(r, "id")
keyID, err := id.ParseAPIKeyID(keyIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
return
}
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")

View File

@ -6,7 +6,10 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -65,13 +68,24 @@ func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
limit = n
}
// Parse ?before_id cursor (UUID).
var beforeID pgtype.UUID
if s := r.URL.Query().Get("before_id"); s != "" {
parsed, err := id.ParseAuditLogID(s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
return
}
beforeID = parsed
}
entries, err := h.svc.List(r.Context(), service.AuditListParams{
TeamID: ac.TeamID,
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
Actions: parseMultiParam(r.URL.Query()["action"]),
Before: before,
BeforeID: r.URL.Query().Get("before_id"),
BeforeID: beforeID,
Limit: limit,
})
if err != nil {

View File

@ -20,7 +20,7 @@ import (
// It prefers the user's default team; if none is flagged as default it falls
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
// no team memberships at all.
func loginTeam(ctx context.Context, q *db.Queries, userID string) (db.Team, string, error) {
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
team, err := q.GetDefaultTeamForUser(ctx, userID)
if err == nil {
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
@ -176,8 +176,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, authResponse{
Token: token,
UserID: userID,
TeamID: teamID,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
})
@ -236,8 +236,8 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: user.ID,
TeamID: team.ID,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
@ -260,10 +260,16 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
ctx := r.Context()
// Verify team exists and is not deleted.
team, err := h.db.GetTeam(ctx, req.TeamID)
team, err := h.db.GetTeam(ctx, teamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "team not found")
@ -280,7 +286,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
// Verify membership from DB — JWT role is not trusted here.
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: ac.UserID,
TeamID: req.TeamID,
TeamID: teamID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@ -298,7 +304,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -306,8 +312,8 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: ac.UserID,
TeamID: req.TeamID,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
})

View File

@ -0,0 +1,276 @@
package api
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
type buildHandler struct {
svc *service.BuildService
db *db.Queries
pool *lifecycle.HostClientPool
}
func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
return &buildHandler{svc: svc, db: db, pool: pool}
}
type createBuildRequest struct {
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []string `json:"recipe"`
Healthcheck string `json:"healthcheck"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SkipPrePost bool `json:"skip_pre_post"`
}
type buildResponse struct {
ID string `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe json.RawMessage `json:"recipe"`
Healthcheck *string `json:"healthcheck,omitempty"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs json.RawMessage `json:"logs"`
Error *string `json:"error,omitempty"`
SandboxID *string `json:"sandbox_id,omitempty"`
HostID *string `json:"host_id,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
}
func buildToResponse(b db.TemplateBuild) buildResponse {
resp := buildResponse{
ID: id.FormatBuildID(b.ID),
Name: b.Name,
BaseTemplate: b.BaseTemplate,
Recipe: b.Recipe,
VCPUs: b.Vcpus,
MemoryMB: b.MemoryMb,
Status: b.Status,
CurrentStep: b.CurrentStep,
TotalSteps: b.TotalSteps,
Logs: b.Logs,
}
if b.Healthcheck != "" {
resp.Healthcheck = &b.Healthcheck
}
if b.Error != "" {
resp.Error = &b.Error
}
if b.SandboxID.Valid {
s := id.FormatSandboxID(b.SandboxID)
resp.SandboxID = &s
}
if b.HostID.Valid {
s := id.FormatHostID(b.HostID)
resp.HostID = &s
}
if b.CreatedAt.Valid {
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
}
if b.StartedAt.Valid {
s := b.StartedAt.Time.Format(time.RFC3339)
resp.StartedAt = &s
}
if b.CompletedAt.Valid {
s := b.CompletedAt.Time.Format(time.RFC3339)
resp.CompletedAt = &s
}
return resp
}
// Create handles POST /v1/admin/builds.
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createBuildRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
return
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
if len(req.Recipe) == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
return
}
build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
Name: req.Name,
BaseTemplate: req.BaseTemplate,
Recipe: req.Recipe,
Healthcheck: req.Healthcheck,
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
})
if err != nil {
slog.Error("failed to create build", "error", err)
writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
return
}
writeJSON(w, http.StatusCreated, buildToResponse(build))
}
// List handles GET /v1/admin/builds.
func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
builds, err := h.svc.List(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
return
}
resp := make([]buildResponse, len(builds))
for i, b := range builds {
resp[i] = buildToResponse(b)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/admin/builds/{id}.
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.svc.Get(r.Context(), buildID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "build not found")
return
}
writeJSON(w, http.StatusOK, buildToResponse(build))
}
// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
templates, err := h.db.ListTemplates(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
return
}
type templateResponse struct {
Name string `json:"name"`
Type string `json:"type"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
CreatedAt string `json:"created_at"`
}
resp := make([]templateResponse, len(templates))
for i, t := range templates {
resp[i] = templateResponse{
Name: t.Name,
Type: t.Type,
VCPUs: t.Vcpus,
MemoryMB: t.MemoryMb,
SizeBytes: t.SizeBytes,
TeamID: id.FormatTeamID(t.TeamID),
}
if t.CreatedAt.Valid {
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
}
writeJSON(w, http.StatusOK, resp)
}
// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if err := validate.SafeName(name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
ctx := r.Context()
tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
// Broadcast delete to all online hosts.
hosts, _ := h.db.ListActiveHosts(ctx)
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(tmpl.TeamID),
TemplateId: formatUUIDForRPC(tmpl.ID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
}
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Cancel handles POST /v1/admin/builds/{id}/cancel.
func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
if err := h.svc.Cancel(r.Context(), buildID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -46,10 +47,16 @@ type execResponse struct {
// Exec handles POST /v1/sandboxes/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -80,7 +87,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
@ -101,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
// Use base64 encoding if output contains non-UTF-8 bytes.
@ -112,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
@ -124,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),

View File

@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -48,10 +49,16 @@ type wsOutMsg struct {
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -91,7 +98,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
defer cancel()
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Cmd: startMsg.Cmd,
Args: startMsg.Args,
}))
@ -157,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
}

View File

@ -11,6 +11,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -82,7 +89,7 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
Content: content,
})); err != nil {
@ -101,10 +108,16 @@ type readFileRequest struct {
// Download handles POST /v1/sandboxes/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -133,7 +146,7 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -12,6 +12,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -29,10 +30,16 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -101,7 +108,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
if err := stream.Send(&pb.WriteFileStreamRequest{
Content: &pb.WriteFileStreamRequest_Meta{
Meta: &pb.WriteFileStreamMeta{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
},
},
@ -146,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -178,7 +191,7 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
// Open server-streaming RPC to host agent.
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -8,9 +8,12 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -46,6 +49,9 @@ type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type deletePreviewResponse struct {
@ -66,6 +72,9 @@ type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type addTagRequest struct {
@ -93,34 +102,35 @@ type hostResponse struct {
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
ID: h.ID,
ID: id.FormatHostID(h.ID),
Type: h.Type,
Status: h.Status,
CreatedBy: h.CreatedBy,
CreatedBy: id.FormatUserID(h.CreatedBy),
}
if h.TeamID.Valid {
resp.TeamID = &h.TeamID.String
s := id.FormatTeamID(h.TeamID)
resp.TeamID = &s
}
if h.Provider.Valid {
resp.Provider = &h.Provider.String
if h.Provider != "" {
resp.Provider = &h.Provider
}
if h.AvailabilityZone.Valid {
resp.AvailabilityZone = &h.AvailabilityZone.String
if h.AvailabilityZone != "" {
resp.AvailabilityZone = &h.AvailabilityZone
}
if h.Arch.Valid {
resp.Arch = &h.Arch.String
if h.Arch != "" {
resp.Arch = &h.Arch
}
if h.CpuCores.Valid {
resp.CPUCores = &h.CpuCores.Int32
if h.CpuCores != 0 {
resp.CPUCores = &h.CpuCores
}
if h.MemoryMb.Valid {
resp.MemoryMB = &h.MemoryMb.Int32
if h.MemoryMb != 0 {
resp.MemoryMB = &h.MemoryMb
}
if h.DiskGb.Valid {
resp.DiskGB = &h.DiskGb.Int32
if h.DiskGb != 0 {
resp.DiskGB = &h.DiskGb
}
if h.Address.Valid {
resp.Address = &h.Address.String
if h.Address != "" {
resp.Address = &h.Address
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
@ -133,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
@ -151,14 +161,23 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
Type: req.Type,
TeamID: req.TeamID,
Provider: req.Provider,
AvailabilityZone: req.AvailabilityZone,
RequestingUserID: ac.UserID,
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
})
// Parse optional team ID from request body.
var params service.HostCreateParams
params.Type = req.Type
params.Provider = req.Provider
params.AvailabilityZone = req.AvailabilityZone
params.RequestingUserID = ac.UserID
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
if req.TeamID != "" {
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
params.TeamID = teamID
}
result, err := h.svc.Create(r.Context(), params)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -166,8 +185,7 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
hostTeamID := result.Host.TeamID.String
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, hostTeamID)
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
@ -192,14 +210,22 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
seen[host.TeamID.String] = struct{}{}
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for id := range seen {
if team, err := h.queries.GetTeam(r.Context(), id); err == nil {
teamNames[id] = team.Name
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
@ -209,7 +235,8 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
for i, host := range hosts {
resp[i] = hostToResponse(host)
if host.TeamID.Valid {
if name, ok := teamNames[host.TeamID.String]; ok {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
@ -220,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -236,9 +269,15 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -256,19 +295,25 @@ func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Fetch host before deletion to capture team_id for audit.
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
if hostErr != nil {
slog.Warn("audit: could not fetch host before delete", "host_id", hostID, "error", hostErr)
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
}
err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID.String)
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
w.WriteHeader(http.StatusNoContent)
return
}
@ -292,9 +337,15 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -343,14 +394,23 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
@ -368,7 +428,7 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// Log marked_up if the host just recovered from unreachable.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID.String, hc.HostID)
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
}
w.WriteHeader(http.StatusNoContent)
@ -376,10 +436,16 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
@ -401,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -438,14 +510,23 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)

View File

@ -7,9 +7,11 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -38,10 +40,16 @@ type metricsResponse struct {
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
rangeTier := r.URL.Query().Get("range")
if rangeTier == "" {
rangeTier = "10m"
@ -60,15 +68,15 @@ func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Reques
switch sb.Status {
case "running":
h.getFromAgent(w, r, sandboxID, rangeTier, sb.HostID)
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
case "paused":
h.getFromDB(ctx, w, sandboxID, rangeTier)
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
default:
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
}
}
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxID, rangeTier, hostID string) {
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
ctx := r.Context()
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
@ -78,7 +86,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Range: rangeTier,
}))
if err != nil {
@ -98,7 +106,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
@ -118,7 +126,7 @@ var rangeToDB = map[string]struct {
"24h": {"24h", 24 * time.Hour},
}
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxID, rangeTier string) {
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
mapping := rangeToDB[rangeTier]
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
SandboxID: sandboxID,
@ -141,7 +149,7 @@ func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWr
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})

View File

@ -162,7 +162,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -262,7 +262,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
redirectWithToken(w, r, redirectBase, token, userID, teamID, email, profile.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -296,7 +296,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {

View File

@ -10,6 +10,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -46,7 +47,7 @@ type sandboxResponse struct {
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
resp := sandboxResponse{
ID: sb.ID,
ID: id.FormatSandboxID(sb.ID),
Status: sb.Status,
Template: sb.Template,
VCPUs: sb.Vcpus,
@ -81,7 +82,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
}
ac := auth.MustFromContext(r.Context())
if ac.TeamID == "" {
if !ac.TeamID.Valid {
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
return
}
@ -122,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/sandboxes/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -136,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
// Pause handles POST /v1/sandboxes/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -152,9 +165,15 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
// Resume handles POST /v1/sandboxes/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -168,9 +187,15 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
// Ping handles POST /v1/sandboxes/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -182,9 +207,15 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
// Destroy handles DELETE /v1/sandboxes/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)

View File

@ -10,12 +10,14 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
@ -35,8 +37,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors. TODO: add host_id to templates table.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name string) error {
// and ignore NotFound errors.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
@ -49,9 +51,12 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
@ -70,6 +75,7 @@ type snapshotResponse struct {
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
}
func templateToResponse(t db.Template) snapshotResponse {
@ -77,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
Name: t.Name,
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
}
if t.Vcpus.Valid {
resp.VCPUs = &t.Vcpus.Int32
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
}
if t.MemoryMb.Valid {
resp.MemoryMB = &t.MemoryMb.Int32
if t.MemoryMb != 0 {
resp.MemoryMB = &t.MemoryMb
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -103,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
@ -115,14 +128,20 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old snapshot files from all hosts before removing the DB record.
if err := h.deleteSnapshotBroadcast(ctx, req.Name); err != nil {
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
@ -133,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
@ -149,30 +168,53 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
if sb.Status != "paused" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: req.SandboxID, Status: "paused",
}); err != nil {
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
}
}
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
})
@ -182,7 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSnapshotCreate(r.Context(), ac, req.Name)
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
}
@ -215,12 +262,22 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
// Platform templates can only be deleted by admins via /v1/admin/templates.
if tmpl.TeamID == id.PlatformTeamID {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
if err := h.deleteSnapshotBroadcast(ctx, name); err != nil {
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}

View File

@ -7,10 +7,12 @@ import (
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -48,7 +50,7 @@ type memberResponse struct {
func teamToResponse(t db.Team) teamResponse {
resp := teamResponse{
ID: t.ID,
ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
@ -72,11 +74,16 @@ func memberInfoToResponse(m service.MemberInfo) memberResponse {
// requireTeamAccess is an inline check used by every team-scoped handler:
// the JWT team_id must match the URL {id} before any DB call is made.
// Returns false and writes 403 if they don't match.
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (string, bool) {
teamID := chi.URLParam(r, "id")
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return pgtype.UUID{}, false
}
if ac.TeamID != teamID {
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
return "", false
return pgtype.UUID{}, false
}
return teamID, true
}
@ -185,7 +192,7 @@ func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
// Fetch old name for audit log before renaming.
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
slog.Warn("audit: could not fetch old team name for rename log", "team_id", teamID, "error", err)
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
@ -267,7 +274,11 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogMemberAdd(r.Context(), ac, member.UserID, member.Email, member.Role)
// member.UserID is already formatted with prefix; parse it back for the audit logger.
targetUserID, parseErr := id.ParseUserID(member.UserID)
if parseErr == nil {
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
}
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
@ -279,7 +290,13 @@ func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
targetUserID := chi.URLParam(r, "uid")
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -299,7 +316,13 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
targetUserID := chi.URLParam(r, "uid")
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Role string `json:"role"`
@ -341,7 +364,13 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
// Enables or disables the BYOC feature flag for a team.
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
teamID := chi.URLParam(r, "id")
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return
}
var req struct {
Enabled bool `json:"enabled"`

View File

@ -8,6 +8,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type usersHandler struct {
@ -45,7 +46,7 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
}
resp := make([]userResult, len(results))
for i, u := range results {
resp[i] = userResult{UserID: u.ID, Email: u.Email}
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -6,9 +6,11 @@ import (
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -82,15 +84,15 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
}
m.audit.LogHostMarkedDown(ctx, host.TeamID.String, host.ID)
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
return
}
@ -110,19 +112,20 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPaused[id] = struct{}{}
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
@ -134,30 +137,31 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []string
var toStop []string
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
for _, sb := range missingSandboxes {
if _, ok := alive[sb.ID]; ok {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
@ -169,18 +173,19 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
var toPause, toStop []string
sbTeamID := make(map[string]string, len(runningSandboxes))
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sb.ID]; ok {
if _, ok := alive[sbIDStr]; ok {
continue
}
if _, ok := autoPaused[sb.ID]; ok {
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
@ -188,24 +193,24 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}

View File

@ -12,6 +12,9 @@ import (
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type errorResponse struct {
@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
})
}
// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
func formatUUIDForRPC(u pgtype.UUID) string {
return id.UUIDString(u)
}
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
func agentErrToHTTP(err error) (int, string, string) {
switch connect.CodeOf(err) {

View File

@ -7,6 +7,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@ -24,7 +25,7 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
@ -45,9 +46,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,

View File

@ -4,6 +4,7 @@ import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
hostID, err := id.ParseHostID(claims.HostID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
@ -25,9 +26,20 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,

View File

@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
@ -22,7 +23,8 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
router chi.Router
router chi.Router
BuildSvc *service.BuildService
}
// New constructs the chi router and registers all routes.
@ -35,6 +37,7 @@ func New(
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
ca *auth.CA,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
@ -43,10 +46,11 @@ func New(
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
al := audit.New(queries)
@ -65,6 +69,7 @@ func New(
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@ -174,9 +179,15 @@ func New(
r.Use(requireJWT(jwtSecret))
r.Use(requireAdmin(queries))
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
})
return &Server{router: r}
return &Server{router: r, BuildSvc: buildSvc}
}
// Handler returns the HTTP handler.

View File

@ -25,18 +25,15 @@ func New(queries *db.Queries) *AuditLogger {
}
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
func actorFields(ac auth.AuthContext) (actorType string, actorID pgtype.Text, actorName pgtype.Text) {
if ac.UserID != "" {
return "user",
pgtype.Text{String: ac.UserID, Valid: true},
pgtype.Text{String: ac.Name, Valid: ac.Name != ""}
// actor_id is stored as a prefixed string in the TEXT column.
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
if ac.UserID.Valid {
return "user", id.FormatUserID(ac.UserID), ac.Name
}
if ac.APIKeyID != "" {
return "api_key",
pgtype.Text{String: ac.APIKeyID, Valid: true},
pgtype.Text{String: ac.APIKeyName, Valid: true}
if ac.APIKeyID.Valid {
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
}
return "system", pgtype.Text{}, pgtype.Text{}
return "system", "", ""
}
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
@ -44,7 +41,6 @@ func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
slog.Warn("audit: failed to write log entry",
"action", p.Action,
"resource_type", p.ResourceType,
"team_id", p.TeamID,
"error", err,
)
}
@ -61,18 +57,26 @@ func marshalMeta(meta map[string]any) []byte {
return b
}
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
func optText(s string) pgtype.Text {
if s == "" {
return pgtype.Text{}
}
return pgtype.Text{String: s, Valid: true}
}
// --- Sandbox events (scope: team) ---
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID, template string) {
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: pgtype.Text{String: sandboxID, Valid: true},
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "create",
Scope: "team",
Status: "success",
@ -80,16 +84,16 @@ func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext,
})
}
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID string) {
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: pgtype.Text{String: sandboxID, Valid: true},
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "success",
@ -98,15 +102,15 @@ func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext,
}
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID string) {
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: pgtype.Text{},
ActorName: "",
ResourceType: "sandbox",
ResourceID: pgtype.Text{String: sandboxID, Valid: true},
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "info",
@ -114,16 +118,16 @@ func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID
})
}
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID string) {
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: pgtype.Text{String: sandboxID, Valid: true},
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "resume",
Scope: "team",
Status: "success",
@ -131,16 +135,16 @@ func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext,
})
}
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID string) {
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: pgtype.Text{String: sandboxID, Valid: true},
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "destroy",
Scope: "team",
Status: "warning",
@ -156,10 +160,10 @@ func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: pgtype.Text{String: name, Valid: true},
ResourceID: optText(name),
Action: "create",
Scope: "team",
Status: "success",
@ -173,10 +177,10 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: pgtype.Text{String: name, Valid: true},
ResourceID: optText(name),
Action: "delete",
Scope: "team",
Status: "warning",
@ -186,16 +190,16 @@ func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext
// --- Team events (scope: team) ---
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID, oldName, newName string) {
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "team",
ResourceID: pgtype.Text{String: teamID, Valid: true},
ResourceID: optText(id.FormatTeamID(teamID)),
Action: "rename",
Scope: "team",
Status: "info",
@ -205,16 +209,16 @@ func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, te
// --- API key events (scope: team) ---
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID, keyName string) {
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: pgtype.Text{String: keyID, Valid: true},
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "create",
Scope: "team",
Status: "success",
@ -222,16 +226,16 @@ func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext,
})
}
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID string) {
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: pgtype.Text{String: keyID, Valid: true},
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "revoke",
Scope: "team",
Status: "warning",
@ -241,16 +245,16 @@ func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext,
// --- Member events (scope: admin) ---
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID, targetEmail, role string) {
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: pgtype.Text{String: targetUserID, Valid: true},
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "add",
Scope: "admin",
Status: "success",
@ -258,16 +262,16 @@ func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, tar
})
}
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID string) {
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: pgtype.Text{String: targetUserID, Valid: true},
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "remove",
Scope: "admin",
Status: "warning",
@ -277,14 +281,18 @@ func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext,
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
actorType, actorID, actorName := actorFields(ac)
resourceID := ""
if ac.UserID.Valid {
resourceID = id.FormatUserID(ac.UserID)
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: pgtype.Text{String: ac.UserID, Valid: ac.UserID != ""},
ResourceID: optText(resourceID),
Action: "leave",
Scope: "admin",
Status: "info",
@ -292,16 +300,16 @@ func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
})
}
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID, newRole string) {
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: pgtype.Text{String: targetUserID, Valid: true},
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "role_update",
Scope: "admin",
Status: "info",
@ -311,24 +319,24 @@ func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthConte
// --- Host events (scope: admin) ---
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID string) {
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
// For shared hosts with no owning team, use the caller's team.
logTeamID := teamID
if logTeamID == "" {
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if logTeamID == "" {
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: pgtype.Text{String: hostID, Valid: true},
ResourceID: optText(id.FormatHostID(hostID)),
Action: "create",
Scope: "admin",
Status: "success",
@ -336,23 +344,23 @@ func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, ho
})
}
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID string) {
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
logTeamID := teamID
if logTeamID == "" {
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if logTeamID == "" {
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: actorID,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: pgtype.Text{String: hostID, Valid: true},
ResourceID: optText(id.FormatHostID(hostID)),
Action: "delete",
Scope: "admin",
Status: "warning",
@ -361,9 +369,8 @@ func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, ho
}
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
// teamID must be non-empty (BYOC hosts only); shared hosts are not logged.
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID string) {
if teamID == "" {
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
@ -371,9 +378,9 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: pgtype.Text{String: hostID, Valid: true},
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_down",
Scope: "admin",
Status: "error",
@ -382,9 +389,8 @@ func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID stri
}
// LogHostMarkedUp records a system-initiated host status transition back to online.
// teamID must be non-empty (BYOC hosts only); shared hosts are not logged.
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string) {
if teamID == "" {
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
@ -392,9 +398,9 @@ func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID string
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: pgtype.Text{String: hostID, Valid: true},
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_up",
Scope: "admin",
Status: "success",

251
internal/auth/cert.go Normal file
View File

@ -0,0 +1,251 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

View File

@ -1,6 +1,10 @@
package auth
import "context"
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
type contextKey int
@ -8,14 +12,14 @@ const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
TeamID string
UserID string // empty when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID string // populated when authenticated via API key; empty for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
TeamID pgtype.UUID
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
@ -43,7 +47,7 @@ const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID string
HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.

View File

@ -5,6 +5,9 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
const jwtExpiry = 6 * time.Hour
@ -23,16 +26,16 @@ type Claims struct {
}
// SignJWT signs a new 6-hour JWT for the given user.
func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) {
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
TeamID: teamID,
TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
@ -70,14 +73,15 @@ type HostClaims struct {
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID string) (string, error) {
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: hostID,
HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
Subject: hostID,
Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},

View File

@ -13,6 +13,11 @@ type Config struct {
ListenAddr string
JWTSecret string
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
OAuthGitHubClientID string
OAuthGitHubClientSecret string
OAuthRedirectURL string
@ -28,9 +33,12 @@ func Load() Config {
return Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
CACert: os.Getenv("WRENN_CA_CERT"),
CAKey: os.Getenv("WRENN_CA_KEY"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),

View File

@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
`
type DeleteAPIKeyParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_
`
type InsertAPIKeyParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) {
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
if err != nil {
return nil, err
@ -126,18 +126,18 @@ ORDER BY k.created_at DESC
`
type ListAPIKeysByTeamWithCreatorRow struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
CreatorEmail string `json:"creator_email"`
}
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) {
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
if err != nil {
return nil, err
@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
`
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error {
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
return err
}

View File

@ -17,11 +17,11 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
type InsertAuditLogParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName pgtype.Text `json:"actor_name"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
@ -60,12 +60,12 @@ LIMIT $7
`
type ListAuditLogsParams struct {
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
Column2 []string `json:"column_2"`
Column3 []string `json:"column_3"`
Column4 []string `json:"column_4"`
Column5 pgtype.Timestamptz `json:"column_5"`
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Limit int32 `json:"limit"`
}

View File

@ -47,8 +47,8 @@ RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID string `json:"id"`
HostID string `json:"host_id"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
@ -76,7 +76,7 @@ const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id string) error {
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
@ -86,7 +86,7 @@ UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID string) error {
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

View File

@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
`
type AddHostTagParams struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
@ -29,16 +29,16 @@ const deleteHost = `-- name: DeleteHost :exec
DELETE FROM hosts WHERE id = $1
`
func (q *Queries) DeleteHost(ctx context.Context, id string) error {
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteHost, id)
return err
}
const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
`
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
row := q.db.QueryRow(ctx, getHost, id)
var i Host
err := row.Scan(
@ -59,18 +59,18 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
ID string `json:"id"`
TeamID pgtype.Text `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
`
func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) {
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
rows, err := q.db.Query(ctx, getHostTags, hostID)
if err != nil {
return nil, err
@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
`
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) {
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
if err != nil {
return nil, err
@ -157,16 +157,16 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
`
type InsertHostParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.Text `json:"team_id"`
Provider pgtype.Text `json:"provider"`
AvailabilityZone pgtype.Text `json:"availability_zone"`
CreatedBy string `json:"created_by"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at
`
type InsertHostTokenParams struct {
ID string `json:"id"`
HostID string `json:"host_id"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
}
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
}
const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -411,10 +411,10 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
if err != nil {
return nil, err
@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
}
const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -500,7 +500,7 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
return err
}
@ -509,31 +509,35 @@ const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id string) error {
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
status = 'online',
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
updated_at = NOW()
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
ID string `json:"id"`
Arch pgtype.Text `json:"arch"`
CpuCores pgtype.Int4 `json:"cpu_cores"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
DiskGb pgtype.Int4 `json:"disk_gb"`
Address pgtype.Text `json:"address"`
ID pgtype.UUID `json:"id"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
arg.MemoryMb,
arg.DiskGb,
arg.Address,
arg.CertFingerprint,
arg.CertExpiresAt,
)
if err != nil {
return 0, err
@ -556,8 +562,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
`
type RemoveHostTagParams struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
@ -565,11 +571,30 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
return err
}
const updateHostCert = `-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1
`
type UpdateHostCertParams struct {
ID pgtype.UUID `json:"id"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
return err
}
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
`
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
return err
}
@ -584,7 +609,7 @@ WHERE id = $1
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
// Returns 0 if no host was found (deleted), which the caller treats as 404.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) {
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
if err != nil {
return 0, err
@ -597,8 +622,8 @@ UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`
type UpdateHostStatusParams struct {
ID string `json:"id"`
Status string `json:"status"`
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {

View File

@ -7,6 +7,8 @@ package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
@ -14,7 +16,7 @@ DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1
`
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID string) error {
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
return err
}
@ -25,8 +27,8 @@ WHERE sandbox_id = $1 AND tier = $2
`
type DeleteSandboxMetricPointsByTierParams struct {
SandboxID string `json:"sandbox_id"`
Tier string `json:"tier"`
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
}
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
@ -53,7 +55,7 @@ type GetLiveMetricsRow struct {
// Reads directly from sandboxes for accurate real-time current values.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID string) (GetLiveMetricsRow, error) {
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
var i GetLiveMetricsRow
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
@ -76,7 +78,7 @@ type GetPeakMetricsRow struct {
PeakMemoryMb int32 `json:"peak_memory_mb"`
}
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID string) (GetPeakMetricsRow, error) {
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
var i GetPeakMetricsRow
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
@ -91,9 +93,9 @@ ORDER BY ts ASC
`
type GetSandboxMetricPointsParams struct {
SandboxID string `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
}
type GetSandboxMetricPointsRow struct {
@ -134,10 +136,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertMetricsSnapshotParams struct {
TeamID string `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
@ -157,12 +159,12 @@ ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
`
type InsertSandboxMetricPointParams struct {
SandboxID string `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
@ -210,10 +212,10 @@ GROUP BY team_id
`
type SampleSandboxMetricsRow struct {
TeamID string `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Aggregates per-team resource usage from the live sandboxes table.

View File

@ -9,18 +9,18 @@ import (
)
type AdminPermission struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type AuditLog struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName pgtype.Text `json:"actor_name"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
@ -31,29 +31,29 @@ type AuditLog struct {
}
type Host struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.Text `json:"team_id"`
Provider pgtype.Text `json:"provider"`
AvailabilityZone pgtype.Text `json:"availability_zone"`
Arch pgtype.Text `json:"arch"`
CpuCores pgtype.Int4 `json:"cpu_cores"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
DiskGb pgtype.Int4 `json:"disk_gb"`
Address pgtype.Text `json:"address"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
MtlsEnabled bool `json:"mtls_enabled"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
type HostRefreshToken struct {
ID string `json:"id"`
HostID string `json:"host_id"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
@ -61,14 +61,14 @@ type HostRefreshToken struct {
}
type HostTag struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
type HostToken struct {
ID string `json:"id"`
HostID string `json:"host_id"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
UsedAt pgtype.Timestamptz `json:"used_at"`
@ -77,40 +77,43 @@ type HostToken struct {
type OauthProvider struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID string `json:"user_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Sandbox struct {
ID string `json:"id"`
HostID string `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
type SandboxMetricPoint struct {
SandboxID string `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type SandboxMetricsSnapshot struct {
ID int64 `json:"id"`
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
SampledAt pgtype.Timestamptz `json:"sampled_at"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
@ -118,21 +121,21 @@ type SandboxMetricsSnapshot struct {
}
type Team struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
IsByoc bool `json:"is_byoc"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type TeamApiKey struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
}
@ -140,26 +143,50 @@ type TeamApiKey struct {
type Template struct {
Name string `json:"name"`
Type string `json:"type"`
Vcpus pgtype.Int4 `json:"vcpus"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
ID pgtype.UUID `json:"id"`
}
type TemplateBuild struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs []byte `json:"logs"`
Error string `json:"error"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
CompletedAt pgtype.Timestamptz `json:"completed_at"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
type User struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
IsAdmin bool `json:"is_admin"`
Name string `json:"name"`
}
type UsersTeam struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
CreatedAt pgtype.Timestamptz `json:"created_at"`

View File

@ -7,6 +7,8 @@ package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getOAuthProvider = `-- name: GetOAuthProvider :one
@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID string `json:"user_id"`
Email string `json:"email"`
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
}
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {

View File

@ -15,12 +15,12 @@ const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::text[]) AND status = 'missing'
WHERE id = ANY($1::uuid[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []string) error {
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
@ -29,12 +29,12 @@ const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::text[])
WHERE id = ANY($1::uuid[])
`
type BulkUpdateStatusByIDsParams struct {
Column1 []string `json:"column_1"`
Status string `json:"status"`
Column1 []pgtype.UUID `json:"column_1"`
Status string `json:"status"`
}
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
@ -43,38 +43,41 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu
}
const getSandbox = `-- name: GetSandbox :one
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
`
func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) {
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandbox, id)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
`
type GetSandboxByTeamParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
@ -82,38 +85,70 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
SELECT s.status, h.address AS host_address
FROM sandboxes s
JOIN hosts h ON h.id = s.host_id
WHERE s.id = $1 AND s.team_id = $2
`
type GetSandboxProxyTargetParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
type GetSandboxProxyTargetRow struct {
Status string `json:"status"`
HostAddress string `json:"host_address"`
}
// Returns the sandbox status and its host's address in one query.
// Used by SandboxProxyWrapper to avoid two round-trips.
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
var i GetSandboxProxyTargetRow
err := row.Scan(&i.Status, &i.HostAddress)
return i, err
}
const insertSandbox = `-- name: InsertSandbox :one
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type InsertSandboxParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
HostID string `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
@ -126,34 +161,40 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S
arg.Vcpus,
arg.MemoryMb,
arg.TimeoutSec,
arg.DiskSizeMb,
arg.TemplateID,
arg.TemplateTeamID,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC
`
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
if err != nil {
return nil, err
@ -164,19 +205,22 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string)
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -189,7 +233,7 @@ func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID string)
}
const listSandboxes = `-- name: ListSandboxes :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
`
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
@ -203,19 +247,22 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -228,14 +275,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
}
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE host_id = $1 AND status = ANY($2::text[])
ORDER BY created_at DESC
`
type ListSandboxesByHostAndStatusParams struct {
HostID string `json:"host_id"`
Column2 []string `json:"column_2"`
HostID pgtype.UUID `json:"host_id"`
Column2 []string `json:"column_2"`
}
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
@ -249,19 +296,22 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -274,12 +324,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
}
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
ORDER BY created_at DESC
`
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
if err != nil {
return nil, err
@ -290,19 +340,22 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -324,7 +377,7 @@ WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID string) error {
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
@ -337,7 +390,7 @@ WHERE id = $1
`
type UpdateLastActiveParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
}
@ -355,11 +408,11 @@ SET status = 'running',
last_active_at = $4,
last_updated = NOW()
WHERE id = $1
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxRunningParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
HostIp string `json:"host_ip"`
GuestIp string `json:"guest_ip"`
StartedAt pgtype.Timestamptz `json:"started_at"`
@ -375,19 +428,22 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
@ -397,12 +453,12 @@ UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = $1
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxStatusParams struct {
ID string `json:"id"`
Status string `json:"status"`
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
@ -410,19 +466,22 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}

View File

@ -16,8 +16,8 @@ DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
`
type DeleteTeamMemberParams struct {
TeamID string `json:"team_id"`
UserID string `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
@ -26,7 +26,7 @@ func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberPara
}
const getBYOCTeams = `-- name: GetBYOCTeams :many
SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
`
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
@ -41,9 +41,9 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
if err := rows.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.IsByoc,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
); err != nil {
return nil, err
@ -57,46 +57,46 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
}
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
SELECT t.id, t.name, t.created_at, t.is_byoc, t.slug, t.deleted_at FROM teams t
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
LIMIT 1
`
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) {
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.IsByoc,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeam = `-- name: GetTeam :one
SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE id = $1
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
`
func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getTeam, id)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.IsByoc,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamBySlug = `-- name: GetTeamBySlug :one
SELECT id, name, created_at, is_byoc, slug, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
`
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
@ -105,9 +105,9 @@ func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error)
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.IsByoc,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
@ -122,14 +122,14 @@ ORDER BY ut.created_at
`
type GetTeamMembersRow struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt pgtype.Timestamptz `json:"joined_at"`
}
func (q *Queries) GetTeamMembers(ctx context.Context, teamID string) ([]GetTeamMembersRow, error) {
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
if err != nil {
return nil, err
@ -160,8 +160,8 @@ SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE use
`
type GetTeamMembershipParams struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
@ -186,7 +186,7 @@ ORDER BY ut.created_at
`
type GetTeamsForUserRow struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
@ -195,7 +195,7 @@ type GetTeamsForUserRow struct {
Role string `json:"role"`
}
func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeamsForUserRow, error) {
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
if err != nil {
return nil, err
@ -226,13 +226,13 @@ func (q *Queries) GetTeamsForUser(ctx context.Context, userID string) ([]GetTeam
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, created_at, is_byoc, slug, deleted_at
RETURNING id, name, slug, is_byoc, created_at, deleted_at
`
type InsertTeamParams struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
}
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
@ -241,9 +241,9 @@ func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, e
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.IsByoc,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
@ -255,10 +255,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertTeamMemberParams struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
}
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
@ -276,8 +276,8 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1
`
type SetTeamBYOCParams struct {
ID string `json:"id"`
IsByoc bool `json:"is_byoc"`
ID pgtype.UUID `json:"id"`
IsByoc bool `json:"is_byoc"`
}
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
@ -289,7 +289,7 @@ const softDeleteTeam = `-- name: SoftDeleteTeam :exec
UPDATE teams SET deleted_at = NOW() WHERE id = $1
`
func (q *Queries) SoftDeleteTeam(ctx context.Context, id string) error {
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, softDeleteTeam, id)
return err
}
@ -299,9 +299,9 @@ UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
`
type UpdateMemberRoleParams struct {
TeamID string `json:"team_id"`
UserID string `json:"user_id"`
Role string `json:"role"`
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
Role string `json:"role"`
}
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
@ -314,8 +314,8 @@ UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
`
type UpdateTeamNameParams struct {
ID string `json:"id"`
Name string `json:"name"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {

View File

@ -0,0 +1,241 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: template_builds.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getTemplateBuild = `-- name: GetTemplateBuild :one
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
`
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, getTemplateBuild, id)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type InsertTemplateBuildParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TotalSteps int32 `json:"total_steps"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, insertTemplateBuild,
arg.ID,
arg.Name,
arg.BaseTemplate,
arg.Recipe,
arg.Healthcheck,
arg.Vcpus,
arg.MemoryMb,
arg.TotalSteps,
arg.TemplateID,
arg.TeamID,
arg.SkipPrePost,
)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
`
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
rows, err := q.db.Query(ctx, listTemplateBuilds)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TemplateBuild
for rows.Next() {
var i TemplateBuild
if err := rows.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateBuildError = `-- name: UpdateBuildError :exec
UPDATE template_builds
SET error = $2, status = 'failed', completed_at = NOW()
WHERE id = $1
`
type UpdateBuildErrorParams struct {
ID pgtype.UUID `json:"id"`
Error string `json:"error"`
}
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
return err
}
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
UPDATE template_builds
SET current_step = $2, logs = $3
WHERE id = $1
`
type UpdateBuildProgressParams struct {
ID pgtype.UUID `json:"id"`
CurrentStep int32 `json:"current_step"`
Logs []byte `json:"logs"`
}
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
return err
}
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
UPDATE template_builds
SET sandbox_id = $2, host_id = $3
WHERE id = $1
`
type UpdateBuildSandboxParams struct {
ID pgtype.UUID `json:"id"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
}
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
return err
}
const updateBuildStatus = `-- name: UpdateBuildStatus :one
UPDATE template_builds
SET status = $2,
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
WHERE id = $1
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type UpdateBuildStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}

View File

@ -12,11 +12,11 @@ import (
)
const deleteTemplate = `-- name: DeleteTemplate :exec
DELETE FROM templates WHERE name = $1
DELETE FROM templates WHERE id = $1
`
func (q *Queries) DeleteTemplate(ctx context.Context, name string) error {
_, err := q.db.Exec(ctx, deleteTemplate, name)
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplate, id)
return err
}
@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2
`
type DeleteTemplateByTeamParams struct {
Name string `json:"name"`
TeamID string `json:"team_id"`
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy
return err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
DELETE FROM templates WHERE team_id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, name)
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
return err
}
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
`
// Check if a global (platform) template exists with the given name.
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
var i Template
err := row.Scan(
&i.Name,
@ -49,19 +60,67 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, id)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByName = `-- name: GetTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
`
type GetTemplateByNameParams struct {
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
}
// Look up a template by team_id and name (exact team match, no global fallback).
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
`
type GetTemplateByTeamParams struct {
Name string `json:"name"`
TeamID string `json:"team_id"`
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
// Platform templates (team_id = 00000000-...) are visible to all teams.
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
var i Template
@ -73,27 +132,30 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const insertTemplate = `-- name: InsertTemplate :one
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
`
type InsertTemplateParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Vcpus pgtype.Int4 `json:"vcpus"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
row := q.db.QueryRow(ctx, insertTemplate,
arg.ID,
arg.Name,
arg.Type,
arg.Vcpus,
@ -110,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams)
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const listTemplates = `-- name: ListTemplates :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
`
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
@ -135,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -147,10 +211,11 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
}
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) {
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
if err != nil {
return nil, err
@ -167,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -179,14 +245,15 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
}
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
`
type ListTemplatesByTeamAndTypeParams struct {
TeamID string `json:"team_id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Type string `json:"type"`
}
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
if err != nil {
@ -204,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
`
// List templates owned by a specific team (NOT including platform templates).
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -216,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
}
const listTemplatesByType = `-- name: ListTemplatesByType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
@ -236,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}

View File

@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
`
type DeleteAdminPermissionParams struct {
UserID string `json:"user_id"`
Permission string `json:"permission"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
`
func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) {
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
if err != nil {
return nil, err
@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm
}
const getAdminUsers = `-- name: GetAdminUsers :many
SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE is_admin = TRUE ORDER BY created_at
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
`
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
@ -71,10 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
&i.Name,
); err != nil {
return nil, err
}
@ -87,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
}
const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE email = $1
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
@ -97,29 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
&i.Name,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, email, password_hash, created_at, updated_at, is_admin, name FROM users WHERE id = $1
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
`
func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) {
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
&i.Name,
)
return i, err
}
@ -131,8 +131,8 @@ SELECT EXISTS(
`
type HasAdminPermissionParams struct {
UserID string `json:"user_id"`
Permission string `json:"permission"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
@ -148,9 +148,9 @@ VALUES ($1, $2, $3)
`
type InsertAdminPermissionParams struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Permission string `json:"permission"`
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
@ -161,11 +161,11 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm
const insertUser = `-- name: InsertUser :one
INSERT INTO users (id, email, password_hash, name)
VALUES ($1, $2, $3, $4)
RETURNING id, email, password_hash, created_at, updated_at, is_admin, name
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
@ -183,10 +183,10 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
&i.Name,
)
return i, err
}
@ -194,13 +194,13 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e
const insertUserOAuth = `-- name: InsertUserOAuth :one
INSERT INTO users (id, email, name)
VALUES ($1, $2, $3)
RETURNING id, email, password_hash, created_at, updated_at, is_admin, name
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserOAuthParams struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
@ -210,10 +210,10 @@ func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
&i.Name,
)
return i, err
}
@ -223,8 +223,8 @@ SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`
type SearchUsersByEmailPrefixRow struct {
ID string `json:"id"`
Email string `json:"email"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
}
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
@ -252,8 +252,8 @@ UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
`
type SetUserAdminParams struct {
ID string `json:"id"`
IsAdmin bool `json:"is_admin"`
ID pgtype.UUID `json:"id"`
IsAdmin bool `json:"is_admin"`
}
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
@ -266,8 +266,8 @@ UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
`
type UpdateUserNameParams struct {
ID string `json:"id"`
Name string `json:"name"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {

View File

@ -116,9 +116,10 @@ type SnapshotDevice struct {
// writable CoW layer.
//
// The origin loop device must already exist (from LoopRegistry.Acquire).
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
// Create sparse CoW file sized to match the origin.
if err := createSparseFile(cowPath, originSizeBytes); err != nil {
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
// Create sparse CoW file. The logical size limits how many blocks can be
// modified; because the file is sparse, only written blocks use real disk.
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
return nil, fmt.Errorf("create cow file: %w", err)
}
@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64)
return nil, fmt.Errorf("losetup cow: %w", err)
}
// The dm-snapshot virtual device size must match the origin — the snapshot
// target maps 1:1 onto origin sectors. The CoW file just needs enough
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
@ -220,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error {
"if="+dmDevPath,
"of="+outputPath,
"bs=4M",
"conv=sparse",
"status=none",
)
if out, err := cmd.CombinedOutput(); err != nil {

View File

@ -3,14 +3,12 @@ package envdclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"net/url"
"time"
"connectrpc.com/connect"
@ -49,35 +47,6 @@ func (c *Client) BaseURL() string {
return c.base
}
// Init calls POST /init on envd to sync the guest clock with the host.
// This is important after snapshot resume where the guest clock is frozen.
func (c *Client) Init(ctx context.Context) error {
now := time.Now().UTC()
body, err := json.Marshal(map[string]any{"timestamp": now})
if err != nil {
return fmt.Errorf("marshal init body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create init request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("init request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// ExecResult holds the output of a command execution.
type ExecResult struct {
Stdout []byte

View File

@ -0,0 +1,42 @@
package hostagent
import (
"crypto/tls"
"fmt"
"sync/atomic"
)
// CertStore provides lock-free read/write access to the agent's current TLS
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
// of the agent's cert on JWT refresh without restarting the server.
//
// The zero value is usable; GetCert returns an error until a cert is stored.
type CertStore struct {
ptr atomic.Pointer[tls.Certificate]
}
// Store atomically replaces the current certificate.
func (s *CertStore) Store(cert *tls.Certificate) {
s.ptr.Store(cert)
}
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
// If parsing fails the existing cert is unchanged.
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return fmt.Errorf("parse TLS key pair: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
// been stored yet.
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no TLS certificate available")
}
return cert, nil
}

View File

@ -0,0 +1,89 @@
package hostagent
import (
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
)
// ProxyHandler reverse-proxies HTTP requests to services running inside
// sandboxes. It handles requests of the form:
//
// /proxy/{sandbox_id}/{port}/{path...}
//
// The sandbox's HostIP (routable on this machine) is used as the upstream.
// This supports any protocol that rides on HTTP, including WebSocket upgrades.
type ProxyHandler struct {
mgr *sandbox.Manager
transport http.RoundTripper
}
// NewProxyHandler creates a new sandbox proxy handler.
func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
return &ProxyHandler{
mgr: mgr,
transport: http.DefaultTransport,
}
}
// ServeHTTP implements http.Handler.
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Expected path: /proxy/{sandbox_id}/{port}/...
// After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
if trimmed == r.URL.Path {
http.Error(w, "invalid proxy path", http.StatusBadRequest)
return
}
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 2 {
http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
return
}
sandboxID := parts[0]
port := parts[1]
remainder := ""
if len(parts) == 3 {
remainder = parts[2]
}
// Validate port is a number in the valid range.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
if !ok {
http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
return
}
defer tracker.Release()
targetHost := fmt.Sprintf("%s:%d", hostIP, portNum)
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = targetHost
req.URL.Path = "/" + remainder
req.URL.RawQuery = r.URL.RawQuery
req.Host = targetHost
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}

View File

@ -17,18 +17,24 @@ import (
"golang.org/x/sys/unix"
)
// tokenFile is the JSON format persisted to AGENT_FILES_ROOTDIR/host.jwt.
type tokenFile struct {
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
// It holds all credentials the agent needs: the host JWT, refresh token, and
// (when mTLS is enabled) the TLS certificate material for the agent's server.
type TokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
// mTLS fields — empty when the CP has no CA configured.
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane
TokenFile string // Path to persist the host JWT after registration
TokenFile string // Path to persist the credentials after registration
Address string // Externally-reachable address (ip:port) for this host
}
@ -41,22 +47,20 @@ type registerRequest struct {
Address string `json:"address"`
}
type registerResponse struct {
// authResponse is the shared JSON shape for both register and refresh responses.
type authResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type errorResponse struct {
Error struct {
Code string `json:"code"`
@ -64,8 +68,8 @@ type errorResponse struct {
} `json:"error"`
}
// loadTokenFile reads and parses the persisted token file.
func loadTokenFile(path string) (*tokenFile, error) {
// LoadTokenFile reads and parses the persisted credentials file.
func LoadTokenFile(path string) (*TokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) {
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
}
var tf tokenFile
var tf TokenFile
if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse token file: %w", err)
return nil, fmt.Errorf("parse credentials file: %w", err)
}
return &tf, nil
}
// saveTokenFile writes the token file as JSON with 0600 permissions.
func saveTokenFile(path string, tf tokenFile) error {
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
func saveTokenFile(path string, tf TokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
return fmt.Errorf("marshal token file: %w", err)
return fmt.Errorf("marshal credentials file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists
// the returned JWT and refresh token to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// If no explicit registration token was given, reuse the saved JWT.
// the returned credentials to disk. Returns the full TokenFile on success.
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
// If no explicit registration token was given, reuse the saved credentials.
// A --register flag always overrides the local file so operators can
// force re-registration without manually deleting host.jwt.
// force re-registration without manually deleting the credentials file.
if cfg.RegistrationToken == "" {
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf.JWT, nil
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf, nil
}
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
}
arch := runtime.GOARCH
@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal registration request: %w", err)
return nil, fmt.Errorf("marshal registration request: %w", err)
}
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create registration request: %w", err)
return nil, fmt.Errorf("create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("registration request failed: %w", err)
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read registration response: %w", err)
return nil, fmt.Errorf("read registration response: %w", err)
}
if resp.StatusCode != http.StatusCreated {
var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil {
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
}
var regResp registerResponse
var regResp authResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return "", fmt.Errorf("parse registration response: %w", err)
return nil, fmt.Errorf("parse registration response: %w", err)
}
if regResp.Token == "" {
return "", fmt.Errorf("registration response missing token")
return nil, fmt.Errorf("registration response missing token")
}
hostID, err := hostIDFromJWT(regResp.Token)
if err != nil {
return "", fmt.Errorf("extract host ID from JWT: %w", err)
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
}
// Persist JWT + refresh token.
tf := tokenFile{
tf := TokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
CertPEM: regResp.CertPEM,
KeyPEM: regResp.KeyPEM,
CACertPEM: regResp.CACertPEM,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return "", fmt.Errorf("save host token: %w", err)
return nil, fmt.Errorf("save host credentials: %w", err)
}
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
return regResp.Token, nil
return &tf, nil
}
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
// It reads and updates the token file in place.
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
tf, err := loadTokenFile(tokenFilePath)
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
// is updated in place. Returns the updated TokenFile.
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
tf, err := LoadTokenFile(credentialsFilePath)
if err != nil {
return "", fmt.Errorf("load token file: %w", err)
return nil, fmt.Errorf("load credentials file: %w", err)
}
if tf.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available; host must re-register")
return nil, fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create refresh request: %w", err)
return nil, fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
var refResp refreshResponse
var refResp authResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
return "", fmt.Errorf("parse refresh response: %w", err)
return nil, fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
return "", fmt.Errorf("save refreshed token: %w", err)
if refResp.CertPEM != "" {
tf.CertPEM = refResp.CertPEM
tf.KeyPEM = refResp.KeyPEM
tf.CACertPEM = refResp.CACertPEM
}
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
return nil, fmt.Errorf("save refreshed credentials: %w", err)
}
slog.Info("host JWT refreshed", "host_id", tf.HostID)
return refResp.Token, nil
slog.Info("host credentials refreshed", "host_id", tf.HostID)
return tf, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
//
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
//
// onDeleted is called when CP returns 404, meaning this host record was deleted.
// The token file is removed before calling onDeleted so subsequent starts prompt
// for a new registration token.
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) {
// The credentials file is removed before calling onDeleted so subsequent starts
// prompt for a new registration token.
//
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
client := &http.Client{Timeout: 10 * time.Second}
go func() {
@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure := false
currentJWT := ""
// Load the current JWT from disk.
if tf, err := loadTokenFile(tokenFilePath); err == nil {
// Load the current JWT from the credentials file.
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
currentJWT = tf.JWT
}
@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
if refreshErr != nil {
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
// Stop the heartbeat loop — operator must re-register.
return true
}
currentJWT = newJWT
slog.Info("heartbeat: JWT refreshed successfully")
currentJWT = newCreds.JWT
slog.Info("heartbeat: credentials refreshed successfully")
if onCredsRefreshed != nil {
onCredsRefreshed(newCreds)
}
case http.StatusNotFound:
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting")
if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove token file", "error", err)
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
}
if onDeleted != nil {
onDeleted()
@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) {
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the token file loader.
// the credentials file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {

View File

@ -12,6 +12,8 @@ import (
"time"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@ -33,13 +35,35 @@ func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
return &Server{mgr: mgr, terminate: terminate}
}
// parseUUIDString parses a UUID hex string into a pgtype.UUID.
// An empty string yields an all-zeros UUID (valid).
func parseUUIDString(s string) (pgtype.UUID, error) {
if s == "" {
return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil
}
parsed, err := uuid.Parse(s)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err)
}
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
}
func (s *Server) CreateSandbox(
ctx context.Context,
req *connect.Request[pb.CreateSandboxRequest],
) (*connect.Response[pb.CreateSandboxResponse], error) {
msg := req.Msg
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
@ -90,12 +114,21 @@ func (s *Server) CreateSnapshot(
ctx context.Context,
req *connect.Request[pb.CreateSnapshotRequest],
) (*connect.Response[pb.CreateSnapshotResponse], error) {
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
}
return connect.NewResponse(&pb.CreateSnapshotResponse{
Name: req.Msg.Name,
SizeBytes: sizeBytes,
}), nil
}
@ -104,12 +137,45 @@ func (s *Server) DeleteSnapshot(
ctx context.Context,
req *connect.Request[pb.DeleteSnapshotRequest],
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
}
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
}
func (s *Server) FlattenRootfs(
ctx context.Context,
req *connect.Request[pb.FlattenRootfsRequest],
) (*connect.Response[pb.FlattenRootfsResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
}
return connect.NewResponse(&pb.FlattenRootfsResponse{
SizeBytes: sizeBytes,
}), nil
}
func (s *Server) PingSandbox(
ctx context.Context,
req *connect.Request[pb.PingSandboxRequest],
@ -400,7 +466,8 @@ func (s *Server) ListSandboxes(
infos[i] = &pb.SandboxInfo{
SandboxId: sb.ID,
Status: string(sb.Status),
Template: sb.Template,
TeamId: uuid.UUID(sb.TemplateTeamID).String(),
TemplateId: uuid.UUID(sb.TemplateID).String(),
Vcpus: int32(sb.VCPUs),
MemoryMb: int32(sb.MemoryMB),
HostIp: sb.HostIP.String(),

View File

@ -4,8 +4,167 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const (
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
)
var base36Base = big.NewInt(36)
// --- Generation ---
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
func newUUID() pgtype.UUID {
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
}
func NewSandboxID() pgtype.UUID { return newUUID() }
func NewUserID() pgtype.UUID { return newUUID() }
func NewTeamID() pgtype.UUID { return newUUID() }
func NewAPIKeyID() pgtype.UUID { return newUUID() }
func NewHostID() pgtype.UUID { return newUUID() }
func NewHostTokenID() pgtype.UUID { return newUUID() }
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
func NewAuditLogID() pgtype.UUID { return newUUID() }
func NewBuildID() pgtype.UUID { return newUUID() }
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
func NewTemplateID() pgtype.UUID { return newUUID() }
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
func NewTeamSlug() string {
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
return hexToken(32)
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
func NewRefreshToken() string {
return hexToken(32)
}
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
const (
PrefixSandbox = "cl-"
PrefixUser = "usr-"
PrefixTeam = "team-"
PrefixAPIKey = "key-"
PrefixHost = "host-"
PrefixHostToken = "htok-"
PrefixRefreshToken = "hrt-"
PrefixAuditLog = "log-"
PrefixBuild = "bld-"
PrefixAdminPermission = "perm-"
)
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
func UUIDToBase36(b [16]byte) string {
n := new(big.Int).SetBytes(b[:])
buf := make([]byte, base36IDLen)
mod := new(big.Int)
for i := base36IDLen - 1; i >= 0; i-- {
n.DivMod(n, base36Base, mod)
buf[i] = base36Alphabet[mod.Int64()]
}
return string(buf)
}
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
func base36ToUUID(s string) ([16]byte, error) {
if len(s) != base36IDLen {
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
}
n := new(big.Int)
for _, c := range s {
idx := strings.IndexRune(base36Alphabet, c)
if idx < 0 {
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
}
n.Mul(n, base36Base)
n.Add(n, big.NewInt(int64(idx)))
}
b := n.Bytes()
var out [16]byte
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
copy(out[16-len(b):], b)
return out, nil
}
func formatUUID(prefix string, id pgtype.UUID) string {
return prefix + UUIDToBase36(id.Bytes)
}
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
func parseUUID(prefix, s string) (pgtype.UUID, error) {
if !strings.HasPrefix(s, prefix) {
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
}
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
}
return pgtype.UUID{Bytes: b, Valid: true}, nil
}
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
// --- Well-known IDs ---
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
// (e.g. base templates, shared infrastructure).
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
// template. When both team_id and template_id are zero, the host agent uses
// the minimal rootfs at WRENN_DIR/images/minimal/.
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
func UUIDString(id pgtype.UUID) string {
return uuid.UUID(id.Bytes).String()
}
// --- Helpers ---
func hex8() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@ -14,73 +173,8 @@ func hex8() string {
return hex.EncodeToString(b)
}
// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars.
func NewSandboxID() string {
return "sb-" + hex8()
}
// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewUserID generates a new user ID in the format "usr-" + 8 hex chars.
func NewUserID() string {
return "usr-" + hex8()
}
// NewTeamID generates a new team ID in the format "team-" + 8 hex chars.
func NewTeamID() string {
return "team-" + hex8()
}
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy"
// where each part is 3 random bytes encoded as hex (6 hex chars each).
func NewTeamSlug() string {
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
}
// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars.
func NewAPIKeyID() string {
return "key-" + hex8()
}
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
func NewHostID() string {
return "host-" + hex8()
}
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
func NewHostTokenID() string {
return "htok-" + hex8()
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
// NewRefreshTokenID generates a new refresh token record ID in the format "hrt-" + 8 hex chars.
func NewRefreshTokenID() string {
return "hrt-" + hex8()
}
// NewAuditLogID generates a new audit log ID in the format "log-" + 8 hex chars.
func NewAuditLogID() string {
return "log-" + hex8()
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy) for use as a host refresh token.
func NewRefreshToken() string {
b := make([]byte, 32)
func hexToken(nBytes int) string {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}

118
internal/id/id_test.go Normal file
View File

@ -0,0 +1,118 @@
package id
import (
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
func TestBase36RoundTrip(t *testing.T) {
for i := 0; i < 1000; i++ {
orig := uuid.New()
encoded := UUIDToBase36(orig)
if len(encoded) != base36IDLen {
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != orig {
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
}
}
}
func TestBase36ZeroUUID(t *testing.T) {
var zero [16]byte
encoded := UUIDToBase36(zero)
if encoded != "0000000000000000000000000" {
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != zero {
t.Fatalf("round-trip failed for zero UUID")
}
}
func TestFormatParseRoundTrip(t *testing.T) {
id := NewSandboxID()
formatted := FormatSandboxID(id)
if formatted[:3] != "cl-" {
t.Fatalf("expected cl- prefix, got %s", formatted)
}
if len(formatted) != 3+base36IDLen {
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
}
parsed, err := ParseSandboxID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != id {
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
}
}
func TestBase36InvalidInput(t *testing.T) {
// Wrong length.
if _, err := base36ToUUID("abc"); err == nil {
t.Fatal("expected error for short input")
}
// Invalid character.
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
t.Fatal("expected error for invalid character")
}
}
func TestPlatformTeamIDFormats(t *testing.T) {
formatted := FormatTeamID(PlatformTeamID)
parsed, err := ParseTeamID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != PlatformTeamID {
t.Fatalf("platform team ID round-trip failed")
}
}
func TestMaxUUID(t *testing.T) {
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
encoded := UUIDToBase36(max)
if len(encoded) != base36IDLen {
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != max {
t.Fatalf("round-trip failed for max UUID")
}
}
func BenchmarkFormatSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
b.ResetTimer()
for i := 0; i < b.N; i++ {
FormatSandboxID(id)
}
}
func BenchmarkParseSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
s := FormatSandboxID(id)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseSandboxID(s)
}
}

58
internal/layout/layout.go Normal file
View File

@ -0,0 +1,58 @@
package layout
import (
"path/filepath"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// IsMinimal reports whether the given team and template IDs represent the
// built-in "minimal" template (both all-zeros).
func IsMinimal(teamID, templateID pgtype.UUID) bool {
return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
}
// TemplateDir returns the on-disk directory for a template.
//
// minimal (zeros, zeros): {wrennDir}/images/minimal
// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
if IsMinimal(teamID, templateID) {
return filepath.Join(wrennDir, "images", "minimal")
}
return filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(teamID.Bytes),
id.UUIDToBase36(templateID.Bytes))
}
// TemplateRootfs returns the path to a template's rootfs.ext4.
func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
}
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
func PauseSnapshotDir(wrennDir, sandboxID string) string {
return filepath.Join(wrennDir, "snapshots", sandboxID)
}
// SandboxesDir returns the directory for running sandbox CoW files.
func SandboxesDir(wrennDir string) string {
return filepath.Join(wrennDir, "sandboxes")
}
// KernelPath returns the path to the Firecracker kernel.
func KernelPath(wrennDir string) string {
return filepath.Join(wrennDir, "kernels", "vmlinux")
}
// ImagesRoot returns the root images directory.
func ImagesRoot(wrennDir string) string {
return filepath.Join(wrennDir, "images")
}
// TeamsDir returns the directory containing all team template subdirectories.
func TeamsDir(wrennDir string) string {
return filepath.Join(wrennDir, "images", "teams")
}

View File

@ -0,0 +1,120 @@
package layout
import (
"path/filepath"
"testing"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
func TestIsMinimal(t *testing.T) {
tests := []struct {
name string
teamID pgtype.UUID
templateID pgtype.UUID
want bool
}{
{
name: "both zeros",
teamID: id.PlatformTeamID,
templateID: id.MinimalTemplateID,
want: true,
},
{
name: "non-zero team",
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
templateID: id.MinimalTemplateID,
want: false,
},
{
name: "non-zero template",
teamID: id.PlatformTeamID,
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
want: false,
},
{
name: "both non-zero",
teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
}
})
}
}
func TestTemplateDir(t *testing.T) {
wrennDir := "/var/lib/wrenn"
t.Run("minimal", func(t *testing.T) {
got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal")
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
t.Run("team template", func(t *testing.T) {
teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true}
got := TemplateDir(wrennDir, teamID, tmplID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(teamID.Bytes),
id.UUIDToBase36(tmplID.Bytes))
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
t.Run("global template (platform team, non-zero template)", func(t *testing.T) {
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true}
got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(id.PlatformTeamID.Bytes),
id.UUIDToBase36(tmplID.Bytes))
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
}
func TestTemplateRootfs(t *testing.T) {
wrennDir := "/var/lib/wrenn"
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
if got != want {
t.Errorf("TemplateRootfs() = %q, want %q", got, want)
}
}
func TestPauseSnapshotDir(t *testing.T) {
got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
want := "/var/lib/wrenn/snapshots/cl-abc123"
if got != want {
t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
}
}
func TestSandboxesDir(t *testing.T) {
got := SandboxesDir("/var/lib/wrenn")
want := "/var/lib/wrenn/sandboxes"
if got != want {
t.Errorf("SandboxesDir() = %q, want %q", got, want)
}
}
func TestKernelPath(t *testing.T) {
got := KernelPath("/var/lib/wrenn")
want := "/var/lib/wrenn/kernels/vmlinux"
if got != want {
t.Errorf("KernelPath() = %q, want %q", got, want)
}
}

View File

@ -1,6 +1,7 @@
package lifecycle
import (
"crypto/tls"
"fmt"
"net/http"
"strings"
@ -8,6 +9,7 @@ import (
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
@ -18,14 +20,33 @@ type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
scheme string // "http://" or "https://"
}
// NewHostClientPool creates a new pool. The underlying HTTP client uses a
// 10-minute timeout to support long-running streaming operations.
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
// Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
scheme: "http://",
}
}
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
// tlsCfg should already carry the CP client cert and CA trust anchor
// (use auth.CPClientTLSConfig to construct it).
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
transport := &http.Transport{
TLSClientConfig: tlsCfg,
}
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{
Timeout: 10 * time.Minute,
Transport: transport,
},
scheme: "https://",
}
}
@ -45,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, ensureScheme(address))
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c
return c
}
@ -53,10 +74,10 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if !h.Address.Valid || h.Address.String == "" {
return nil, fmt.Errorf("host %s has no address", h.ID)
if h.Address == "" {
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
}
return p.Get(h.ID, h.Address.String), nil
return p.Get(id.FormatHostID(h.ID), h.Address), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
@ -68,8 +89,35 @@ func (p *HostClientPool) Evict(hostID string) {
p.mu.Unlock()
}
// ensureScheme adds "http://" if the address has no scheme.
func ensureScheme(addr string) string {
// ensureScheme prepends the pool's configured scheme if the address has none.
func (p *HostClientPool) ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return p.scheme + addr
}
// Transport returns the http.RoundTripper used by this pool. Use this when you
// need to make raw HTTP requests to agent addresses with the same TLS settings
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
func (p *HostClientPool) Transport() http.RoundTripper {
if p.httpClient.Transport != nil {
return p.httpClient.Transport
}
return http.DefaultTransport
}
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
// Use this when constructing URLs that must use the same transport as the pool
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
func (p *HostClientPool) ResolveAddr(addr string) string {
return p.ensureScheme(addr)
}
// EnsureScheme adds "http://" if the address has no scheme.
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}

Some files were not shown because too many files have changed in this diff Show More