1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/sandbox#8
This commit is contained in:
2026-04-09 19:24:49 +00:00
parent 32e5a5a715
commit d3e4812e46
199 changed files with 24552 additions and 2776 deletions

View File

@ -0,0 +1,22 @@
package api
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)
}
return pool.GetForHost(host)
}

View File

@ -0,0 +1,230 @@
package api
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"path"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
)
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
// without relying on error message text.
var (
errProxySandboxNotFound = errors.New("sandbox not found")
errProxyNoHostAddress = errors.New("host agent has no address")
)
const proxyCacheTTL = 120 * time.Second
// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
// errProxySandboxNotRunning carries the sandbox status so callers can include
// it in the HTTP response without parsing error strings.
type errProxySandboxNotRunning struct{ status string }
func (e errProxySandboxNotRunning) Error() string {
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
}
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
// can capture the correct port without the cache key needing to include it.
type proxyCacheEntry struct {
agentURL *url.URL
expiresAt time.Time
}
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
type proxyCacheKey [32]byte
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
var k proxyCacheKey
copy(k[:16], sandboxID.Bytes[:])
copy(k[16:], teamID.Bytes[:])
return k
}
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
// requests are reverse-proxied through the host agent that owns the sandbox.
// All other requests are passed through to the inner handler.
//
// Authentication is via X-API-Key header only (no JWT). The API key's team
// must own the sandbox.
type SandboxProxyWrapper struct {
inner http.Handler
db *db.Queries
pool *lifecycle.HostClientPool
transport http.RoundTripper
cacheMu sync.Mutex
cache map[proxyCacheKey]proxyCacheEntry
}
// NewSandboxProxyWrapper creates a new proxy wrapper.
func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
return &SandboxProxyWrapper{
inner: inner,
db: queries,
pool: pool,
transport: pool.Transport(),
cache: make(map[proxyCacheKey]proxyCacheEntry),
}
}
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
// On a miss it queries the DB, resolves the address, and populates the cache.
// The *httputil.ReverseProxy is built by the caller so the Director closure
// captures the correct port without the cache key needing to include it.
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
cacheKey := makeProxyCacheKey(sandboxID, teamID)
h.cacheMu.Lock()
entry, ok := h.cache[cacheKey]
h.cacheMu.Unlock()
if ok && time.Now().Before(entry.expiresAt) {
return entry.agentURL, nil
}
// Cache miss or expired — query DB.
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
ID: sandboxID,
TeamID: teamID,
})
if err != nil {
return nil, errProxySandboxNotFound
}
if target.Status != "running" {
return nil, errProxySandboxNotRunning{status: target.Status}
}
if target.HostAddress == "" {
return nil, errProxyNoHostAddress
}
agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
if err != nil {
return nil, fmt.Errorf("invalid host agent address: %w", err)
}
h.cacheMu.Lock()
h.cache[cacheKey] = proxyCacheEntry{
agentURL: agentURL,
expiresAt: time.Now().Add(proxyCacheTTL),
}
h.cacheMu.Unlock()
return agentURL, nil
}
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
h.cacheMu.Lock()
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
h.cacheMu.Unlock()
}
func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Host
// Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
host = host[:colonIdx]
}
matches := sandboxHostPattern.FindStringSubmatch(host)
if matches == nil {
h.inner.ServeHTTP(w, r)
return
}
port := matches[1]
sandboxIDStr := matches[2]
// Validate port.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
// Authenticate: require API key or JWT, extract team ID.
teamID, err := h.authenticateRequest(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
return
}
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
return
}
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
if err != nil {
switch {
case errors.Is(err, errProxySandboxNotFound):
http.Error(w, err.Error(), http.StatusNotFound)
case errors.As(err, new(errProxySandboxNotRunning)):
http.Error(w, err.Error(), http.StatusConflict)
default:
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
return
}
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = agentURL.Scheme
req.URL.Host = agentURL.Host
req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path))
req.Host = agentURL.Host
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("sandbox proxy error",
"sandbox_id", sandboxIDStr,
"port", port,
"error", err,
)
h.evictProxyCache(sandboxID, teamID)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}
// authenticateRequest validates the request's API key and returns the team ID.
// Only API key authentication is supported for sandbox proxy requests (not JWT).
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
key := r.Header.Get("X-API-Key")
if key == "" {
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
}
hash := auth.HashAPIKey(key)
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid API key")
}
return row.TeamID, nil
}

View File

@ -6,17 +6,20 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type apiKeyHandler struct {
svc *service.APIKeyService
svc *service.APIKeyService
audit *audit.AuditLogger
}
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
return &apiKeyHandler{svc: svc}
func newAPIKeyHandler(svc *service.APIKeyService, al *audit.AuditLogger) *apiKeyHandler {
return &apiKeyHandler{svc: svc, audit: al}
}
type createAPIKeyRequest struct {
@ -37,11 +40,11 @@ type apiKeyResponse struct {
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
@ -55,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
}
if k.CreatedAt.Valid {
@ -91,6 +94,7 @@ func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
resp := apiKeyToResponse(result.Row)
resp.Key = &result.Plaintext
h.audit.LogAPIKeyCreate(r.Context(), ac, result.Row.ID, result.Row.Name)
writeJSON(w, http.StatusCreated, resp)
}
@ -115,12 +119,19 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /v1/api-keys/{id}.
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
keyID := chi.URLParam(r, "id")
keyIDStr := chi.URLParam(r, "id")
keyID, err := id.ParseAPIKeyID(keyIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
return
}
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
return
}
h.audit.LogAPIKeyRevoke(r.Context(), ac, keyID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,148 @@
package api
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type auditHandler struct {
svc *service.AuditService
}
func newAuditHandler(svc *service.AuditService) *auditHandler {
return &auditHandler{svc: svc}
}
type auditLogResponse struct {
ID string `json:"id"`
ActorType string `json:"actor_type"`
ActorID string `json:"actor_id,omitempty"`
ActorName string `json:"actor_name,omitempty"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id,omitempty"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt string `json:"created_at"`
}
// List handles GET /v1/audit-logs.
// Query params:
// - before: RFC3339 timestamp cursor (exclusive); omit to start from latest
// - limit: page size, default 50, max 200
// - resource_type: filter by resource type (sandbox, snapshot, team, api_key, member, host)
// - action: filter by action verb
//
// Members see only team-scoped events; admins/owners see all.
func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
// Parse ?before cursor.
var before time.Time
if s := r.URL.Query().Get("before"); s != "" {
var err error
before, err = time.Parse(time.RFC3339, s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before must be an RFC3339 timestamp")
return
}
}
// Parse ?limit.
limit := 50
if s := r.URL.Query().Get("limit"); s != "" {
n, err := strconv.Atoi(s)
if err != nil || n < 1 {
writeError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer")
return
}
limit = n
}
// Parse ?before_id cursor (UUID).
var beforeID pgtype.UUID
if s := r.URL.Query().Get("before_id"); s != "" {
parsed, err := id.ParseAuditLogID(s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
return
}
beforeID = parsed
}
entries, err := h.svc.List(r.Context(), service.AuditListParams{
TeamID: ac.TeamID,
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
Actions: parseMultiParam(r.URL.Query()["action"]),
Before: before,
BeforeID: beforeID,
Limit: limit,
})
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list audit logs")
return
}
items := make([]auditLogResponse, len(entries))
for i, e := range entries {
items[i] = auditLogResponse{
ID: e.ID,
ActorType: e.ActorType,
ActorID: e.ActorID,
ActorName: e.ActorName,
ResourceType: e.ResourceType,
ResourceID: e.ResourceID,
Action: e.Action,
Scope: e.Scope,
Status: e.Status,
Metadata: e.Metadata,
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
}
}
resp := map[string]any{"items": items}
if len(items) > 0 {
last := entries[len(entries)-1]
resp["next_before"] = last.CreatedAt.UTC().Format(time.RFC3339)
resp["next_before_id"] = last.ID
}
writeJSON(w, http.StatusOK, resp)
}
// parseMultiParam flattens repeated params and comma-separated values into a
// single deduplicated slice. Empty strings are dropped. Returns nil (no filter)
// when no values are present.
//
// Both ?resource_type=sandbox&resource_type=snapshot
// and ?resource_type=sandbox,snapshot are accepted.
func parseMultiParam(values []string) []string {
if len(values) == 0 {
return nil
}
seen := make(map[string]struct{})
var out []string
for _, v := range values {
for _, part := range strings.Split(v, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if _, ok := seen[part]; !ok {
seen[part] = struct{}{}
out = append(out, part)
}
}
}
return out
}

View File

@ -1,7 +1,9 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
@ -15,6 +17,45 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// loginTeam returns the team and role to stamp into a login JWT.
// It prefers the user's default team; if none is flagged as default it falls
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
// no team memberships at all.
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
team, err := q.GetDefaultTeamForUser(ctx, userID)
if err == nil {
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
if err != nil {
return db.Team{}, "", err
}
return team, membership.Role, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return db.Team{}, "", err
}
// No default set — fall back to earliest-joined team.
rows, err := q.GetTeamsForUser(ctx, userID)
if err != nil {
return db.Team{}, "", err
}
if len(rows) == 0 {
return db.Team{}, "", pgx.ErrNoRows
}
first := rows[0]
return db.Team{
ID: first.ID,
Name: first.Name,
Slug: first.Slug,
IsByoc: first.IsByoc,
CreatedAt: first.CreatedAt,
DeletedAt: first.DeletedAt,
}, first.Role, nil
}
type switchTeamRequest struct {
TeamID string `json:"team_id"`
}
type authHandler struct {
db *db.Queries
pool *pgxpool.Pool
@ -28,6 +69,7 @@ func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authH
type signupRequest struct {
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
}
type loginRequest struct {
@ -40,6 +82,7 @@ type authResponse struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
}
// Signup handles POST /v1/auth/signup.
@ -51,6 +94,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
}
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
req.Name = strings.TrimSpace(req.Name)
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
return
@ -59,6 +103,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
return
}
if req.Name == "" || len(req.Name) > 100 {
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
return
}
ctx := r.Context()
@ -83,6 +131,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
ID: userID,
Email: req.Email,
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
Name: req.Name,
})
if err != nil {
var pgErr *pgconn.PgError
@ -98,7 +147,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
teamID := id.NewTeamID()
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: req.Email + "'s Team",
Name: req.Name + "'s Team",
Slug: id.NewTeamSlug(),
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
return
@ -119,7 +169,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -127,9 +177,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, authResponse{
Token: token,
UserID: userID,
TeamID: teamID,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
})
}
@ -152,6 +203,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
user, err := h.db.GetUserByEmail(ctx, req.Email)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("login failed: unknown email", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
@ -160,21 +212,27 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
}
if !user.PasswordHash.Valid {
slog.Warn("login failed: no password set", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
slog.Warn("login failed: wrong password", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -182,8 +240,85 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: user.ID,
TeamID: team.ID,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
}
// SwitchTeam handles POST /v1/auth/switch-team.
// Verifies from DB that the user is a member of the target team, then re-issues
// a JWT scoped to that team. The JWT's team_id is used as a pre-filter on all
// subsequent team-scoped requests; DB is the source of truth for actual permissions.
func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req switchTeamRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.TeamID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "team_id is required")
return
}
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
ctx := r.Context()
// Verify team exists and is not deleted.
team, err := h.db.GetTeam(ctx, teamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "team not found")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
return
}
if team.DeletedAt.Valid {
writeError(w, http.StatusNotFound, "not_found", "team not found")
return
}
// Verify membership from DB — JWT role is not trusted here.
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: ac.UserID,
TeamID: teamID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusForbidden, "forbidden", "not a member of this team")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up membership")
return
}
// Fetch current name from DB — JWT name is not trusted here (may be stale or empty for old tokens).
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
})
}

View File

@ -0,0 +1,276 @@
package api
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
type buildHandler struct {
svc *service.BuildService
db *db.Queries
pool *lifecycle.HostClientPool
}
func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
return &buildHandler{svc: svc, db: db, pool: pool}
}
type createBuildRequest struct {
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []string `json:"recipe"`
Healthcheck string `json:"healthcheck"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SkipPrePost bool `json:"skip_pre_post"`
}
type buildResponse struct {
ID string `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe json.RawMessage `json:"recipe"`
Healthcheck *string `json:"healthcheck,omitempty"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs json.RawMessage `json:"logs"`
Error *string `json:"error,omitempty"`
SandboxID *string `json:"sandbox_id,omitempty"`
HostID *string `json:"host_id,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
}
func buildToResponse(b db.TemplateBuild) buildResponse {
resp := buildResponse{
ID: id.FormatBuildID(b.ID),
Name: b.Name,
BaseTemplate: b.BaseTemplate,
Recipe: b.Recipe,
VCPUs: b.Vcpus,
MemoryMB: b.MemoryMb,
Status: b.Status,
CurrentStep: b.CurrentStep,
TotalSteps: b.TotalSteps,
Logs: b.Logs,
}
if b.Healthcheck != "" {
resp.Healthcheck = &b.Healthcheck
}
if b.Error != "" {
resp.Error = &b.Error
}
if b.SandboxID.Valid {
s := id.FormatSandboxID(b.SandboxID)
resp.SandboxID = &s
}
if b.HostID.Valid {
s := id.FormatHostID(b.HostID)
resp.HostID = &s
}
if b.CreatedAt.Valid {
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
}
if b.StartedAt.Valid {
s := b.StartedAt.Time.Format(time.RFC3339)
resp.StartedAt = &s
}
if b.CompletedAt.Valid {
s := b.CompletedAt.Time.Format(time.RFC3339)
resp.CompletedAt = &s
}
return resp
}
// Create handles POST /v1/admin/builds.
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createBuildRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
return
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
if len(req.Recipe) == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
return
}
build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
Name: req.Name,
BaseTemplate: req.BaseTemplate,
Recipe: req.Recipe,
Healthcheck: req.Healthcheck,
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
})
if err != nil {
slog.Error("failed to create build", "error", err)
writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
return
}
writeJSON(w, http.StatusCreated, buildToResponse(build))
}
// List handles GET /v1/admin/builds.
func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
builds, err := h.svc.List(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
return
}
resp := make([]buildResponse, len(builds))
for i, b := range builds {
resp[i] = buildToResponse(b)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/admin/builds/{id}.
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.svc.Get(r.Context(), buildID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "build not found")
return
}
writeJSON(w, http.StatusOK, buildToResponse(build))
}
// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
templates, err := h.db.ListTemplates(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
return
}
type templateResponse struct {
Name string `json:"name"`
Type string `json:"type"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
CreatedAt string `json:"created_at"`
}
resp := make([]templateResponse, len(templates))
for i, t := range templates {
resp[i] = templateResponse{
Name: t.Name,
Type: t.Type,
VCPUs: t.Vcpus,
MemoryMB: t.MemoryMb,
SizeBytes: t.SizeBytes,
TeamID: id.FormatTeamID(t.TeamID),
}
if t.CreatedAt.Valid {
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
}
writeJSON(w, http.StatusOK, resp)
}
// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if err := validate.SafeName(name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
ctx := r.Context()
tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
// Broadcast delete to all online hosts.
hosts, _ := h.db.ListActiveHosts(ctx)
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(tmpl.TeamID),
TemplateId: formatUUIDForRPC(tmpl.ID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
}
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Cancel handles POST /v1/admin/builds/{id}/cancel.
func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
if err := h.svc.Cancel(r.Context(), buildID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,242 @@
package api
import (
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/channels"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type channelHandler struct {
svc *channels.Service
audit *audit.AuditLogger
}
func newChannelHandler(svc *channels.Service, al *audit.AuditLogger) *channelHandler {
return &channelHandler{svc: svc, audit: al}
}
type createChannelRequest struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config map[string]string `json:"config"`
Events []string `json:"events"`
}
type updateChannelRequest struct {
Name string `json:"name"`
Events []string `json:"events"`
}
type rotateConfigRequest struct {
Config map[string]string `json:"config"`
}
type testChannelRequest struct {
Provider string `json:"provider"`
Config map[string]string `json:"config"`
}
type channelResponse struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Events []string `json:"events"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Secret *string `json:"secret,omitempty"`
}
func channelToResponse(ch db.Channel) channelResponse {
resp := channelResponse{
ID: id.FormatChannelID(ch.ID),
TeamID: id.FormatTeamID(ch.TeamID),
Name: ch.Name,
Provider: ch.Provider,
Events: ch.EventTypes,
}
if ch.CreatedAt.Valid {
resp.CreatedAt = ch.CreatedAt.Time.Format(time.RFC3339)
}
if ch.UpdatedAt.Valid {
resp.UpdatedAt = ch.UpdatedAt.Time.Format(time.RFC3339)
}
return resp
}
// Create handles POST /v1/channels.
func (h *channelHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req createChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
result, err := h.svc.Create(r.Context(), channels.CreateParams{
TeamID: ac.TeamID,
Name: req.Name,
Provider: req.Provider,
Config: req.Config,
Events: req.Events,
})
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelCreate(r.Context(), ac, result.Channel.ID, result.Channel.Name, result.Channel.Provider)
resp := channelToResponse(result.Channel)
if result.PlaintextSecret != "" {
resp.Secret = &result.PlaintextSecret
}
writeJSON(w, http.StatusCreated, resp)
}
// List handles GET /v1/channels.
func (h *channelHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
chs, err := h.svc.List(r.Context(), ac.TeamID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list channels")
return
}
resp := make([]channelResponse, len(chs))
for i, ch := range chs {
resp[i] = channelToResponse(ch)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/channels/{id}.
func (h *channelHandler) Get(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
ch, err := h.svc.Get(r.Context(), channelID, ac.TeamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "channel not found")
} else {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get channel")
}
return
}
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Update handles PATCH /v1/channels/{id}.
func (h *channelHandler) Update(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
var req updateChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ch, err := h.svc.Update(r.Context(), channelID, ac.TeamID, req.Name, req.Events)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelUpdate(r.Context(), ac, channelID)
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Test handles POST /v1/channels/test.
func (h *channelHandler) Test(w http.ResponseWriter, r *http.Request) {
var req testChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.Test(r.Context(), req.Provider, req.Config); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// RotateConfig handles PUT /v1/channels/{id}/config.
func (h *channelHandler) RotateConfig(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
var req rotateConfigRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ch, err := h.svc.RotateConfig(r.Context(), channelID, ac.TeamID, req.Config)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelRotateConfig(r.Context(), ac, channelID)
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Delete handles DELETE /v1/channels/{id}.
func (h *channelHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
if err := h.svc.Delete(r.Context(), channelID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete channel")
return
}
h.audit.LogChannelDelete(r.Context(), ac, channelID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -14,17 +14,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
return &execHandler{db: db, agent: agent}
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
return &execHandler{db: db, pool: pool}
}
type execRequest struct {
@ -46,10 +47,16 @@ type execResponse struct {
// Exec handles POST /v1/sandboxes/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -73,8 +80,14 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
start := time.Now()
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
@ -95,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
// Use base64 encoding if output contains non-UTF-8 bytes.
@ -106,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
@ -118,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),

View File

@ -14,17 +14,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
return &execStreamHandler{db: db, agent: agent}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
}
var upgrader = websocket.Upgrader{
@ -48,10 +49,16 @@ type wsOutMsg struct {
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -80,12 +87,18 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendWSError(conn, "sandbox host is not reachable")
return
}
// Open streaming exec to host agent.
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID,
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxIDStr,
Cmd: startMsg.Cmd,
Args: startMsg.Args,
}))
@ -151,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
}

View File

@ -11,17 +11,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
return &filesHandler{db: db, agent: agent}
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
return &filesHandler{db: db, pool: pool}
}
// Upload handles POST /v1/sandboxes/{id}/files/write.
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceCl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -75,8 +82,14 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
return
}
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxIDStr,
Path: filePath,
Content: content,
})); err != nil {
@ -95,10 +108,16 @@ type readFileRequest struct {
// Download handles POST /v1/sandboxes/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -120,8 +139,14 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -12,27 +12,34 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
return &filesStreamHandler{db: db, agent: agent}
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
return &filesStreamHandler{db: db, pool: pool}
}
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -88,14 +95,20 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
}
defer filePart.Close()
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open client-streaming RPC to host agent.
stream := h.agent.WriteFileStream(ctx)
stream := agent.WriteFileStream(ctx)
// Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{
Content: &pb.WriteFileStreamRequest_Meta{
Meta: &pb.WriteFileStreamMeta{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
},
},
@ -140,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -164,9 +183,15 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open server-streaming RPC to host agent.
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID,
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -1,23 +1,30 @@
package api
import (
"errors"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type hostHandler struct {
svc *service.HostService
queries *db.Queries
audit *audit.AuditLogger
}
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
return &hostHandler{svc: svc, queries: queries}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al}
}
// Request/response types.
@ -34,6 +41,24 @@ type createHostResponse struct {
RegistrationToken string `json:"registration_token"`
}
type refreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type deletePreviewResponse struct {
Host hostResponse `json:"host"`
SandboxIDs []string `json:"sandbox_ids"`
}
type registerHostRequest struct {
Token string `json:"token"`
Arch string `json:"arch,omitempty"`
@ -44,8 +69,12 @@ type registerHostRequest struct {
}
type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type addTagRequest struct {
@ -56,6 +85,7 @@ type hostResponse struct {
ID string `json:"id"`
Type string `json:"type"`
TeamID *string `json:"team_id,omitempty"`
TeamName *string `json:"team_name,omitempty"`
Provider *string `json:"provider,omitempty"`
AvailabilityZone *string `json:"availability_zone,omitempty"`
Arch *string `json:"arch,omitempty"`
@ -72,34 +102,35 @@ type hostResponse struct {
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
ID: h.ID,
ID: id.FormatHostID(h.ID),
Type: h.Type,
Status: h.Status,
CreatedBy: h.CreatedBy,
CreatedBy: id.FormatUserID(h.CreatedBy),
}
if h.TeamID.Valid {
resp.TeamID = &h.TeamID.String
s := id.FormatTeamID(h.TeamID)
resp.TeamID = &s
}
if h.Provider.Valid {
resp.Provider = &h.Provider.String
if h.Provider != "" {
resp.Provider = &h.Provider
}
if h.AvailabilityZone.Valid {
resp.AvailabilityZone = &h.AvailabilityZone.String
if h.AvailabilityZone != "" {
resp.AvailabilityZone = &h.AvailabilityZone
}
if h.Arch.Valid {
resp.Arch = &h.Arch.String
if h.Arch != "" {
resp.Arch = &h.Arch
}
if h.CpuCores.Valid {
resp.CPUCores = &h.CpuCores.Int32
if h.CpuCores != 0 {
resp.CPUCores = &h.CpuCores
}
if h.MemoryMb.Valid {
resp.MemoryMB = &h.MemoryMb.Int32
if h.MemoryMb != 0 {
resp.MemoryMB = &h.MemoryMb
}
if h.DiskGb.Valid {
resp.DiskGB = &h.DiskGb.Int32
if h.DiskGb != 0 {
resp.DiskGB = &h.DiskGb
}
if h.Address.Valid {
resp.Address = &h.Address.String
if h.Address != "" {
resp.Address = &h.Address
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
@ -112,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
@ -130,20 +161,32 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
Type: req.Type,
TeamID: req.TeamID,
Provider: req.Provider,
AvailabilityZone: req.AvailabilityZone,
RequestingUserID: ac.UserID,
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
})
// Parse optional team ID from request body.
var params service.HostCreateParams
params.Type = req.Type
params.Provider = req.Provider
params.AvailabilityZone = req.AvailabilityZone
params.RequestingUserID = ac.UserID
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
if req.TeamID != "" {
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
params.TeamID = teamID
}
result, err := h.svc.Create(r.Context(), params)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
RegistrationToken: result.RegistrationToken,
@ -153,16 +196,50 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
// List handles GET /v1/hosts.
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
hosts, err := h.svc.List(r.Context(), ac.TeamID, admin)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
return
}
// Collect unique team IDs so we can fetch team names in one pass.
var teamNames map[string]string
if admin {
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
}
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponse(host)
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
}
writeJSON(w, http.StatusOK, resp)
@ -170,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -183,25 +266,86 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, hostToResponse(host))
}
// Delete handles DELETE /v1/hosts/{id}.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
writeJSON(w, http.StatusOK, deletePreviewResponse{
Host: hostToResponse(preview.Host),
SandboxIDs: preview.SandboxIDs,
})
}
// Delete handles DELETE /v1/hosts/{id}.
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Fetch host before deletion to capture team_id for audit.
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
if hostErr != nil {
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
}
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
w.WriteHeader(http.StatusNoContent)
return
}
// Check if it's a "has running sandboxes" error and return a structured 409.
var hasSandboxes *service.HostHasSandboxesError
if errors.As(err, &hasSandboxes) {
writeJSON(w, http.StatusConflict, map[string]any{
"error": map[string]any{
"code": "has_active_sandboxes",
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
"sandbox_ids": hasSandboxes.SandboxIDs,
},
})
return
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -247,36 +391,61 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusCreated, registerHostResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
return
}
// Capture pre-heartbeat status to detect unreachable → online transition.
prevHost, _ := h.queries.GetHost(r.Context(), hc.HostID)
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Log marked_up if the host just recovered from unreachable.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
}
w.WriteHeader(http.StatusNoContent)
}
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
@ -298,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -311,11 +486,47 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req refreshTokenRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.RefreshToken == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
return
}
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, refreshTokenResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)

View File

@ -0,0 +1,156 @@
package api
import (
"context"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
type sandboxMetricsHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
}
func newSandboxMetricsHandler(db *db.Queries, pool *lifecycle.HostClientPool) *sandboxMetricsHandler {
return &sandboxMetricsHandler{db: db, pool: pool}
}
type metricPointResponse struct {
TimestampUnix int64 `json:"timestamp_unix"`
CPUPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type metricsResponse struct {
SandboxID string `json:"sandbox_id"`
Range string `json:"range"`
Points []metricPointResponse `json:"points"`
}
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
rangeTier := r.URL.Query().Get("range")
if rangeTier == "" {
rangeTier = "10m"
}
validRanges := map[string]bool{"5m": true, "10m": true, "1h": true, "2h": true, "6h": true, "12h": true, "24h": true}
if !validRanges[rangeTier] {
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 10m, 1h, 2h, 6h, 12h, 24h")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
switch sb.Status {
case "running":
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
case "paused":
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
default:
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
}
}
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
ctx := r.Context()
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
SandboxId: sandboxIDStr,
Range: rangeTier,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
points := make([]metricPointResponse, len(resp.Msg.Points))
for i, p := range resp.Msg.Points {
points[i] = metricPointResponse{
TimestampUnix: p.TimestampUnix,
CPUPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
}
// rangeToDB maps a user-facing range filter to the DB tier and cutoff duration.
var rangeToDB = map[string]struct {
tier string
cutoff time.Duration
}{
"5m": {"10m", 5 * time.Minute},
"10m": {"10m", 10 * time.Minute},
"1h": {"2h", 1 * time.Hour},
"2h": {"2h", 2 * time.Hour},
"6h": {"24h", 6 * time.Hour},
"12h": {"24h", 12 * time.Hour},
"24h": {"24h", 24 * time.Hour},
}
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
mapping := rangeToDB[rangeTier]
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
SandboxID: sandboxID,
Tier: mapping.tier,
Ts: time.Now().Add(-mapping.cutoff).Unix(),
})
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read metrics")
return
}
points := make([]metricPointResponse, len(rows))
for i, row := range rows {
points[i] = metricPointResponse{
TimestampUnix: row.Ts,
CPUPct: row.CpuPct,
MemBytes: row.MemBytes,
DiskBytes: row.DiskBytes,
}
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
}

View File

@ -150,19 +150,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
slog.Error("oauth login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -199,6 +199,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
ID: userID,
Email: email,
Name: profile.Name,
})
if err != nil {
var pgErr *pgconn.PgError
@ -219,6 +220,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: teamName,
Slug: id.NewTeamSlug(),
}); err != nil {
slog.Error("oauth: failed to create team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
@ -253,14 +255,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -282,29 +284,39 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
slog.Error("oauth: retry login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
u := base + "?" + url.Values{
"token": {token},
"user_id": {userID},
"team_id": {teamID},
"email": {email},
}.Encode()
http.Redirect(w, r, u, http.StatusFound)
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
// Set auth data as short-lived cookies instead of URL query parameters.
// This prevents token leakage via server access logs, Referer headers, and browser history.
for _, c := range []http.Cookie{
{Name: "wrenn_oauth_token", Value: token},
{Name: "wrenn_oauth_user_id", Value: userID},
{Name: "wrenn_oauth_team_id", Value: teamID},
{Name: "wrenn_oauth_email", Value: email},
{Name: "wrenn_oauth_name", Value: name},
} {
c.Path = "/auth/"
c.MaxAge = 60
c.HttpOnly = false // frontend JS must read these
c.SameSite = http.SameSiteLaxMode
c.Secure = isSecure(r)
http.SetCookie(w, &c)
}
http.Redirect(w, r, base, http.StatusFound)
}
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {

View File

@ -7,17 +7,20 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type sandboxHandler struct {
svc *service.SandboxService
svc *service.SandboxService
audit *audit.AuditLogger
}
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
return &sandboxHandler{svc: svc}
func newSandboxHandler(svc *service.SandboxService, al *audit.AuditLogger) *sandboxHandler {
return &sandboxHandler{svc: svc, audit: al}
}
type createSandboxRequest struct {
@ -44,7 +47,7 @@ type sandboxResponse struct {
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
resp := sandboxResponse{
ID: sb.ID,
ID: id.FormatSandboxID(sb.ID),
Status: sb.Status,
Template: sb.Template,
VCPUs: sb.Vcpus,
@ -79,6 +82,10 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
}
ac := auth.MustFromContext(r.Context())
if !ac.TeamID.Valid {
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
return
}
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
TeamID: ac.TeamID,
@ -93,6 +100,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
}
@ -115,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/sandboxes/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -129,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
// Pause handles POST /v1/sandboxes/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -139,14 +159,21 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Resume handles POST /v1/sandboxes/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -154,14 +181,21 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Ping handles POST /v1/sandboxes/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -173,14 +207,21 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
// Destroy handles DELETE /v1/sandboxes/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
@ -9,25 +10,57 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, agent: agent}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
return nil
}
type createSnapshotRequest struct {
@ -42,6 +75,7 @@ type snapshotResponse struct {
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
}
func templateToResponse(t db.Template) snapshotResponse {
@ -49,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
Name: t.Name,
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
}
if t.Vcpus.Valid {
resp.VCPUs = &t.Vcpus.Int32
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
}
if t.MemoryMb.Valid {
resp.MemoryMB = &t.MemoryMb.Int32
if t.MemoryMb != 0 {
resp.MemoryMB = &t.MemoryMb
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -75,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
@ -87,16 +128,21 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old files from the agent before removing the DB record.
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
// Delete old snapshot files from all hosts before removing the DB record.
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
@ -106,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
@ -116,30 +162,59 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
if sb.Status != "paused" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: req.SandboxID, Status: "paused",
}); err != nil {
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
}
}
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
})
@ -149,6 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
}
@ -181,16 +262,23 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
// Platform templates can only be deleted by admins via /v1/admin/templates.
if tmpl.TeamID == id.PlatformTeamID {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
Name: name,
})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete snapshot files: "+msg)
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}
@ -199,5 +287,6 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSnapshotDelete(r.Context(), ac, name)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,95 @@
package api
import (
"log/slog"
"net/http"
"time"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type statsHandler struct {
svc *service.StatsService
}
func newStatsHandler(svc *service.StatsService) *statsHandler {
return &statsHandler{svc: svc}
}
type statsCurrentResponse struct {
RunningCount int32 `json:"running_count"`
VCPUsReserved int32 `json:"vcpus_reserved"`
MemoryMBReserved int32 `json:"memory_mb_reserved"`
}
type statsPeaksResponse struct {
RunningCount int32 `json:"running_count"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
}
type statsSeriesResponse struct {
Labels []string `json:"labels"`
Running []int32 `json:"running"`
VCPUs []int32 `json:"vcpus"`
MemoryMB []int32 `json:"memory_mb"`
}
type statsResponse struct {
Range string `json:"range"`
Current statsCurrentResponse `json:"current"`
Peaks statsPeaksResponse `json:"peaks"`
Series statsSeriesResponse `json:"series"`
}
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
rangeParam := r.URL.Query().Get("range")
if rangeParam == "" {
rangeParam = string(service.Range1h)
}
tr := service.TimeRange(rangeParam)
if !service.ValidRange(tr) {
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 1h, 6h, 24h, 30d")
return
}
current, peaks, series, err := h.svc.GetStats(r.Context(), ac.TeamID, tr)
if err != nil {
slog.Error("stats handler: get stats failed", "team_id", ac.TeamID, "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to retrieve stats")
return
}
resp := statsResponse{
Range: rangeParam,
Current: statsCurrentResponse{
RunningCount: current.RunningCount,
VCPUsReserved: current.VCPUsReserved,
MemoryMBReserved: current.MemoryMBReserved,
},
Peaks: statsPeaksResponse{
RunningCount: peaks.RunningCount,
VCPUs: peaks.VCPUs,
MemoryMB: peaks.MemoryMB,
},
Series: statsSeriesResponse{
Labels: make([]string, len(series)),
Running: make([]int32, len(series)),
VCPUs: make([]int32, len(series)),
MemoryMB: make([]int32, len(series)),
},
}
for i, pt := range series {
resp.Series.Labels[i] = pt.Bucket.UTC().Format(time.RFC3339)
resp.Series.Running[i] = pt.RunningCount
resp.Series.VCPUs[i] = pt.VCPUsReserved
resp.Series.MemoryMB[i] = pt.MemoryMBReserved
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -0,0 +1,390 @@
package api
import (
"log/slog"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type teamHandler struct {
svc *service.TeamService
audit *audit.AuditLogger
}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
return &teamHandler{svc: svc, audit: al}
}
// teamResponse is the JSON shape for a team.
type teamResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt string `json:"created_at"`
}
// teamWithRoleResponse includes the calling user's role.
type teamWithRoleResponse struct {
teamResponse
Role string `json:"role"`
}
type memberResponse struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt string `json:"joined_at,omitempty"`
}
func teamToResponse(t db.Team) teamResponse {
resp := teamResponse{
ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
return resp
}
func memberInfoToResponse(m service.MemberInfo) memberResponse {
return memberResponse{
UserID: m.UserID,
Name: m.Name,
Email: m.Email,
Role: m.Role,
JoinedAt: m.JoinedAt.Format(time.RFC3339),
}
}
// requireTeamAccess is an inline check used by every team-scoped handler:
// the JWT team_id must match the URL {id} before any DB call is made.
// Returns false and writes 403 if they don't match.
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return pgtype.UUID{}, false
}
if ac.TeamID != teamID {
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
return pgtype.UUID{}, false
}
return teamID, true
}
// List handles GET /v1/teams
// Returns all teams the authenticated user belongs to.
func (h *teamHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teams, err := h.svc.ListTeamsForUser(r.Context(), ac.UserID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
resp := make([]teamWithRoleResponse, len(teams))
for i, t := range teams {
resp[i] = teamWithRoleResponse{
teamResponse: teamToResponse(t.Team),
Role: t.Role,
}
}
writeJSON(w, http.StatusOK, resp)
}
// Create handles POST /v1/teams
// Creates a new team owned by the authenticated user.
func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req struct {
Name string `json:"name"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Name = strings.TrimSpace(req.Name)
team, err := h.svc.CreateTeam(r.Context(), ac.UserID, req.Name)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
teamResponse: teamToResponse(team.Team),
Role: team.Role,
})
}
// Get handles GET /v1/teams/{id}
// Returns team info and member list.
func (h *teamHandler) Get(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
team, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
members, err := h.svc.GetMembers(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
memberResp := make([]memberResponse, len(members))
for i, m := range members {
memberResp[i] = memberInfoToResponse(m)
}
writeJSON(w, http.StatusOK, map[string]any{
"team": teamToResponse(team),
"members": memberResp,
})
}
// Rename handles PATCH /v1/teams/{id}
// Renames the team. Requires admin or owner role (verified from DB).
func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
var req struct {
Name string `json:"name"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Name = strings.TrimSpace(req.Name)
// Fetch old name for audit log before renaming.
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogTeamRename(r.Context(), ac, teamID, oldTeam.Name, req.Name)
w.WriteHeader(http.StatusNoContent)
}
// Delete handles DELETE /v1/teams/{id}
// Soft-deletes the team and destroys active sandboxes. Owner only.
func (h *teamHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
if err := h.svc.DeleteTeam(r.Context(), teamID, ac.UserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// ListMembers handles GET /v1/teams/{id}/members
func (h *teamHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
members, err := h.svc.GetMembers(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
resp := make([]memberResponse, len(members))
for i, m := range members {
resp[i] = memberInfoToResponse(m)
}
writeJSON(w, http.StatusOK, resp)
}
// AddMember handles POST /v1/teams/{id}/members
// Adds a user by email. Requires admin or owner (verified from DB).
func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
var req struct {
Email string `json:"email"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
if req.Email == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "email is required")
return
}
member, err := h.svc.AddMember(r.Context(), teamID, ac.UserID, req.Email)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// member.UserID is already formatted with prefix; parse it back for the audit logger.
targetUserID, parseErr := id.ParseUserID(member.UserID)
if parseErr == nil {
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
}
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
// RemoveMember handles DELETE /v1/teams/{id}/members/{uid}
// Removes a member. Requires admin or owner (verified from DB). Owner cannot be removed.
func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberRemove(r.Context(), ac, targetUserID)
w.WriteHeader(http.StatusNoContent)
}
// UpdateMemberRole handles PATCH /v1/teams/{id}/members/{uid}
// Changes a member's role (admin or member). Owner's role cannot be changed.
func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Role string `json:"role"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.UpdateMemberRole(r.Context(), teamID, ac.UserID, targetUserID, req.Role); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
w.WriteHeader(http.StatusNoContent)
}
// Leave handles POST /v1/teams/{id}/leave
// Removes the calling user from the team. Owner cannot leave.
func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
if err := h.svc.LeaveTeam(r.Context(), teamID, ac.UserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberLeave(r.Context(), ac)
w.WriteHeader(http.StatusNoContent)
}
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
// Enables or disables the BYOC feature flag for a team.
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return
}
var req struct {
Enabled bool `json:"enabled"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,52 @@
package api
import (
"net/http"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type usersHandler struct {
db *db.Queries
}
func newUsersHandler(db *db.Queries) *usersHandler {
return &usersHandler{db: db}
}
// Search handles GET /v1/users/search?email=<prefix>
// Returns up to 10 users whose email starts with the given prefix.
// The prefix must be at least 3 characters long and contain "@".
func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
auth.MustFromContext(r.Context()) // ensure authenticated
prefix := strings.TrimSpace(r.URL.Query().Get("email"))
if len(prefix) < 3 || !strings.Contains(prefix, "@") {
writeError(w, http.StatusBadRequest, "invalid_request", "email prefix must be at least 3 characters and contain '@'")
return
}
// Escape LIKE metacharacters to prevent pattern injection.
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(prefix)
results, err := h.db.SearchUsersByEmailPrefix(r.Context(), pgtype.Text{String: escaped, Valid: true})
if err != nil {
writeError(w, http.StatusInternalServerError, "internal", "search failed")
return
}
type userResult struct {
UserID string `json:"user_id"`
Email string `json:"email"`
}
resp := make([]userResult, len(results))
for i, u := range results {
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -0,0 +1,216 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
// "unreachable" and marks their active sandboxes as "missing".
//
// 2. Active reconciliation: for each online host, calls ListSandboxes and
// reconciles DB state against live host state — restoring "missing"
// sandboxes that are actually alive, and stopping orphaned ones.
type HostMonitor struct {
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
interval time.Duration
}
// NewHostMonitor creates a HostMonitor.
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger, interval time.Duration) *HostMonitor {
return &HostMonitor{
db: queries,
pool: pool,
audit: al,
interval: interval,
}
}
// Start runs the monitor loop until the context is cancelled.
func (m *HostMonitor) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
// Run immediately on startup so the CP doesn't wait one full interval
// before reconciling host and sandbox state.
m.run(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.run(ctx)
}
}
}()
}
func (m *HostMonitor) run(ctx context.Context) {
hosts, err := m.db.ListActiveHosts(ctx)
if err != nil {
slog.Warn("host monitor: failed to list hosts", "error", err)
return
}
for _, host := range hosts {
m.checkHost(ctx, host)
}
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
stale := !host.LastHeartbeatAt.Valid ||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
}
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
return
}
// --- Active reconciliation: only for online hosts ---
if host.Status != "online" {
return
}
agent, err := m.pool.GetForHost(host)
if err != nil {
// Host has no address yet (e.g., just registered) — skip.
return
}
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
for _, sb := range missingSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sbIDStr]; ok {
continue
}
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}

View File

@ -0,0 +1,68 @@
package api
import (
"context"
"log/slog"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// MetricsSampler records per-team sandbox resource usage to
// sandbox_metrics_snapshots every interval. It also prunes rows older than
// 60 days on each tick to keep the table bounded.
type MetricsSampler struct {
db *db.Queries
interval time.Duration
}
// NewMetricsSampler creates a MetricsSampler.
func NewMetricsSampler(queries *db.Queries, interval time.Duration) *MetricsSampler {
return &MetricsSampler{db: queries, interval: interval}
}
// Start runs the sampler loop until the context is cancelled.
func (s *MetricsSampler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
// Sample immediately on startup.
s.run(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.run(ctx)
}
}
}()
}
func (s *MetricsSampler) run(ctx context.Context) {
s.prune(ctx)
if err := s.sample(ctx); err != nil {
slog.Warn("metrics sampler: sample failed", "error", err)
}
}
func (s *MetricsSampler) sample(ctx context.Context) error {
rows, err := s.db.SampleSandboxMetrics(ctx)
if err != nil {
return err
}
for _, row := range rows {
if err := s.db.InsertMetricsSnapshot(ctx, db.InsertMetricsSnapshotParams(row)); err != nil {
slog.Warn("metrics sampler: insert snapshot failed", "team_id", row.TeamID, "error", err)
}
}
return nil
}
func (s *MetricsSampler) prune(ctx context.Context) {
if err := s.db.PruneOldMetrics(ctx); err != nil {
slog.Warn("metrics sampler: prune failed", "error", err)
}
}

View File

@ -12,6 +12,9 @@ import (
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type errorResponse struct {
@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
})
}
// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
func formatUUIDForRPC(u pgtype.UUID) string {
return id.UUIDString(u)
}
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
func agentErrToHTTP(err error) (int, string, string) {
switch connect.CodeOf(err) {
@ -87,8 +95,12 @@ func serviceErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "conflict:"):
return http.StatusConflict, "conflict", msg
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
default:

View File

@ -0,0 +1,30 @@
package api
import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}
user, err := queries.GetUserByID(r.Context(), ac.UserID)
if err != nil || !user.IsAdmin {
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,38 +0,0 @@
package api
import (
"log/slog"
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
// and stamps TeamID into the request context.
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("X-API-Key")
if key == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
return
}
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
// Best-effort update of last_used timestamp.
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -7,6 +7,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@ -19,15 +20,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
@ -37,14 +43,28 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
tokenStr := strings.TrimPrefix(header, "Bearer ")
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
if err != nil {
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return

View File

@ -4,6 +4,7 @@ import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
hostID, err := id.ParseHostID(claims.HostID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
@ -25,11 +26,26 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
Email: claims.Email,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
IsAdmin: claims.IsAdmin,
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

File diff suppressed because it is too large Load Diff

View File

@ -1,126 +0,0 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/internal/db"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// Reconciler periodically compares the host agent's sandbox list with the DB
// and marks sandboxes that no longer exist on the host as stopped.
type Reconciler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
hostID string
interval time.Duration
}
// NewReconciler creates a new reconciler.
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
return &Reconciler{
db: db,
agent: agent,
hostID: hostID,
interval: interval,
}
}
// Start runs the reconciliation loop until the context is cancelled.
func (rc *Reconciler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(rc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.reconcile(ctx)
}
}
}()
}
func (rc *Reconciler) reconcile(ctx context.Context) {
// Single RPC returns both the running sandbox list and any IDs that
// were auto-paused by the TTL reaper since the last call.
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
return
}
// Build a set of sandbox IDs that are alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
// Build auto-paused set from the same response.
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPausedSet[id] = struct{}{}
}
// Get all DB sandboxes for this host that are running.
// Paused sandboxes are excluded: they are expected to not exist on the
// host agent because pause = snapshot + destroy resources.
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: rc.hostID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
return
}
// Find sandboxes in DB that are no longer on the host.
var stale []string
for _, sb := range dbSandboxes {
if _, ok := alive[sb.ID]; !ok {
stale = append(stale, sb.ID)
}
}
if len(stale) == 0 {
return
}
// Split stale sandboxes into those auto-paused by the TTL reaper vs
// those that crashed/were orphaned.
var toPause, toStop []string
for _, id := range stale {
if _, ok := autoPausedSet[id]; ok {
toPause = append(toPause, id)
} else {
toStop = append(toStop, id)
}
}
if len(toPause) > 0 {
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
}
}
if len(toStop) > 0 {
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
}
}
}

View File

@ -9,10 +9,14 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/channels"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
//go:embed openapi.yaml
@ -20,30 +24,54 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
router chi.Router
router chi.Router
BuildSvc *service.BuildService
}
// New constructs the chi router and registers all routes.
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
func New(
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
pgPool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
ca *auth.CA,
al *audit.AuditLogger,
channelSvc *channels.Service,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
// Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
sandbox := newSandboxHandler(sandboxSvc)
exec := newExecHandler(queries, agent)
execStream := newExecStreamHandler(queries, agent)
files := newFilesHandler(queries, agent)
filesStream := newFilesStreamHandler(queries, agent)
snapshots := newSnapshotHandler(templateSvc, queries, agent)
authH := newAuthHandler(queries, pool, jwtSecret)
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc)
hostH := newHostHandler(hostSvc, queries)
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
authH := newAuthHandler(queries, pgPool, jwtSecret)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc, al)
hostH := newHostHandler(hostSvc, queries, al)
teamH := newTeamHandler(teamSvc, al)
usersH := newUsersHandler(queries)
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
channelH := newChannelHandler(channelSvc, al)
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@ -55,6 +83,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
// JWT-authenticated: switch active team.
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
// JWT-authenticated: API key management.
r.Route("/v1/api-keys", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
@ -63,11 +94,32 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Delete("/{id}", apiKeys.Delete)
})
// JWT-authenticated: team management.
r.Route("/v1/teams", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Get("/", teamH.List)
r.Post("/", teamH.Create)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", teamH.Get)
r.Patch("/", teamH.Rename)
r.Delete("/", teamH.Delete)
r.Get("/members", teamH.ListMembers)
r.Post("/members", teamH.AddMember)
r.Patch("/members/{uid}", teamH.UpdateMemberRole)
r.Delete("/members/{uid}", teamH.RemoveMember)
r.Post("/leave", teamH.Leave)
})
})
// JWT-authenticated: user search (for add-member UI).
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
// Sandbox lifecycle: accepts API key or JWT bearer token.
r.Route("/v1/sandboxes", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", sandbox.Get)
@ -81,6 +133,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Post("/files/read", files.Download)
r.Post("/files/stream/write", filesStream.StreamUpload)
r.Post("/files/stream/read", filesStream.StreamDownload)
r.Get("/metrics", metricsH.GetMetrics)
})
})
@ -97,6 +150,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
// Unauthenticated: one-time registration token.
r.Post("/register", hostH.Register)
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
@ -108,6 +164,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Route("/{id}", func(r chi.Router) {
r.Get("/", hostH.Get)
r.Delete("/", hostH.Delete)
r.Get("/delete-preview", hostH.DeletePreview)
r.Post("/token", hostH.RegenerateToken)
r.Get("/tags", hostH.ListTags)
r.Post("/tags", hostH.AddTag)
@ -116,7 +173,37 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
})
})
return &Server{router: r}
// JWT-authenticated: notification channels.
r.Route("/v1/channels", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Post("/", channelH.Create)
r.Get("/", channelH.List)
r.Post("/test", channelH.Test)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", channelH.Get)
r.Patch("/", channelH.Update)
r.Delete("/", channelH.Delete)
r.Put("/config", channelH.RotateConfig)
})
})
// JWT-authenticated: audit log.
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
// Platform admin routes — require JWT + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireAdmin(queries))
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
})
return &Server{router: r, BuildSvc: buildSvc}
}
// Handler returns the HTTP handler.
@ -137,7 +224,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrenn Sandbox API</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
<style>
body { margin: 0; background: #fafafa; }
.swagger-ui .topbar { display: none; }
@ -145,7 +232,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script src="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui-bundle.js" integrity="sha384-NXtFPpN61oWCuN4D42K6Zd5Rt2+uxeIT36R7kpXBuY9tLnZorzrJ4ykpqwJfgjpZ" crossorigin="anonymous"></script>
<script>
SwaggerUIBundle({
url: "/openapi.yaml",

569
internal/audit/logger.go Normal file
View File

@ -0,0 +1,569 @@
package audit
import (
"context"
"encoding/json"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/events"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// AuditLogger writes audit log entries for user-initiated and system events.
// All methods are fire-and-forget: failures are logged via slog and never
// propagated to the caller.
type AuditLogger struct {
db *db.Queries
pub events.EventPublisher // optional — nil disables event publishing
}
// New constructs an AuditLogger without event publishing.
func New(queries *db.Queries) *AuditLogger {
return &AuditLogger{db: queries}
}
// NewWithPublisher constructs an AuditLogger that also publishes channel events.
func NewWithPublisher(queries *db.Queries, pub events.EventPublisher) *AuditLogger {
return &AuditLogger{db: queries, pub: pub}
}
// publish sends an event to the notification stream if a publisher is configured.
func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
if l.pub != nil {
l.pub.Publish(ctx, e)
}
}
// actorToEvent converts auth context fields to an events.Actor.
func actorToEvent(ac auth.AuthContext) events.Actor {
at, aid, aname := actorFields(ac)
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
}
// systemActor returns an events.Actor for system-initiated events.
func systemActor() events.Actor {
return events.Actor{Type: events.ActorSystem}
}
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
// actor_id is stored as a prefixed string in the TEXT column.
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
if ac.UserID.Valid {
return "user", id.FormatUserID(ac.UserID), ac.Name
}
if ac.APIKeyID.Valid {
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
}
return "system", "", ""
}
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
if err := l.db.InsertAuditLog(ctx, p); err != nil {
slog.Warn("audit: failed to write log entry",
"action", p.Action,
"resource_type", p.ResourceType,
"error", err,
)
}
}
func marshalMeta(meta map[string]any) []byte {
if len(meta) == 0 {
return []byte("{}")
}
b, err := json.Marshal(meta)
if err != nil {
return []byte("{}")
}
return b
}
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
func optText(s string) pgtype.Text {
if s == "" {
return pgtype.Text{}
}
return pgtype.Text{String: s, Valid: true}
}
// --- Sandbox events (scope: team) ---
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"template": template}),
})
l.publish(ctx, events.Event{
Event: events.CapsuleCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "resume",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleRunning,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "destroy",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleDestroyed,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// --- Snapshot events (scope: team) ---
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "create",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotDeleted,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
// --- Team events (scope: team) ---
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "team",
ResourceID: optText(id.FormatTeamID(teamID)),
Action: "rename",
Scope: "team",
Status: "info",
Metadata: marshalMeta(map[string]any{"old_name": oldName, "new_name": newName}),
})
}
// --- Channel events (scope: team) ---
func (l *AuditLogger) LogChannelCreate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID, name, provider string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": name, "provider": provider}),
})
}
func (l *AuditLogger) LogChannelUpdate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "update",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelRotateConfig(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "rotate_config",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelDelete(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- API key events (scope: team) ---
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": keyName}),
})
}
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "revoke",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- Member events (scope: admin) ---
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "add",
Scope: "admin",
Status: "success",
Metadata: marshalMeta(map[string]any{"email": targetEmail, "role": role}),
})
}
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "remove",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
actorType, actorID, actorName := actorFields(ac)
resourceID := ""
if ac.UserID.Valid {
resourceID = id.FormatUserID(ac.UserID)
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(resourceID),
Action: "leave",
Scope: "admin",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "role_update",
Scope: "admin",
Status: "info",
Metadata: marshalMeta(map[string]any{"new_role": newRole}),
})
}
// --- Host events (scope: admin) ---
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
// For shared hosts with no owning team, use the caller's team.
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "create",
Scope: "admin",
Status: "success",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "delete",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
// Scoped to "team" so BYOC team members can see when their hosts go down.
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_down",
Scope: "team",
Status: "error",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostDown,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}
// LogHostMarkedUp records a system-initiated host status transition back to online.
// Scoped to "team" so BYOC team members can see when their hosts recover.
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_up",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostUp,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}

251
internal/auth/cert.go Normal file
View File

@ -0,0 +1,251 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

View File

@ -1,6 +1,10 @@
package auth
import "context"
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
type contextKey int
@ -8,9 +12,14 @@ const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
TeamID string
UserID string // empty when authenticated via API key
Email string // empty when authenticated via API key
TeamID pgtype.UUID
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
@ -38,7 +47,7 @@ const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID string
HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.

View File

@ -5,27 +5,37 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 8760 * time.Hour // 1 year
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens.
type Claims struct {
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Email string `json:"email"`
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Role string `json:"role"` // owner, admin, or member within TeamID
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
jwt.RegisteredClaims
}
// SignJWT signs a new 6-hour JWT for the given user.
func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
TeamID: teamID,
Email: email,
TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID,
Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
@ -63,14 +73,15 @@ type HostClaims struct {
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID string) (string, error) {
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: hostID,
HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
Subject: hostID,
Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},

View File

@ -0,0 +1,63 @@
package channels
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
// Returns base64(nonce || ciphertext).
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptSecret decrypts a value produced by EncryptSecret.
func DecryptSecret(key [32]byte, encoded string) (string, error) {
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@ -0,0 +1,36 @@
package channels
import (
"context"
"encoding/json"
"fmt"
"github.com/containrrr/shoutrrr"
"git.omukk.dev/wrenn/sandbox/internal/events"
)
// Deliver sends a notification to a single provider with the given config.
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
payload, err := json.Marshal(e)
if err != nil {
return fmt.Errorf("marshal event: %w", err)
}
if provider == "webhook" {
wh := NewWebhookDelivery()
return wh.Deliver(ctx, config["url"], config["secret"], payload)
}
shoutrrrURL, err := ShoutrrrURL(provider, config)
if err != nil {
return fmt.Errorf("build shoutrrr URL: %w", err)
}
msg := FormatMessage(e)
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
return fmt.Errorf("shoutrrr send: %w", err)
}
return nil
}

View File

@ -0,0 +1,183 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/events"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
const (
groupName = "wrenn-channels-v1"
consumerName = "cp-0"
)
// Dispatcher consumes events from the Redis stream and delivers them
// to matching notification channels.
type Dispatcher struct {
rdb *redis.Client
db *db.Queries
encKey [32]byte
webhook *WebhookDelivery
}
// NewDispatcher constructs an event dispatcher.
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
return &Dispatcher{
rdb: rdb,
db: queries,
encKey: encKey,
webhook: NewWebhookDelivery(),
}
}
// Start launches the consumer goroutine. Returns when ctx is cancelled.
func (d *Dispatcher) Start(ctx context.Context) {
go d.run(ctx)
}
func (d *Dispatcher) run(ctx context.Context) {
// Create consumer group idempotently. "$" means only new messages.
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
if err != nil && !isGroupExistsError(err) {
slog.Error("channels: failed to create consumer group", "error", err)
return
}
for {
select {
case <-ctx.Done():
return
default:
}
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: groupName,
Consumer: consumerName,
Streams: []string{streamKey, ">"},
Count: 10,
Block: 5 * time.Second,
}).Result()
if err != nil {
if err == redis.Nil || ctx.Err() != nil {
continue
}
slog.Warn("channels: xreadgroup error", "error", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, msg := range stream.Messages {
d.handleMessage(ctx, msg)
}
}
}
}
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
defer func() {
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
}
}()
payload, ok := msg.Values["payload"].(string)
if !ok {
slog.Warn("channels: message missing payload", "id", msg.ID)
return
}
var event events.Event
if err := json.Unmarshal([]byte(payload), &event); err != nil {
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
return
}
teamID, err := id.ParseTeamID(event.TeamID)
if err != nil {
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
return
}
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
TeamID: teamID,
EventType: event.Event,
})
if err != nil {
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
return
}
for _, ch := range channels {
d.dispatch(ctx, ch, event)
}
}
// retryDelays defines the wait durations before each retry attempt.
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
config, err := d.decryptConfig(ch.Config)
if err != nil {
slog.Warn("channels: failed to decrypt config",
"channel_id", id.FormatChannelID(ch.ID), "error", err)
return
}
chID := id.FormatChannelID(ch.ID)
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
slog.Warn("channels: delivery failed, scheduling retries",
"channel_id", chID, "provider", ch.Provider, "error", err)
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
}
}
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
for i, delay := range retryDelays {
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
if err := Deliver(ctx, provider, config, e); err != nil {
slog.Warn("channels: retry delivery failed",
"channel_id", chID, "provider", provider,
"attempt", i+2, "error", err)
continue
}
return
}
slog.Error("channels: delivery failed after all retries",
"channel_id", chID, "provider", provider, "event", e.Event)
}
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
var encrypted map[string]string
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
return nil, err
}
decrypted := make(map[string]string, len(encrypted))
for k, v := range encrypted {
plaintext, err := DecryptSecret(d.encKey, v)
if err != nil {
return nil, err
}
decrypted[k] = plaintext
}
return decrypted, nil
}
func isGroupExistsError(err error) bool {
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
}

View File

@ -0,0 +1,65 @@
package channels
import (
"fmt"
"strings"
"git.omukk.dev/wrenn/sandbox/internal/events"
)
// FormatMessage produces a human-readable notification string containing
// the event summary, resource details, actor, and timestamp.
func FormatMessage(e events.Event) string {
var b strings.Builder
b.WriteString(formatSummary(e))
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
return b.String()
}
func formatSummary(e events.Event) string {
switch e.Event {
case events.CapsuleCreated:
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
case events.CapsuleRunning:
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
case events.CapsulePaused:
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
case events.CapsuleDestroyed:
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
case events.SnapshotCreated:
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
case events.SnapshotDeleted:
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
case events.HostUp:
return fmt.Sprintf("Host %s is up", e.Resource.ID)
case events.HostDown:
return fmt.Sprintf("Host %s is down", e.Resource.ID)
default:
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
}
}
func formatActor(a events.Actor) string {
switch a.Type {
case events.ActorSystem:
return "system"
case events.ActorUser:
if a.Name != "" {
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
}
return a.ID
case events.ActorAPIKey:
if a.Name != "" {
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
}
return fmt.Sprintf("api_key %s", a.ID)
default:
return string(a.Type)
}
}

View File

@ -0,0 +1,44 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/events"
)
const streamKey = "wrenn:events"
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
type Publisher struct {
rdb *redis.Client
}
// NewPublisher constructs an event publisher.
func NewPublisher(rdb *redis.Client) *Publisher {
return &Publisher{rdb: rdb}
}
// Publish serializes the event and appends it to the global stream.
// Fire-and-forget: failures are logged, never propagated.
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
payload, err := json.Marshal(e)
if err != nil {
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
return
}
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
Stream: streamKey,
MaxLen: 10000,
Approx: true,
Values: map[string]interface{}{
"payload": string(payload),
},
}).Err(); err != nil {
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
}
}

View File

@ -0,0 +1,298 @@
package channels
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/events"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/validate"
)
// Valid providers.
var validProviders = map[string]bool{
"discord": true,
"slack": true,
"teams": true,
"googlechat": true,
"telegram": true,
"matrix": true,
"webhook": true,
}
// Required config fields per provider.
var requiredFields = map[string][]string{
"discord": {"webhook_url"},
"slack": {"webhook_url"},
"teams": {"webhook_url"},
"googlechat": {"webhook_url"},
"telegram": {"bot_token", "chat_id"},
"matrix": {"homeserver_url", "access_token", "room_id"},
"webhook": {"url"},
}
// validEvents maps event type strings to true for validation.
var validEvents map[string]bool
func init() {
validEvents = make(map[string]bool, len(events.AllEventTypes))
for _, et := range events.AllEventTypes {
validEvents[et] = true
}
}
// Service handles channel CRUD operations.
type Service struct {
DB *db.Queries
EncKey [32]byte
}
// CreateParams holds the parameters for creating a channel.
type CreateParams struct {
TeamID pgtype.UUID
Name string
Provider string
Config map[string]string
Events []string
}
// CreateResult holds the result of creating a channel.
type CreateResult struct {
Channel db.Channel
PlaintextSecret string // non-empty only for webhook provider
}
// Create creates a new notification channel.
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
clean, err := cleanName(p.Name)
if err != nil {
return CreateResult{}, err
}
p.Name = clean
if !validProviders[p.Provider] {
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
}
if len(p.Events) == 0 {
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range p.Events {
if !validEvents[et] {
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
// Validate required config fields.
for _, field := range requiredFields[p.Provider] {
if p.Config[field] == "" {
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
var plaintextSecret string
if p.Provider == "webhook" {
if p.Config["secret"] == "" {
secret := generateSecret()
p.Config["secret"] = secret
plaintextSecret = secret
} else {
plaintextSecret = p.Config["secret"]
}
}
// Encrypt config fields.
encrypted := make(map[string]string, len(p.Config))
for k, v := range p.Config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
}
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
ID: id.NewChannelID(),
TeamID: p.TeamID,
Name: p.Name,
Provider: p.Provider,
Config: configJSON,
EventTypes: p.Events,
})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
}
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
}
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
}
// List returns all channels belonging to the given team.
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
return s.DB.ListChannelsByTeam(ctx, teamID)
}
// Get returns a single channel by ID, scoped to the given team.
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// Update updates a channel's name and event types.
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
clean, err := cleanName(name)
if err != nil {
return db.Channel{}, err
}
name = clean
if len(eventTypes) == 0 {
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range eventTypes {
if !validEvents[et] {
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
ID: channelID,
TeamID: teamID,
Name: name,
EventTypes: eventTypes,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
}
return db.Channel{}, fmt.Errorf("update channel: %w", err)
}
return ch, nil
}
// RotateConfig replaces a channel's config with new provider secrets.
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
// Look up the existing channel to get its provider for validation.
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("get channel: %w", err)
}
// Validate required config fields for this provider.
for _, field := range requiredFields[ch.Provider] {
if config[field] == "" {
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
if ch.Provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
// Encrypt all config fields.
encrypted := make(map[string]string, len(config))
for k, v := range config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
}
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
ID: channelID,
TeamID: teamID,
Config: configJSON,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
}
return updated, nil
}
// Test validates config and sends a test notification without persisting anything.
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
if !validProviders[provider] {
return fmt.Errorf("invalid: unsupported provider %q", provider)
}
for _, field := range requiredFields[provider] {
if config[field] == "" {
return fmt.Errorf("invalid: %s is required for %s", field, provider)
}
}
// For webhooks, auto-generate a temporary secret if not provided.
if provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
testEvent := events.Event{
Event: "channel.test",
Timestamp: events.Now(),
TeamID: "test",
Actor: events.Actor{Type: events.ActorSystem},
Resource: events.Resource{ID: "test", Type: "channel"},
}
return Deliver(ctx, provider, config, testEvent)
}
// Delete removes a channel by ID, scoped to the given team.
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// cleanName normalises a channel name: trim whitespace, lowercase, replace
// spaces with hyphens, then validate against SafeName rules.
func cleanName(name string) (string, error) {
name = strings.TrimSpace(name)
name = strings.ToLower(name)
name = strings.ReplaceAll(name, " ", "-")
if err := validate.SafeName(name); err != nil {
return "", fmt.Errorf("invalid: %w", err)
}
return name, nil
}
func generateSecret() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,119 @@
package channels
import (
"fmt"
"net/url"
"regexp"
"strings"
)
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
switch provider {
case "discord":
return discordURL(config)
case "slack":
return slackURL(config)
case "teams":
return teamsURL(config)
case "googlechat":
return googlechatURL(config)
case "telegram":
return telegramURL(config)
case "matrix":
return matrixURL(config)
default:
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
}
}
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
func discordURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
}
// Path: /api/webhooks/{id}/{token}
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
return "", fmt.Errorf("unexpected discord webhook URL format")
}
webhookID, token := parts[2], parts[3]
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
}
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
func slackURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
}
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "services" {
return "", fmt.Errorf("unexpected slack webhook URL format")
}
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
}
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
func teamsURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("teams webhook_url is required")
}
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
if len(groups) != 5 {
return "", fmt.Errorf("unexpected teams webhook URL format")
}
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
}
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
func googlechatURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("googlechat webhook_url is required")
}
u, err := url.Parse(webhookURL)
if err != nil {
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
}
if u.Host != "chat.googleapis.com" {
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
}
// Rebuild as googlechat:// scheme with same host, path, and query.
u.Scheme = "googlechat"
return u.String(), nil
}
// telegramURL builds telegram://token@telegram/?chats=chatID
func telegramURL(config map[string]string) (string, error) {
token := config["bot_token"]
chatID := config["chat_id"]
if token == "" || chatID == "" {
return "", fmt.Errorf("telegram bot_token and chat_id are required")
}
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
}
// matrixURL builds matrix://user:token@homeserver/room
func matrixURL(config map[string]string) (string, error) {
homeserver := config["homeserver_url"]
token := config["access_token"]
roomID := config["room_id"]
if homeserver == "" || token == "" || roomID == "" {
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
}
// Strip protocol from homeserver URL.
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
// Room ID often starts with ! — URL-encode it.
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
}

View File

@ -0,0 +1,62 @@
package channels
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
type WebhookDelivery struct {
client *http.Client
}
// NewWebhookDelivery constructs a webhook delivery client.
func NewWebhookDelivery() *WebhookDelivery {
return &WebhookDelivery{
client: &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
// Deliver signs and POSTs the event payload to the configured URL.
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
timestamp := time.Now().UTC().Format(time.RFC3339)
deliveryID := uuid.New().String()
// Compute HMAC-SHA256: sign over "timestamp.body".
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(timestamp + "." + string(payload)))
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-WRENN-SIGNATURE", signature)
req.Header.Set("X-Wrenn-Delivery", deliveryID)
req.Header.Set("X-Wrenn-Timestamp", timestamp)
resp, err := d.client.Do(req)
if err != nil {
return fmt.Errorf("http post: %w", err)
}
resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook returned %d", resp.StatusCode)
}
return nil
}

View File

@ -1,24 +1,32 @@
package config
import (
"encoding/hex"
"os"
"strings"
"github.com/joho/godotenv"
)
// Config holds the control plane configuration.
type Config struct {
DatabaseURL string
RedisURL string
ListenAddr string
HostAgentAddr string
JWTSecret string
DatabaseURL string
RedisURL string
ListenAddr string
JWTSecret string
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
OAuthGitHubClientID string
OAuthGitHubClientSecret string
OAuthRedirectURL string
CPPublicURL string
// Channels — encryption for channel secrets (AES-256-GCM).
EncryptionKeyHex string // WRENN_ENCRYPTION_KEY raw hex string (for validation)
EncryptionKey [32]byte // parsed 32-byte key
}
// Load reads configuration from a .env file (if present) and environment variables.
@ -28,21 +36,27 @@ func Load() Config {
_ = godotenv.Load()
cfg := Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
JWTSecret: os.Getenv("JWT_SECRET"),
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
CACert: os.Getenv("WRENN_CA_CERT"),
CAKey: os.Getenv("WRENN_CA_KEY"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
EncryptionKeyHex: os.Getenv("WRENN_ENCRYPTION_KEY"),
}
// Ensure the host agent address has a scheme.
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
if cfg.EncryptionKeyHex != "" {
b, err := hex.DecodeString(cfg.EncryptionKeyHex)
if err == nil && len(b) == 32 {
copy(cfg.EncryptionKey[:], b)
}
}
return cfg

View File

@ -16,8 +16,8 @@ DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
`
type DeleteAPIKeyParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
@ -52,12 +52,12 @@ RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_
`
type InsertAPIKeyParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
@ -87,7 +87,7 @@ const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) {
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
if err != nil {
return nil, err
@ -126,18 +126,18 @@ ORDER BY k.created_at DESC
`
type ListAPIKeysByTeamWithCreatorRow struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
CreatorEmail string `json:"creator_email"`
}
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) {
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
if err != nil {
return nil, err
@ -171,7 +171,7 @@ const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
`
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error {
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
return err
}

111
internal/db/audit.sql.go Normal file
View File

@ -0,0 +1,111 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: audit.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const insertAuditLog = `-- name: InsertAuditLog :exec
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
type InsertAuditLogParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) error {
_, err := q.db.Exec(ctx, insertAuditLog,
arg.ID,
arg.TeamID,
arg.ActorType,
arg.ActorID,
arg.ActorName,
arg.ResourceType,
arg.ResourceID,
arg.Action,
arg.Scope,
arg.Status,
arg.Metadata,
)
return err
}
const listAuditLogs = `-- name: ListAuditLogs :many
SELECT id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata, created_at FROM audit_logs
WHERE team_id = $1
AND scope = ANY($2::text[])
AND (cardinality($3::text[]) = 0 OR resource_type = ANY($3::text[]))
AND (cardinality($4::text[]) = 0 OR action = ANY($4::text[]))
AND ($5::timestamptz IS NULL OR created_at < $5
OR (created_at = $5 AND id < $6))
ORDER BY created_at DESC, id DESC
LIMIT $7
`
type ListAuditLogsParams struct {
TeamID pgtype.UUID `json:"team_id"`
Column2 []string `json:"column_2"`
Column3 []string `json:"column_3"`
Column4 []string `json:"column_4"`
Column5 pgtype.Timestamptz `json:"column_5"`
ID pgtype.UUID `json:"id"`
Limit int32 `json:"limit"`
}
func (q *Queries) ListAuditLogs(ctx context.Context, arg ListAuditLogsParams) ([]AuditLog, error) {
rows, err := q.db.Query(ctx, listAuditLogs,
arg.TeamID,
arg.Column2,
arg.Column3,
arg.Column4,
arg.Column5,
arg.ID,
arg.Limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AuditLog
for rows.Next() {
var i AuditLog
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.ActorType,
&i.ActorID,
&i.ActorName,
&i.ResourceType,
&i.ResourceID,
&i.Action,
&i.Scope,
&i.Status,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

225
internal/db/channels.sql.go Normal file
View File

@ -0,0 +1,225 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: channels.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteChannelByTeam = `-- name: DeleteChannelByTeam :exec
DELETE FROM channels WHERE id = $1 AND team_id = $2
`
type DeleteChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteChannelByTeam(ctx context.Context, arg DeleteChannelByTeamParams) error {
_, err := q.db.Exec(ctx, deleteChannelByTeam, arg.ID, arg.TeamID)
return err
}
const getChannelByTeam = `-- name: GetChannelByTeam :one
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE id = $1 AND team_id = $2
`
type GetChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetChannelByTeam(ctx context.Context, arg GetChannelByTeamParams) (Channel, error) {
row := q.db.QueryRow(ctx, getChannelByTeam, arg.ID, arg.TeamID)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const insertChannel = `-- name: InsertChannel :one
INSERT INTO channels (id, team_id, name, provider, config, event_types)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type InsertChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) InsertChannel(ctx context.Context, arg InsertChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, insertChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.Provider,
arg.Config,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listChannelsByTeam = `-- name: ListChannelsByTeam :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListChannelsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listChannelsForEvent = `-- name: ListChannelsForEvent :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels
WHERE team_id = $1
AND $2::text = ANY(event_types)
ORDER BY created_at
`
type ListChannelsForEventParams struct {
TeamID pgtype.UUID `json:"team_id"`
EventType string `json:"event_type"`
}
func (q *Queries) ListChannelsForEvent(ctx context.Context, arg ListChannelsForEventParams) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsForEvent, arg.TeamID, arg.EventType)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateChannel = `-- name: UpdateChannel :one
UPDATE channels SET name = $3, event_types = $4, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) UpdateChannel(ctx context.Context, arg UpdateChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateChannelConfig = `-- name: UpdateChannelConfig :one
UPDATE channels SET config = $3, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelConfigParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Config []byte `json:"config"`
}
func (q *Queries) UpdateChannelConfig(ctx context.Context, arg UpdateChannelConfigParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannelConfig, arg.ID, arg.TeamID, arg.Config)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}

View File

@ -0,0 +1,92 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: host_refresh_tokens.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
`
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
return err
}
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
`
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, insertHostRefreshToken,
arg.ID,
arg.HostID,
arg.TokenHash,
arg.ExpiresAt,
)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

View File

@ -16,8 +16,8 @@ INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
`
type AddHostTagParams struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
@ -29,16 +29,16 @@ const deleteHost = `-- name: DeleteHost :exec
DELETE FROM hosts WHERE id = $1
`
func (q *Queries) DeleteHost(ctx context.Context, id string) error {
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteHost, id)
return err
}
const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
`
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
row := q.db.QueryRow(ctx, getHost, id)
var i Host
err := row.Scan(
@ -59,18 +59,18 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
ID string `json:"id"`
TeamID pgtype.Text `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
@ -103,7 +103,7 @@ const getHostTags = `-- name: GetHostTags :many
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
`
func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) {
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
rows, err := q.db.Query(ctx, getHostTags, hostID)
if err != nil {
return nil, err
@ -127,7 +127,7 @@ const getHostTokensByHost = `-- name: GetHostTokensByHost :many
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
`
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) {
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
if err != nil {
return nil, err
@ -157,16 +157,16 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
`
type InsertHostParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.Text `json:"team_id"`
Provider pgtype.Text `json:"provider"`
AvailabilityZone pgtype.Text `json:"availability_zone"`
CreatedBy string `json:"created_by"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
)
return i, err
}
@ -209,9 +209,9 @@ RETURNING id, host_id, created_by, created_at, expires_at, used_at
`
type InsertHostTokenParams struct {
ID string `json:"id"`
HostID string `json:"host_id"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
@ -234,8 +234,52 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
return i, err
}
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listActiveHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
@ -265,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -278,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
@ -308,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -321,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
}
const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
@ -354,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -367,10 +411,10 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
if err != nil {
return nil, err
@ -397,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -410,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
}
const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
@ -440,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
@ -456,31 +500,44 @@ const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
return err
}
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
status = 'online',
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
updated_at = NOW()
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
ID string `json:"id"`
Arch pgtype.Text `json:"arch"`
CpuCores pgtype.Int4 `json:"cpu_cores"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
DiskGb pgtype.Int4 `json:"disk_gb"`
Address pgtype.Text `json:"address"`
ID pgtype.UUID `json:"id"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
@ -491,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
arg.MemoryMb,
arg.DiskGb,
arg.Address,
arg.CertFingerprint,
arg.CertExpiresAt,
)
if err != nil {
return 0, err
@ -503,8 +562,8 @@ DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
`
type RemoveHostTagParams struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
@ -512,22 +571,59 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
return err
}
const updateHostCert = `-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1
`
type UpdateHostCertParams struct {
ID pgtype.UUID `json:"id"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
return err
}
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
`
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
return err
}
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1
`
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
// Returns 0 if no host was found (deleted), which the caller treats as 404.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const updateHostStatus = `-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`
type UpdateHostStatusParams struct {
ID string `json:"id"`
Status string `json:"status"`
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {

250
internal/db/metrics.sql.go Normal file
View File

@ -0,0 +1,250 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: metrics.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1
`
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
return err
}
const deleteSandboxMetricPointsByTier = `-- name: DeleteSandboxMetricPointsByTier :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2
`
type DeleteSandboxMetricPointsByTierParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
}
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPointsByTier, arg.SandboxID, arg.Tier)
return err
}
const getLiveMetrics = `-- name: GetLiveMetrics :one
SELECT
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
WHERE team_id = $1
`
type GetLiveMetricsRow struct {
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Reads directly from sandboxes for accurate real-time current values.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
var i GetLiveMetricsRow
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
return i, err
}
const getPeakMetrics = `-- name: GetPeakMetrics :one
SELECT
COALESCE(MAX(running_count), 0)::INTEGER AS peak_running_count,
COALESCE(MAX(vcpus_reserved), 0)::INTEGER AS peak_vcpus,
COALESCE(MAX(memory_mb_reserved), 0)::INTEGER AS peak_memory_mb
FROM sandbox_metrics_snapshots
WHERE team_id = $1
AND sampled_at > NOW() - INTERVAL '30 days'
`
type GetPeakMetricsRow struct {
PeakRunningCount int32 `json:"peak_running_count"`
PeakVcpus int32 `json:"peak_vcpus"`
PeakMemoryMb int32 `json:"peak_memory_mb"`
}
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
var i GetPeakMetricsRow
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
return i, err
}
const getSandboxMetricPoints = `-- name: GetSandboxMetricPoints :many
SELECT ts, cpu_pct, mem_bytes, disk_bytes
FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2 AND ts >= $3
ORDER BY ts ASC
`
type GetSandboxMetricPointsParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
}
type GetSandboxMetricPointsRow struct {
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) GetSandboxMetricPoints(ctx context.Context, arg GetSandboxMetricPointsParams) ([]GetSandboxMetricPointsRow, error) {
rows, err := q.db.Query(ctx, getSandboxMetricPoints, arg.SandboxID, arg.Tier, arg.Ts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetSandboxMetricPointsRow
for rows.Next() {
var i GetSandboxMetricPointsRow
if err := rows.Scan(
&i.Ts,
&i.CpuPct,
&i.MemBytes,
&i.DiskBytes,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertMetricsSnapshot = `-- name: InsertMetricsSnapshot :exec
INSERT INTO sandbox_metrics_snapshots (team_id, running_count, vcpus_reserved, memory_mb_reserved)
VALUES ($1, $2, $3, $4)
`
type InsertMetricsSnapshotParams struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
_, err := q.db.Exec(ctx, insertMetricsSnapshot,
arg.TeamID,
arg.RunningCount,
arg.VcpusReserved,
arg.MemoryMbReserved,
)
return err
}
const insertSandboxMetricPoint = `-- name: InsertSandboxMetricPoint :exec
INSERT INTO sandbox_metric_points (sandbox_id, tier, ts, cpu_pct, mem_bytes, disk_bytes)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
`
type InsertSandboxMetricPointParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
_, err := q.db.Exec(ctx, insertSandboxMetricPoint,
arg.SandboxID,
arg.Tier,
arg.Ts,
arg.CpuPct,
arg.MemBytes,
arg.DiskBytes,
)
return err
}
const pruneOldMetrics = `-- name: PruneOldMetrics :exec
DELETE FROM sandbox_metrics_snapshots
WHERE sampled_at < NOW() - INTERVAL '60 days'
`
func (q *Queries) PruneOldMetrics(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneOldMetrics)
return err
}
const pruneSandboxMetricPoints = `-- name: PruneSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE ts < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT
`
// Remove metric points older than 30 days for destroyed sandboxes.
func (q *Queries) PruneSandboxMetricPoints(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneSandboxMetricPoints)
return err
}
const sampleSandboxMetrics = `-- name: SampleSandboxMetrics :many
SELECT
team_id,
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
GROUP BY team_id
`
type SampleSandboxMetricsRow struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Aggregates per-team resource usage from the live sandboxes table.
// Groups by all teams that have any sandbox row (including stopped) so that
// zero-value snapshots are recorded when all capsules are stopped, keeping the
// time-series charts continuous rather than trailing off into empty space.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) SampleSandboxMetrics(ctx context.Context) ([]SampleSandboxMetricsRow, error) {
rows, err := q.db.Query(ctx, sampleSandboxMetrics)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SampleSandboxMetricsRow
for rows.Next() {
var i SampleSandboxMetricsRow
if err := rows.Scan(
&i.TeamID,
&i.RunningCount,
&i.VcpusReserved,
&i.MemoryMbReserved,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -9,42 +9,77 @@ import (
)
type AdminPermission struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type AuditLog struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Channel struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type Host struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.Text `json:"team_id"`
Provider pgtype.Text `json:"provider"`
AvailabilityZone pgtype.Text `json:"availability_zone"`
Arch pgtype.Text `json:"arch"`
CpuCores pgtype.Int4 `json:"cpu_cores"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
DiskGb pgtype.Int4 `json:"disk_gb"`
Address pgtype.Text `json:"address"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
MtlsEnabled bool `json:"mtls_enabled"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
type HostRefreshToken struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
}
type HostTag struct {
HostID string `json:"host_id"`
Tag string `json:"tag"`
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
type HostToken struct {
ID string `json:"id"`
HostID string `json:"host_id"`
CreatedBy string `json:"created_by"`
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
UsedAt pgtype.Timestamptz `json:"used_at"`
@ -53,42 +88,65 @@ type HostToken struct {
type OauthProvider struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID string `json:"user_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Sandbox struct {
ID string `json:"id"`
HostID string `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
type SandboxMetricPoint struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type SandboxMetricsSnapshot struct {
ID int64 `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
SampledAt pgtype.Timestamptz `json:"sampled_at"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
type Team struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type TeamApiKey struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy string `json:"created_by"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
}
@ -96,25 +154,50 @@ type TeamApiKey struct {
type Template struct {
Name string `json:"name"`
Type string `json:"type"`
Vcpus pgtype.Int4 `json:"vcpus"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
ID pgtype.UUID `json:"id"`
}
type TemplateBuild struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs []byte `json:"logs"`
Error string `json:"error"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
CompletedAt pgtype.Timestamptz `json:"completed_at"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
type User struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
IsAdmin bool `json:"is_admin"`
}
type UsersTeam struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
CreatedAt pgtype.Timestamptz `json:"created_at"`

View File

@ -7,6 +7,8 @@ package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getOAuthProvider = `-- name: GetOAuthProvider :one
@ -38,10 +40,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID string `json:"user_id"`
Email string `json:"email"`
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
}
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {

View File

@ -11,16 +11,30 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::uuid[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::text[])
WHERE id = ANY($1::uuid[])
`
type BulkUpdateStatusByIDsParams struct {
Column1 []string `json:"column_1"`
Status string `json:"status"`
Column1 []pgtype.UUID `json:"column_1"`
Status string `json:"status"`
}
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
@ -29,38 +43,41 @@ func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatu
}
const getSandbox = `-- name: GetSandbox :one
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
`
func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) {
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandbox, id)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
`
type GetSandboxByTeamParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
@ -68,38 +85,70 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
SELECT s.status, h.address AS host_address
FROM sandboxes s
JOIN hosts h ON h.id = s.host_id
WHERE s.id = $1 AND s.team_id = $2
`
type GetSandboxProxyTargetParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
type GetSandboxProxyTargetRow struct {
Status string `json:"status"`
HostAddress string `json:"host_address"`
}
// Returns the sandbox status and its host's address in one query.
// Used by SandboxProxyWrapper to avoid two round-trips.
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
var i GetSandboxProxyTargetRow
err := row.Scan(&i.Status, &i.HostAddress)
return i, err
}
const insertSandbox = `-- name: InsertSandbox :one
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type InsertSandboxParams struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
HostID string `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
@ -112,29 +161,79 @@ func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (S
arg.Vcpus,
arg.MemoryMb,
arg.TimeoutSec,
arg.DiskSizeMb,
arg.TemplateID,
arg.TemplateTeamID,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC
`
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxes = `-- name: ListSandboxes :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
`
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
@ -148,19 +247,22 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -173,14 +275,14 @@ func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
}
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE host_id = $1 AND status = ANY($2::text[])
ORDER BY created_at DESC
`
type ListSandboxesByHostAndStatusParams struct {
HostID string `json:"host_id"`
Column2 []string `json:"column_2"`
HostID pgtype.UUID `json:"host_id"`
Column2 []string `json:"column_2"`
}
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
@ -194,19 +296,22 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -219,12 +324,12 @@ func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSand
}
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
ORDER BY created_at DESC
`
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
if err != nil {
return nil, err
@ -235,19 +340,22 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
@ -259,6 +367,21 @@ func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]San
return items, nil
}
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
`
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
const updateLastActive = `-- name: UpdateLastActive :exec
UPDATE sandboxes
SET last_active_at = $2,
@ -267,7 +390,7 @@ WHERE id = $1
`
type UpdateLastActiveParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
}
@ -285,11 +408,11 @@ SET status = 'running',
last_active_at = $4,
last_updated = NOW()
WHERE id = $1
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxRunningParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
HostIp string `json:"host_ip"`
GuestIp string `json:"guest_ip"`
StartedAt pgtype.Timestamptz `json:"started_at"`
@ -305,19 +428,22 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
@ -327,12 +453,12 @@ UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = $1
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxStatusParams struct {
ID string `json:"id"`
Status string `json:"status"`
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
@ -340,19 +466,22 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TeamID,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}

View File

@ -7,10 +7,26 @@ package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteTeamMember = `-- name: DeleteTeamMember :exec
DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
`
type DeleteTeamMemberParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
_, err := q.db.Exec(ctx, deleteTeamMember, arg.TeamID, arg.UserID)
return err
}
const getBYOCTeams = `-- name: GetBYOCTeams :many
SELECT id, name, created_at, is_byoc FROM teams WHERE is_byoc = TRUE ORDER BY created_at
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
`
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
@ -25,8 +41,10 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
if err := rows.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
); err != nil {
return nil, err
}
@ -39,47 +57,111 @@ func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
}
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
SELECT t.id, t.name, t.created_at, t.is_byoc FROM teams t
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND ut.is_default = TRUE
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
LIMIT 1
`
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) {
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeam = `-- name: GetTeam :one
SELECT id, name, created_at, is_byoc FROM teams WHERE id = $1
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
`
func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getTeam, id)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamBySlug = `-- name: GetTeamBySlug :one
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
`
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
row := q.db.QueryRow(ctx, getTeamBySlug, slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamMembers = `-- name: GetTeamMembers :many
SELECT u.id, u.name, u.email, ut.role, ut.created_at AS joined_at
FROM users_teams ut
JOIN users u ON u.id = ut.user_id
WHERE ut.team_id = $1
ORDER BY ut.created_at
`
type GetTeamMembersRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt pgtype.Timestamptz `json:"joined_at"`
}
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamMembersRow
for rows.Next() {
var i GetTeamMembersRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Email,
&i.Role,
&i.JoinedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getTeamMembership = `-- name: GetTeamMembership :one
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
`
type GetTeamMembershipParams struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
@ -95,25 +177,74 @@ func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipPa
return i, err
}
const getTeamsForUser = `-- name: GetTeamsForUser :many
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at, ut.role
FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND t.deleted_at IS NULL
ORDER BY ut.created_at
`
type GetTeamsForUserRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
Role string `json:"role"`
}
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamsForUserRow
for rows.Next() {
var i GetTeamsForUserRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
&i.Role,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name)
VALUES ($1, $2)
RETURNING id, name, created_at, is_byoc
INSERT INTO teams (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, slug, is_byoc, created_at, deleted_at
`
type InsertTeamParams struct {
ID string `json:"id"`
Name string `json:"name"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
}
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name)
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name, arg.Slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.CreatedAt,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
@ -124,10 +255,10 @@ VALUES ($1, $2, $3, $4)
`
type InsertTeamMemberParams struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
}
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
@ -145,11 +276,49 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1
`
type SetTeamBYOCParams struct {
ID string `json:"id"`
IsByoc bool `json:"is_byoc"`
ID pgtype.UUID `json:"id"`
IsByoc bool `json:"is_byoc"`
}
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
return err
}
const softDeleteTeam = `-- name: SoftDeleteTeam :exec
UPDATE teams SET deleted_at = NOW() WHERE id = $1
`
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, softDeleteTeam, id)
return err
}
const updateMemberRole = `-- name: UpdateMemberRole :exec
UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
`
type UpdateMemberRoleParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
Role string `json:"role"`
}
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
_, err := q.db.Exec(ctx, updateMemberRole, arg.TeamID, arg.UserID, arg.Role)
return err
}
const updateTeamName = `-- name: UpdateTeamName :exec
UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
`
type UpdateTeamNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
_, err := q.db.Exec(ctx, updateTeamName, arg.ID, arg.Name)
return err
}

View File

@ -0,0 +1,241 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: template_builds.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getTemplateBuild = `-- name: GetTemplateBuild :one
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
`
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, getTemplateBuild, id)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type InsertTemplateBuildParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TotalSteps int32 `json:"total_steps"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, insertTemplateBuild,
arg.ID,
arg.Name,
arg.BaseTemplate,
arg.Recipe,
arg.Healthcheck,
arg.Vcpus,
arg.MemoryMb,
arg.TotalSteps,
arg.TemplateID,
arg.TeamID,
arg.SkipPrePost,
)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
`
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
rows, err := q.db.Query(ctx, listTemplateBuilds)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TemplateBuild
for rows.Next() {
var i TemplateBuild
if err := rows.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateBuildError = `-- name: UpdateBuildError :exec
UPDATE template_builds
SET error = $2, status = 'failed', completed_at = NOW()
WHERE id = $1
`
type UpdateBuildErrorParams struct {
ID pgtype.UUID `json:"id"`
Error string `json:"error"`
}
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
return err
}
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
UPDATE template_builds
SET current_step = $2, logs = $3
WHERE id = $1
`
type UpdateBuildProgressParams struct {
ID pgtype.UUID `json:"id"`
CurrentStep int32 `json:"current_step"`
Logs []byte `json:"logs"`
}
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
return err
}
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
UPDATE template_builds
SET sandbox_id = $2, host_id = $3
WHERE id = $1
`
type UpdateBuildSandboxParams struct {
ID pgtype.UUID `json:"id"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
}
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
return err
}
const updateBuildStatus = `-- name: UpdateBuildStatus :one
UPDATE template_builds
SET status = $2,
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
WHERE id = $1
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type UpdateBuildStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}

View File

@ -12,11 +12,11 @@ import (
)
const deleteTemplate = `-- name: DeleteTemplate :exec
DELETE FROM templates WHERE name = $1
DELETE FROM templates WHERE id = $1
`
func (q *Queries) DeleteTemplate(ctx context.Context, name string) error {
_, err := q.db.Exec(ctx, deleteTemplate, name)
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplate, id)
return err
}
@ -25,8 +25,8 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2
`
type DeleteTemplateByTeamParams struct {
Name string `json:"name"`
TeamID string `json:"team_id"`
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
@ -34,12 +34,23 @@ func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateBy
return err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
DELETE FROM templates WHERE team_id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, name)
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
return err
}
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
`
// Check if a global (platform) template exists with the given name.
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
var i Template
err := row.Scan(
&i.Name,
@ -49,19 +60,67 @@ func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, id)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByName = `-- name: GetTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
`
type GetTemplateByNameParams struct {
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
}
// Look up a template by team_id and name (exact team match, no global fallback).
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
`
type GetTemplateByTeamParams struct {
Name string `json:"name"`
TeamID string `json:"team_id"`
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
// Platform templates (team_id = 00000000-...) are visible to all teams.
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
var i Template
@ -73,27 +132,30 @@ func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamPa
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const insertTemplate = `-- name: InsertTemplate :one
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
`
type InsertTemplateParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Vcpus pgtype.Int4 `json:"vcpus"`
MemoryMb pgtype.Int4 `json:"memory_mb"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
row := q.db.QueryRow(ctx, insertTemplate,
arg.ID,
arg.Name,
arg.Type,
arg.Vcpus,
@ -110,12 +172,13 @@ func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams)
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const listTemplates = `-- name: ListTemplates :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
`
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
@ -135,6 +198,7 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -147,10 +211,11 @@ func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
}
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) {
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
if err != nil {
return nil, err
@ -167,6 +232,7 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -179,14 +245,15 @@ func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Tem
}
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
`
type ListTemplatesByTeamAndTypeParams struct {
TeamID string `json:"team_id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Type string `json:"type"`
}
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
if err != nil {
@ -204,6 +271,41 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
`
// List templates owned by a specific team (NOT including platform templates).
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
@ -216,7 +318,7 @@ func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTempla
}
const listTemplatesByType = `-- name: ListTemplatesByType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
@ -236,6 +338,7 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}

View File

@ -16,8 +16,8 @@ DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
`
type DeleteAdminPermissionParams struct {
UserID string `json:"user_id"`
Permission string `json:"permission"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
@ -29,7 +29,7 @@ const getAdminPermissions = `-- name: GetAdminPermissions :many
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
`
func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) {
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
if err != nil {
return nil, err
@ -55,7 +55,7 @@ func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]Adm
}
const getAdminUsers = `-- name: GetAdminUsers :many
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE is_admin = TRUE ORDER BY created_at
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
`
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
@ -71,9 +71,10 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
); err != nil {
return nil, err
}
@ -86,7 +87,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
}
const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE email = $1
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
@ -96,27 +97,29 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE id = $1
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
`
func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) {
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
)
return i, err
}
@ -128,8 +131,8 @@ SELECT EXISTS(
`
type HasAdminPermissionParams struct {
UserID string `json:"user_id"`
Permission string `json:"permission"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
@ -145,9 +148,9 @@ VALUES ($1, $2, $3)
`
type InsertAdminPermissionParams struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Permission string `json:"permission"`
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
@ -156,66 +159,118 @@ func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPerm
}
const insertUser = `-- name: InsertUser :one
INSERT INTO users (id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, email, password_hash, created_at, updated_at, is_admin
INSERT INTO users (id, email, password_hash, name)
VALUES ($1, $2, $3, $4)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserParams struct {
ID string `json:"id"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
}
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
row := q.db.QueryRow(ctx, insertUser, arg.ID, arg.Email, arg.PasswordHash)
row := q.db.QueryRow(ctx, insertUser,
arg.ID,
arg.Email,
arg.PasswordHash,
arg.Name,
)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
)
return i, err
}
const insertUserOAuth = `-- name: InsertUserOAuth :one
INSERT INTO users (id, email)
VALUES ($1, $2)
RETURNING id, email, password_hash, created_at, updated_at, is_admin
INSERT INTO users (id, email, name)
VALUES ($1, $2, $3)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserOAuthParams struct {
ID string `json:"id"`
Email string `json:"email"`
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email)
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email, arg.Name)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsAdmin,
)
return i, err
}
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`
type SearchUsersByEmailPrefixRow struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
}
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
rows, err := q.db.Query(ctx, searchUsersByEmailPrefix, dollar_1)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SearchUsersByEmailPrefixRow
for rows.Next() {
var i SearchUsersByEmailPrefixRow
if err := rows.Scan(&i.ID, &i.Email); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setUserAdmin = `-- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
`
type SetUserAdminParams struct {
ID string `json:"id"`
IsAdmin bool `json:"is_admin"`
ID pgtype.UUID `json:"id"`
IsAdmin bool `json:"is_admin"`
}
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
return err
}
const updateUserName = `-- name: UpdateUserName :exec
UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
`
type UpdateUserNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
_, err := q.db.Exec(ctx, updateUserName, arg.ID, arg.Name)
return err
}

View File

@ -116,9 +116,10 @@ type SnapshotDevice struct {
// writable CoW layer.
//
// The origin loop device must already exist (from LoopRegistry.Acquire).
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
// Create sparse CoW file sized to match the origin.
if err := createSparseFile(cowPath, originSizeBytes); err != nil {
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
// Create sparse CoW file. The logical size limits how many blocks can be
// modified; because the file is sparse, only written blocks use real disk.
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
return nil, fmt.Errorf("create cow file: %w", err)
}
@ -128,6 +129,9 @@ func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64)
return nil, fmt.Errorf("losetup cow: %w", err)
}
// The dm-snapshot virtual device size must match the origin — the snapshot
// target maps 1:1 onto origin sectors. The CoW file just needs enough
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
@ -220,6 +224,7 @@ func FlattenSnapshot(dmDevPath, outputPath string) error {
"if="+dmDevPath,
"of="+outputPath,
"bs=4M",
"conv=sparse",
"status=none",
)
if out, err := cmd.CombinedOutput(); err != nil {

View File

@ -3,14 +3,12 @@ package envdclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"net/url"
"time"
"connectrpc.com/connect"
@ -49,35 +47,6 @@ func (c *Client) BaseURL() string {
return c.base
}
// Init calls POST /init on envd to sync the guest clock with the host.
// This is important after snapshot resume where the guest clock is frozen.
func (c *Client) Init(ctx context.Context) error {
now := time.Now().UTC()
body, err := json.Marshal(map[string]any{"timestamp": now})
if err != nil {
return fmt.Errorf("marshal init body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create init request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("init request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// ExecResult holds the output of a command execution.
type ExecResult struct {
Stdout []byte

73
internal/events/event.go Normal file
View File

@ -0,0 +1,73 @@
package events
import (
"context"
"time"
)
// EventPublisher pushes events onto the notification stream.
// Satisfied by *channels.Publisher.
type EventPublisher interface {
Publish(ctx context.Context, e Event)
}
// ActorKind identifies what initiated an event.
type ActorKind string
const (
ActorUser ActorKind = "user"
ActorAPIKey ActorKind = "api_key"
ActorSystem ActorKind = "system"
)
// Actor describes who triggered an event.
type Actor struct {
Type ActorKind `json:"type"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
}
// Resource identifies the object the event relates to.
type Resource struct {
ID string `json:"id"`
Type string `json:"type"`
}
// Event is the canonical notification payload published to the Redis stream
// and delivered to channel subscribers.
type Event struct {
Event string `json:"event"`
Timestamp string `json:"timestamp"`
TeamID string `json:"team_id"`
Actor Actor `json:"actor"`
Resource Resource `json:"resource"`
}
// Event type constants.
const (
CapsuleCreated = "capsule.created"
CapsuleRunning = "capsule.running"
CapsulePaused = "capsule.paused"
CapsuleDestroyed = "capsule.destroyed"
SnapshotCreated = "template.snapshot.created"
SnapshotDeleted = "template.snapshot.deleted"
HostUp = "host.up"
HostDown = "host.down"
)
// AllEventTypes is the complete set of valid event type strings.
var AllEventTypes = []string{
CapsuleCreated,
CapsuleRunning,
CapsulePaused,
CapsuleDestroyed,
SnapshotCreated,
SnapshotDeleted,
HostUp,
HostDown,
}
// Now returns the current time formatted for event timestamps.
func Now() string {
return time.Now().UTC().Format(time.RFC3339)
}

View File

@ -0,0 +1,42 @@
package hostagent
import (
"crypto/tls"
"fmt"
"sync/atomic"
)
// CertStore provides lock-free read/write access to the agent's current TLS
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
// of the agent's cert on JWT refresh without restarting the server.
//
// The zero value is usable; GetCert returns an error until a cert is stored.
type CertStore struct {
ptr atomic.Pointer[tls.Certificate]
}
// Store atomically replaces the current certificate.
func (s *CertStore) Store(cert *tls.Certificate) {
s.ptr.Store(cert)
}
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
// If parsing fails the existing cert is unchanged.
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return fmt.Errorf("parse TLS key pair: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
// been stored yet.
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no TLS certificate available")
}
return cert, nil
}

View File

@ -0,0 +1,89 @@
package hostagent
import (
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
)
// ProxyHandler reverse-proxies HTTP requests to services running inside
// sandboxes. It handles requests of the form:
//
// /proxy/{sandbox_id}/{port}/{path...}
//
// The sandbox's HostIP (routable on this machine) is used as the upstream.
// This supports any protocol that rides on HTTP, including WebSocket upgrades.
type ProxyHandler struct {
mgr *sandbox.Manager
transport http.RoundTripper
}
// NewProxyHandler creates a new sandbox proxy handler.
func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
return &ProxyHandler{
mgr: mgr,
transport: http.DefaultTransport,
}
}
// ServeHTTP implements http.Handler.
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Expected path: /proxy/{sandbox_id}/{port}/...
// After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
if trimmed == r.URL.Path {
http.Error(w, "invalid proxy path", http.StatusBadRequest)
return
}
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 2 {
http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
return
}
sandboxID := parts[0]
port := parts[1]
remainder := ""
if len(parts) == 3 {
remainder = parts[2]
}
// Validate port is a number in the valid range.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
if !ok {
http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
return
}
defer tracker.Release()
targetHost := fmt.Sprintf("%s:%d", hostIP, portNum)
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = targetHost
req.URL.Path = "/" + remainder
req.URL.RawQuery = r.URL.RawQuery
req.Host = targetHost
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}

View File

@ -17,11 +17,24 @@ import (
"golang.org/x/sys/unix"
)
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
// It holds all credentials the agent needs: the host JWT, refresh token, and
// (when mTLS is enabled) the TLS certificate material for the agent's server.
type TokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
// mTLS fields — empty when the CP has no CA configured.
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane
TokenFile string // Path to persist the host JWT after registration
TokenFile string // Path to persist the credentials after registration
Address string // Externally-reachable address (ip:port) for this host
}
@ -34,9 +47,18 @@ type registerRequest struct {
Address string `json:"address"`
}
type registerResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
// authResponse is the shared JSON shape for both register and refresh responses.
type authResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
type errorResponse struct {
@ -46,20 +68,47 @@ type errorResponse struct {
} `json:"error"`
}
// Register calls the control plane to register this host agent and persists
// the returned JWT to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// Check if we already have a saved token.
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
token := strings.TrimSpace(string(data))
if token != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile)
return token, nil
}
// LoadTokenFile reads and parses the persisted credentials file.
func LoadTokenFile(path string) (*TokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Support legacy format (raw JWT string) for backwards compatibility.
trimmed := strings.TrimSpace(string(data))
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
}
var tf TokenFile
if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse credentials file: %w", err)
}
return &tf, nil
}
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
func saveTokenFile(path string, tf TokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
return fmt.Errorf("marshal credentials file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists
// the returned credentials to disk. Returns the full TokenFile on success.
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
// If no explicit registration token was given, reuse the saved credentials.
// A --register flag always overrides the local file so operators can
// force re-registration without manually deleting the credentials file.
if cfg.RegistrationToken == "" {
return "", fmt.Errorf("no saved host token and no registration token provided")
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf, nil
}
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
}
arch := runtime.GOARCH
@ -78,85 +127,238 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal registration request: %w", err)
return nil, fmt.Errorf("marshal registration request: %w", err)
}
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create registration request: %w", err)
return nil, fmt.Errorf("create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("registration request failed: %w", err)
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read registration response: %w", err)
return nil, fmt.Errorf("read registration response: %w", err)
}
if resp.StatusCode != http.StatusCreated {
var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil {
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
}
var regResp registerResponse
var regResp authResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return "", fmt.Errorf("parse registration response: %w", err)
return nil, fmt.Errorf("parse registration response: %w", err)
}
if regResp.Token == "" {
return "", fmt.Errorf("registration response missing token")
return nil, fmt.Errorf("registration response missing token")
}
// Persist the token to disk for subsequent startups.
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
return "", fmt.Errorf("save host token: %w", err)
hostID, err := hostIDFromJWT(regResp.Token)
if err != nil {
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
}
slog.Info("host registered and token saved", "file", cfg.TokenFile)
return regResp.Token, nil
tf := TokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
CertPEM: regResp.CertPEM,
KeyPEM: regResp.KeyPEM,
CACertPEM: regResp.CACertPEM,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return nil, fmt.Errorf("save host credentials: %w", err)
}
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
return &tf, nil
}
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
// is updated in place. Returns the updated TokenFile.
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
tf, err := LoadTokenFile(credentialsFilePath)
if err != nil {
return nil, fmt.Errorf("load credentials file: %w", err)
}
if tf.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read refresh response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
var refResp authResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
return nil, fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
if refResp.CertPEM != "" {
tf.CertPEM = refResp.CertPEM
tf.KeyPEM = refResp.KeyPEM
tf.CACertPEM = refResp.CACertPEM
}
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
return nil, fmt.Errorf("save refreshed credentials: %w", err)
}
slog.Info("host credentials refreshed", "host_id", tf.HostID)
return tf, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
//
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
//
// onDeleted is called when CP returns 404, meaning this host record was deleted.
// The credentials file is removed before calling onDeleted so subsequent starts
// prompt for a new registration token.
//
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
client := &http.Client{Timeout: 10 * time.Second}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
consecutiveFailures := 0
pausedDueToFailure := false
currentJWT := ""
// Load the current JWT from the credentials file.
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
currentJWT = tf.JWT
}
// beat sends one heartbeat. Returns true if the loop should stop.
beat := func() (stop bool) {
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
slog.Warn("heartbeat: failed to create request", "error", err)
return false
}
req.Header.Set("X-Host-Token", currentJWT)
resp, err := client.Do(req)
if err != nil {
consecutiveFailures++
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
if consecutiveFailures >= 3 && !pausedDueToFailure {
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
if pauseAll != nil {
pauseAll()
}
pausedDueToFailure = true
}
return false
}
resp.Body.Close()
switch resp.StatusCode {
case http.StatusNoContent:
if consecutiveFailures > 0 || pausedDueToFailure {
slog.Info("heartbeat: CP connection restored")
}
consecutiveFailures = 0
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
if refreshErr != nil {
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
pausedDueToFailure = true
}
// Stop the heartbeat loop — operator must re-register.
return true
}
currentJWT = newCreds.JWT
slog.Info("heartbeat: credentials refreshed successfully")
if onCredsRefreshed != nil {
onCredsRefreshed(newCreds)
}
case http.StatusNotFound:
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
}
if onDeleted != nil {
onDeleted()
}
return true
default:
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
}
return false
}
// Send an immediate heartbeat on startup so the CP sees the host as
// online without waiting for the first ticker tick.
if beat() {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
slog.Warn("heartbeat: failed to create request", "error", err)
continue
}
req.Header.Set("X-Host-Token", hostToken)
resp, err := client.Do(req)
if err != nil {
slog.Warn("heartbeat: request failed", "error", err)
continue
}
resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
if beat() {
return
}
}
}
@ -166,6 +368,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
// HostIDFromToken extracts the host_id claim from a host JWT without
// verifying the signature (the agent doesn't have the signing secret).
func HostIDFromToken(token string) (string, error) {
return hostIDFromJWT(token)
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the credentials file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT format")

View File

@ -12,6 +12,8 @@ import (
"time"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@ -22,12 +24,28 @@ import (
// Server implements the HostAgentService Connect RPC handler.
type Server struct {
hostagentv1connect.UnimplementedHostAgentServiceHandler
mgr *sandbox.Manager
mgr *sandbox.Manager
terminate func() // called when the CP requests agent termination
}
// NewServer creates a new host agent RPC server.
func NewServer(mgr *sandbox.Manager) *Server {
return &Server{mgr: mgr}
// terminate is invoked (in a goroutine) when the CP calls the Terminate RPC,
// allowing main to perform a clean shutdown.
func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
return &Server{mgr: mgr, terminate: terminate}
}
// parseUUIDString parses a UUID hex string into a pgtype.UUID.
// An empty string yields an all-zeros UUID (valid).
func parseUUIDString(s string) (pgtype.UUID, error) {
if s == "" {
return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil
}
parsed, err := uuid.Parse(s)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err)
}
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
}
func (s *Server) CreateSandbox(
@ -36,7 +54,16 @@ func (s *Server) CreateSandbox(
) (*connect.Response[pb.CreateSandboxResponse], error) {
msg := req.Msg
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
@ -87,12 +114,21 @@ func (s *Server) CreateSnapshot(
ctx context.Context,
req *connect.Request[pb.CreateSnapshotRequest],
) (*connect.Response[pb.CreateSnapshotResponse], error) {
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
}
return connect.NewResponse(&pb.CreateSnapshotResponse{
Name: req.Msg.Name,
SizeBytes: sizeBytes,
}), nil
}
@ -101,12 +137,45 @@ func (s *Server) DeleteSnapshot(
ctx context.Context,
req *connect.Request[pb.DeleteSnapshotRequest],
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
}
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
}
func (s *Server) FlattenRootfs(
ctx context.Context,
req *connect.Request[pb.FlattenRootfsRequest],
) (*connect.Response[pb.FlattenRootfsResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
}
return connect.NewResponse(&pb.FlattenRootfsResponse{
SizeBytes: sizeBytes,
}), nil
}
func (s *Server) PingSandbox(
ctx context.Context,
req *connect.Request[pb.PingSandboxRequest],
@ -397,7 +466,8 @@ func (s *Server) ListSandboxes(
infos[i] = &pb.SandboxInfo{
SandboxId: sb.ID,
Status: string(sb.Status),
Template: sb.Template,
TeamId: uuid.UUID(sb.TemplateTeamID).String(),
TemplateId: uuid.UUID(sb.TemplateID).String(),
Vcpus: int32(sb.VCPUs),
MemoryMb: int32(sb.MemoryMB),
HostIp: sb.HostIP.String(),
@ -412,3 +482,66 @@ func (s *Server) ListSandboxes(
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
}), nil
}
func (s *Server) Terminate(
_ context.Context,
_ *connect.Request[pb.TerminateRequest],
) (*connect.Response[pb.TerminateResponse], error) {
slog.Info("terminate RPC received — scheduling shutdown")
if s.terminate != nil {
go s.terminate()
}
return connect.NewResponse(&pb.TerminateResponse{}), nil
}
func (s *Server) GetSandboxMetrics(
_ context.Context,
req *connect.Request[pb.GetSandboxMetricsRequest],
) (*connect.Response[pb.GetSandboxMetricsResponse], error) {
msg := req.Msg
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if strings.Contains(err.Error(), "invalid range") {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&pb.GetSandboxMetricsResponse{Points: metricPointsToPB(points)}), nil
}
func (s *Server) FlushSandboxMetrics(
_ context.Context,
req *connect.Request[pb.FlushSandboxMetricsRequest],
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&pb.FlushSandboxMetricsResponse{
Points_10M: metricPointsToPB(pts10m),
Points_2H: metricPointsToPB(pts2h),
Points_24H: metricPointsToPB(pts24h),
}), nil
}
func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
out := make([]*pb.MetricPoint, len(pts))
for i, p := range pts {
out[i] = &pb.MetricPoint{
TimestampUnix: p.Timestamp.Unix(),
CpuPct: p.CPUPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}
}
return out
}

View File

@ -4,8 +4,171 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const (
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
)
var base36Base = big.NewInt(36)
// --- Generation ---
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
func newUUID() pgtype.UUID {
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
}
func NewSandboxID() pgtype.UUID { return newUUID() }
func NewUserID() pgtype.UUID { return newUUID() }
func NewTeamID() pgtype.UUID { return newUUID() }
func NewAPIKeyID() pgtype.UUID { return newUUID() }
func NewHostID() pgtype.UUID { return newUUID() }
func NewHostTokenID() pgtype.UUID { return newUUID() }
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
func NewAuditLogID() pgtype.UUID { return newUUID() }
func NewBuildID() pgtype.UUID { return newUUID() }
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
func NewChannelID() pgtype.UUID { return newUUID() }
func NewTemplateID() pgtype.UUID { return newUUID() }
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
func NewTeamSlug() string {
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
return hexToken(32)
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
func NewRefreshToken() string {
return hexToken(32)
}
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
const (
PrefixSandbox = "cl-"
PrefixUser = "usr-"
PrefixTeam = "team-"
PrefixAPIKey = "key-"
PrefixHost = "host-"
PrefixHostToken = "htok-"
PrefixRefreshToken = "hrt-"
PrefixAuditLog = "log-"
PrefixBuild = "bld-"
PrefixAdminPermission = "perm-"
PrefixChannel = "ch-"
)
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
func UUIDToBase36(b [16]byte) string {
n := new(big.Int).SetBytes(b[:])
buf := make([]byte, base36IDLen)
mod := new(big.Int)
for i := base36IDLen - 1; i >= 0; i-- {
n.DivMod(n, base36Base, mod)
buf[i] = base36Alphabet[mod.Int64()]
}
return string(buf)
}
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
func base36ToUUID(s string) ([16]byte, error) {
if len(s) != base36IDLen {
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
}
n := new(big.Int)
for _, c := range s {
idx := strings.IndexRune(base36Alphabet, c)
if idx < 0 {
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
}
n.Mul(n, base36Base)
n.Add(n, big.NewInt(int64(idx)))
}
b := n.Bytes()
var out [16]byte
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
copy(out[16-len(b):], b)
return out, nil
}
func formatUUID(prefix string, id pgtype.UUID) string {
return prefix + UUIDToBase36(id.Bytes)
}
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
func FormatChannelID(id pgtype.UUID) string { return formatUUID(PrefixChannel, id) }
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
func parseUUID(prefix, s string) (pgtype.UUID, error) {
if !strings.HasPrefix(s, prefix) {
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
}
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
}
return pgtype.UUID{Bytes: b, Valid: true}, nil
}
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixChannel, s) }
// --- Well-known IDs ---
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
// (e.g. base templates, shared infrastructure).
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
// template. When both team_id and template_id are zero, the host agent uses
// the minimal rootfs at WRENN_DIR/images/minimal/.
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
func UUIDString(id pgtype.UUID) string {
return uuid.UUID(id.Bytes).String()
}
// --- Helpers ---
func hex8() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@ -14,44 +177,8 @@ func hex8() string {
return hex.EncodeToString(b)
}
// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars.
func NewSandboxID() string {
return "sb-" + hex8()
}
// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewUserID generates a new user ID in the format "usr-" + 8 hex chars.
func NewUserID() string {
return "usr-" + hex8()
}
// NewTeamID generates a new team ID in the format "team-" + 8 hex chars.
func NewTeamID() string {
return "team-" + hex8()
}
// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars.
func NewAPIKeyID() string {
return "key-" + hex8()
}
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
func NewHostID() string {
return "host-" + hex8()
}
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
func NewHostTokenID() string {
return "htok-" + hex8()
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
b := make([]byte, 32)
func hexToken(nBytes int) string {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}

118
internal/id/id_test.go Normal file
View File

@ -0,0 +1,118 @@
package id
import (
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
func TestBase36RoundTrip(t *testing.T) {
for i := 0; i < 1000; i++ {
orig := uuid.New()
encoded := UUIDToBase36(orig)
if len(encoded) != base36IDLen {
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != orig {
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
}
}
}
func TestBase36ZeroUUID(t *testing.T) {
var zero [16]byte
encoded := UUIDToBase36(zero)
if encoded != "0000000000000000000000000" {
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != zero {
t.Fatalf("round-trip failed for zero UUID")
}
}
func TestFormatParseRoundTrip(t *testing.T) {
id := NewSandboxID()
formatted := FormatSandboxID(id)
if formatted[:3] != "cl-" {
t.Fatalf("expected cl- prefix, got %s", formatted)
}
if len(formatted) != 3+base36IDLen {
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
}
parsed, err := ParseSandboxID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != id {
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
}
}
func TestBase36InvalidInput(t *testing.T) {
// Wrong length.
if _, err := base36ToUUID("abc"); err == nil {
t.Fatal("expected error for short input")
}
// Invalid character.
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
t.Fatal("expected error for invalid character")
}
}
func TestPlatformTeamIDFormats(t *testing.T) {
formatted := FormatTeamID(PlatformTeamID)
parsed, err := ParseTeamID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != PlatformTeamID {
t.Fatalf("platform team ID round-trip failed")
}
}
func TestMaxUUID(t *testing.T) {
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
encoded := UUIDToBase36(max)
if len(encoded) != base36IDLen {
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != max {
t.Fatalf("round-trip failed for max UUID")
}
}
func BenchmarkFormatSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
b.ResetTimer()
for i := 0; i < b.N; i++ {
FormatSandboxID(id)
}
}
func BenchmarkParseSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
s := FormatSandboxID(id)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseSandboxID(s)
}
}

58
internal/layout/layout.go Normal file
View File

@ -0,0 +1,58 @@
package layout
import (
"path/filepath"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// IsMinimal reports whether the given team and template IDs represent the
// built-in "minimal" template (both all-zeros).
func IsMinimal(teamID, templateID pgtype.UUID) bool {
return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
}
// TemplateDir returns the on-disk directory for a template.
//
// minimal (zeros, zeros): {wrennDir}/images/minimal
// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
if IsMinimal(teamID, templateID) {
return filepath.Join(wrennDir, "images", "minimal")
}
return filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(teamID.Bytes),
id.UUIDToBase36(templateID.Bytes))
}
// TemplateRootfs returns the path to a template's rootfs.ext4.
func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
}
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
func PauseSnapshotDir(wrennDir, sandboxID string) string {
return filepath.Join(wrennDir, "snapshots", sandboxID)
}
// SandboxesDir returns the directory for running sandbox CoW files.
func SandboxesDir(wrennDir string) string {
return filepath.Join(wrennDir, "sandboxes")
}
// KernelPath returns the path to the Firecracker kernel.
func KernelPath(wrennDir string) string {
return filepath.Join(wrennDir, "kernels", "vmlinux")
}
// ImagesRoot returns the root images directory.
func ImagesRoot(wrennDir string) string {
return filepath.Join(wrennDir, "images")
}
// TeamsDir returns the directory containing all team template subdirectories.
func TeamsDir(wrennDir string) string {
return filepath.Join(wrennDir, "images", "teams")
}

View File

@ -0,0 +1,120 @@
package layout
import (
"path/filepath"
"testing"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
func TestIsMinimal(t *testing.T) {
tests := []struct {
name string
teamID pgtype.UUID
templateID pgtype.UUID
want bool
}{
{
name: "both zeros",
teamID: id.PlatformTeamID,
templateID: id.MinimalTemplateID,
want: true,
},
{
name: "non-zero team",
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
templateID: id.MinimalTemplateID,
want: false,
},
{
name: "non-zero template",
teamID: id.PlatformTeamID,
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
want: false,
},
{
name: "both non-zero",
teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
}
})
}
}
func TestTemplateDir(t *testing.T) {
wrennDir := "/var/lib/wrenn"
t.Run("minimal", func(t *testing.T) {
got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal")
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
t.Run("team template", func(t *testing.T) {
teamID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true}
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, Valid: true}
got := TemplateDir(wrennDir, teamID, tmplID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(teamID.Bytes),
id.UUIDToBase36(tmplID.Bytes))
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
t.Run("global template (platform team, non-zero template)", func(t *testing.T) {
tmplID := pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5}, Valid: true}
got := TemplateDir(wrennDir, id.PlatformTeamID, tmplID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(id.PlatformTeamID.Bytes),
id.UUIDToBase36(tmplID.Bytes))
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
})
}
func TestTemplateRootfs(t *testing.T) {
wrennDir := "/var/lib/wrenn"
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
if got != want {
t.Errorf("TemplateRootfs() = %q, want %q", got, want)
}
}
func TestPauseSnapshotDir(t *testing.T) {
got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
want := "/var/lib/wrenn/snapshots/cl-abc123"
if got != want {
t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
}
}
func TestSandboxesDir(t *testing.T) {
got := SandboxesDir("/var/lib/wrenn")
want := "/var/lib/wrenn/sandboxes"
if got != want {
t.Errorf("SandboxesDir() = %q, want %q", got, want)
}
}
func TestKernelPath(t *testing.T) {
got := KernelPath("/var/lib/wrenn")
want := "/var/lib/wrenn/kernels/vmlinux"
if got != want {
t.Errorf("KernelPath() = %q, want %q", got, want)
}
}

View File

@ -0,0 +1,125 @@
package lifecycle
import (
"crypto/tls"
"fmt"
"net/http"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
// Clients are created lazily on first access and evicted when a host is removed
// or goes unreachable. The pool is safe for concurrent use.
type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
scheme string // "http://" or "https://"
}
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
// Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
scheme: "http://",
}
}
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
// tlsCfg should already carry the CP client cert and CA trust anchor
// (use auth.CPClientTLSConfig to construct it).
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
transport := &http.Transport{
TLSClientConfig: tlsCfg,
}
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{
Timeout: 10 * time.Minute,
Transport: transport,
},
scheme: "https://",
}
}
// Get returns a Connect RPC client for the given host, creating one if necessary.
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
p.mu.RLock()
c, ok := p.clients[hostID]
p.mu.RUnlock()
if ok {
return c
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock.
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c
return c
}
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if h.Address == "" {
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
}
return p.Get(id.FormatHostID(h.ID), h.Address), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
// created on the next call to Get. Call this when a host's address changes or when
// a host is deleted.
func (p *HostClientPool) Evict(hostID string) {
p.mu.Lock()
delete(p.clients, hostID)
p.mu.Unlock()
}
// ensureScheme prepends the pool's configured scheme if the address has none.
func (p *HostClientPool) ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return p.scheme + addr
}
// Transport returns the http.RoundTripper used by this pool. Use this when you
// need to make raw HTTP requests to agent addresses with the same TLS settings
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
func (p *HostClientPool) Transport() http.RoundTripper {
if p.httpClient.Transport != nil {
return p.httpClient.Transport
}
return http.DefaultTransport
}
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
// Use this when constructing URLs that must use the same transport as the pool
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
func (p *HostClientPool) ResolveAddr(addr string) string {
return p.ensureScheme(addr)
}
// EnsureScheme adds "http://" if the address has no scheme.
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return "http://" + addr
}

View File

@ -18,15 +18,16 @@ const (
// Sandbox holds all state for a running sandbox on this host.
type Sandbox struct {
ID string
Status SandboxStatus
Template string
VCPUs int
MemoryMB int
TimeoutSec int
SlotIndex int
HostIP net.IP
RootfsPath string
CreatedAt time.Time
LastActiveAt time.Time
ID string
Status SandboxStatus
TemplateTeamID [16]byte
TemplateID [16]byte
VCPUs int
MemoryMB int
TimeoutSec int
SlotIndex int
HostIP net.IP
RootfsPath string
CreatedAt time.Time
LastActiveAt time.Time
}

View File

@ -5,13 +5,91 @@ import (
"fmt"
"log/slog"
"net"
"os"
"os/exec"
"runtime"
"strings"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
const nsPrefix = "wrenn-ns-"
// CleanupStaleNamespaces removes leftover wrenn network namespaces from a
// previous crash. Called once at agent startup.
func CleanupStaleNamespaces() {
entries, err := os.ReadDir("/run/netns")
if err != nil {
return // no /run/netns or unreadable — nothing to clean
}
for _, e := range entries {
name := e.Name()
if !strings.HasPrefix(name, nsPrefix) {
continue
}
// Also remove the associated veth from the host side.
vethName := "wrenn-veth-" + strings.TrimPrefix(name, nsPrefix)
if link, err := netlink.LinkByName(vethName); err == nil {
_ = netlink.LinkDel(link)
}
if err := netns.DeleteNamed(name); err != nil {
slog.Warn("failed to remove stale namespace", "ns", name, "error", err)
} else {
slog.Info("removed stale namespace", "ns", name)
}
}
// Clean up any stale wrenn iptables rules referencing old veth interfaces.
cleanupStaleIptablesRules()
}
// cleanupStaleIptablesRules removes host iptables rules that reference
// wrenn-veth interfaces no longer present on the system.
func cleanupStaleIptablesRules() {
for _, table := range []string{"filter", "nat"} {
cmd := exec.Command("iptables-save", "-t", table)
out, err := cmd.Output()
if err != nil {
continue
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "wrenn-veth-") {
continue
}
// Lines look like "-A FORWARD -i wrenn-veth-1 -o wlo1 -j ACCEPT"
// Convert -A to -D to delete the rule.
if !strings.HasPrefix(line, "-A ") {
continue
}
delRule := "-D " + line[3:]
args := strings.Fields(delRule)
delCmd := exec.Command("iptables", append([]string{"-t", table}, args...)...)
if err := delCmd.Run(); err != nil {
slog.Debug("failed to remove stale iptables rule", "rule", line, "error", err)
}
}
}
// Also remove stale host routes to 10.11.0.x via wrenn-veth interfaces.
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
if err != nil {
return
}
for _, r := range routes {
if r.LinkIndex == 0 {
continue
}
link, err := netlink.LinkByIndex(r.LinkIndex)
if err != nil {
continue
}
if strings.HasPrefix(link.Attrs().Name, "wrenn-veth-") {
_ = netlink.RouteDel(&r)
}
}
}
const (
// Fixed addresses inside each network namespace (safe because each
// sandbox gets its own netns).
@ -59,19 +137,20 @@ func NewSlot(index int) *Slot {
hostIP := make(net.IP, 4)
copy(hostIP, hostBaseIP)
hostIP[2] += byte(index / 256)
hostIP[3] += byte(index % 256)
hostIP[2] += byte(index >> 8)
hostIP[3] += byte(index & 0xFF)
vethOffset := index * vrtAddressesPerSlot
vethIP := make(net.IP, 4)
copy(vethIP, vrtBaseIP)
vethIP[2] += byte(vethOffset / 256)
vethIP[3] += byte(vethOffset % 256)
vethIP[2] += byte(vethOffset >> 8)
vethIP[3] += byte(vethOffset & 0xFF)
vpeerOffset := vethOffset + 1
vpeerIP := make(net.IP, 4)
copy(vpeerIP, vrtBaseIP)
vpeerIP[2] += byte((vethOffset + 1) / 256)
vpeerIP[3] += byte((vethOffset + 1) % 256)
vpeerIP[2] += byte(vpeerOffset >> 8)
vpeerIP[3] += byte(vpeerOffset & 0xFF)
return &Slot{
Index: index,
@ -84,8 +163,8 @@ func NewSlot(index int) *Slot {
GuestIP: guestIP,
GuestNetMask: guestNetMask,
TapName: tapName,
NamespaceID: fmt.Sprintf("ns-%d", index),
VethName: fmt.Sprintf("veth-%d", index),
NamespaceID: fmt.Sprintf("wrenn-ns-%d", index),
VethName: fmt.Sprintf("wrenn-veth-%d", index),
}
}

104
internal/recipe/context.go Normal file
View File

@ -0,0 +1,104 @@
package recipe
import (
"regexp"
"slices"
"strings"
)
// ExecContext holds mutable state that persists across recipe steps.
// It is initialized empty and updated by ENV and WORKDIR steps.
type ExecContext struct {
WorkDir string
EnvVars map[string]string
}
// This regex matches:
// 1. $$ (escaped dollar)
// 2. ${VAR} or ${} (braced variable, possibly empty)
// 3. $VAR (bare variable)
var envRegex = regexp.MustCompile(`\$\$|\$\{([a-zA-Z0-9_]*)\}|\$([a-zA-Z0-9_]+)`)
// WrappedCommand returns the full shell command for a RUN step with context
// applied. The result is passed as the argument to /bin/sh -c.
//
// If WORKDIR and/or ENV are set, they are prepended as a shell preamble:
//
// cd '/the/dir' && KEY='val' /bin/sh -c 'original command'
func (c *ExecContext) WrappedCommand(cmd string) string {
prefix := c.shellPrefix()
if prefix == "" {
return cmd
}
return prefix + "/bin/sh -c " + shellescape(cmd)
}
// StartCommand returns the shell command for a START step. The process is
// launched in the background via nohup so that the outer shell exits
// immediately, allowing the build to continue. stdout/stderr of the
// background process are discarded (the process keeps running in the VM).
//
// Multiple START steps can be issued to run several background processes
// simultaneously before a healthcheck is evaluated.
func (c *ExecContext) StartCommand(cmd string) string {
prefix := c.shellPrefix()
return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
}
// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command.
// Returns an empty string when no context is set.
func (c *ExecContext) shellPrefix() string {
if c.WorkDir == "" && len(c.EnvVars) == 0 {
return ""
}
var sb strings.Builder
if c.WorkDir != "" {
sb.WriteString("cd ")
sb.WriteString(shellescape(c.WorkDir))
sb.WriteString(" && ")
}
keys := make([]string, 0, len(c.EnvVars))
for k := range c.EnvVars {
keys = append(keys, k)
}
slices.Sort(keys)
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(shellescape(c.EnvVars[k]))
sb.WriteByte(' ')
}
return sb.String()
}
// expandEnv replaces $var and ${var} placeholders in the string s with their
// corresponding values from the vars map.
// It supports escaping with $$, which is replaced by a single $.
// If a variable is not found in the vars map, it is replaced with an empty
// string.
func expandEnv(s string, vars map[string]string) string {
return envRegex.ReplaceAllStringFunc(s, func(match string) string {
if match == "$$" {
return "$"
}
var name string
if len(match) > 1 && match[1] == '{' {
name = match[2 : len(match)-1]
} else {
name = match[1:]
}
if v, ok := vars[name]; ok {
return v
}
return ""
})
}
// shellescape wraps s in single quotes, escaping any embedded single quotes.
// This is POSIX-safe for paths, env values, and shell commands.
func shellescape(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}

View File

@ -0,0 +1,237 @@
package recipe
import "testing"
func TestExecContext_WrappedCommand(t *testing.T) {
tests := []struct {
name string
ctx ExecContext
cmd string
want string
wantOneOf []string
}{
{
name: "no context",
ctx: ExecContext{},
cmd: "apt install -y curl",
want: "apt install -y curl",
},
{
name: "workdir only",
ctx: ExecContext{WorkDir: "/app"},
cmd: "npm install",
want: "cd '/app' && /bin/sh -c 'npm install'",
},
{
name: "env only",
ctx: ExecContext{EnvVars: map[string]string{"PORT": "8080"}},
cmd: "node server.js",
want: "PORT='8080' /bin/sh -c 'node server.js'",
},
{
name: "workdir with space",
ctx: ExecContext{WorkDir: "/my project"},
cmd: "make build",
want: "cd '/my project' && /bin/sh -c 'make build'",
},
{
name: "command with single quotes",
ctx: ExecContext{WorkDir: "/app"},
cmd: "echo 'hello'",
want: "cd '/app' && /bin/sh -c 'echo '\\''hello'\\'''",
},
{
name: "env value with single quotes",
ctx: ExecContext{EnvVars: map[string]string{"MSG": "it's fine"}},
cmd: "echo $MSG",
want: "MSG='it'\\''s fine' /bin/sh -c 'echo $MSG'",
},
{
name: "env expansion with pre-expanded PATH",
ctx: ExecContext{
EnvVars: map[string]string{"PATH": "/usr/bin", "FOO": "/opt/venv/bin:/usr/bin"},
},
cmd: "make build",
want: "FOO='/opt/venv/bin:/usr/bin' PATH='/usr/bin' /bin/sh -c 'make build'",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.ctx.WrappedCommand(tc.cmd)
if len(tc.wantOneOf) > 0 {
matched := false
for _, w := range tc.wantOneOf {
if got == w {
matched = true
break
}
}
if !matched {
t.Errorf("WrappedCommand(%q)\n got %q\n want one of %q", tc.cmd, got, tc.wantOneOf)
}
} else if got != tc.want {
t.Errorf("WrappedCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
}
})
}
}
func TestExecContext_StartCommand(t *testing.T) {
tests := []struct {
name string
ctx ExecContext
cmd string
want string
}{
{
name: "no context",
ctx: ExecContext{},
cmd: "python3 app.py",
want: "nohup /bin/sh -c 'python3 app.py' >/dev/null 2>&1 &",
},
{
name: "with workdir",
ctx: ExecContext{WorkDir: "/app"},
cmd: "python3 server.py",
want: "cd '/app' && nohup /bin/sh -c 'python3 server.py' >/dev/null 2>&1 &",
},
{
name: "with env",
ctx: ExecContext{EnvVars: map[string]string{"PORT": "9000"}},
cmd: "node index.js",
want: "PORT='9000' nohup /bin/sh -c 'node index.js' >/dev/null 2>&1 &",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.ctx.StartCommand(tc.cmd)
if got != tc.want {
t.Errorf("StartCommand(%q)\n got %q\n want %q", tc.cmd, got, tc.want)
}
})
}
}
func TestExpandEnv(t *testing.T) {
tests := []struct {
s string
vars map[string]string
want string
}{
{
s: "hello",
vars: nil,
want: "hello",
},
{
s: "$PATH",
vars: map[string]string{"PATH": "/usr/bin"},
want: "/usr/bin",
},
{
s: "${PATH}",
vars: map[string]string{"PATH": "/usr/bin"},
want: "/usr/bin",
},
{
s: "/opt/venv/bin:$PATH",
vars: map[string]string{"PATH": "/usr/bin"},
want: "/opt/venv/bin:/usr/bin",
},
{
s: "${HOME}/code",
vars: map[string]string{"HOME": "/root"},
want: "/root/code",
},
{
s: "hello $USER",
vars: map[string]string{"USER": "admin"},
want: "hello admin",
},
{
s: "$UNSET",
vars: map[string]string{"PATH": "/usr/bin"},
want: "",
},
{
s: "${UNSET}",
vars: map[string]string{"PATH": "/usr/bin"},
want: "",
},
{
s: "$$",
vars: map[string]string{"PATH": "/usr/bin"},
want: "$",
},
{
s: "price is $$100",
vars: nil,
want: "price is $100",
},
{
s: "$FOO:$BAR",
vars: map[string]string{"FOO": "a", "BAR": "b"},
want: "a:b",
},
{
s: "${FOO}_${BAR}",
vars: map[string]string{"FOO": "hello", "BAR": "world"},
want: "hello_world",
},
{
s: "no vars here",
vars: nil,
want: "no vars here",
},
{
s: "$",
vars: nil,
want: "$",
},
{
s: "${",
vars: nil,
want: "${",
},
{
s: "${}",
vars: nil,
want: "",
},
{
s: "$VAR1$VAR2",
vars: map[string]string{"VAR1": "a", "VAR2": "b"},
want: "ab",
},
}
for _, tc := range tests {
t.Run(tc.s, func(t *testing.T) {
got := expandEnv(tc.s, tc.vars)
if got != tc.want {
t.Errorf("expandEnv(%q, %v)\n got %q\n want %q", tc.s, tc.vars, got, tc.want)
}
})
}
}
func TestShellescape(t *testing.T) {
tests := []struct {
input string
want string
}{
{"simple", "'simple'"},
{"/path/to/dir", "'/path/to/dir'"},
{"it's fine", "'it'\\''s fine'"},
{"", "''"},
{"a'b'c", "'a'\\''b'\\''c'"},
}
for _, tc := range tests {
got := shellescape(tc.input)
if got != tc.want {
t.Errorf("shellescape(%q) = %q, want %q", tc.input, got, tc.want)
}
}
}

185
internal/recipe/executor.go Normal file
View File

@ -0,0 +1,185 @@
package recipe
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"connectrpc.com/connect"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// DefaultStepTimeout is the fallback timeout for RUN steps that carry no
// explicit --timeout flag.
const DefaultStepTimeout = 30 * time.Second
// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB).
type BuildLogEntry struct {
Step int `json:"step"`
Phase string `json:"phase"`
Cmd string `json:"cmd"`
Stdout string `json:"stdout"`
Stderr string `json:"stderr"`
Exit int32 `json:"exit"`
Ok bool `json:"ok"`
Elapsed int64 `json:"elapsed_ms"`
}
// ExecFunc is the agent.Exec call signature used by the executor. It matches
// the method on the hostagent Connect RPC client.
type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
// Execute runs steps sequentially against sandboxID using execFn.
//
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
// - startStep is the 1-based offset so entries are globally numbered across phases.
// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes.
// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward
// into subsequent phases when the caller passes the same pointer.
//
// Returns all log entries appended during this call, the next step counter
// value, and whether all steps succeeded. On false the last entry contains
// failure details; the caller is responsible for destroying the sandbox and
// recording the build error.
func Execute(
ctx context.Context,
phase string,
steps []Step,
sandboxID string,
startStep int,
defaultTimeout time.Duration,
bctx *ExecContext,
execFn ExecFunc,
) (entries []BuildLogEntry, nextStep int, ok bool) {
if defaultTimeout <= 0 {
defaultTimeout = 10 * time.Minute
}
step := startStep
for _, st := range steps {
step++
slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw)
switch st.Kind {
case KindENV:
if bctx.EnvVars == nil {
bctx.EnvVars = make(map[string]string)
}
bctx.EnvVars[st.Key] = expandEnv(st.Value, bctx.EnvVars)
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
case KindWORKDIR:
bctx.WorkDir = st.Path
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
case KindUSER, KindCOPY:
verb := strings.ToUpper(strings.Fields(st.Raw)[0])
entries = append(entries, BuildLogEntry{
Step: step,
Phase: phase,
Cmd: st.Raw,
Stderr: verb + " is not yet supported",
Ok: false,
})
return entries, step, false
case KindSTART:
entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn)
entries = append(entries, entry)
if !succeeded {
return entries, step, false
}
case KindRUN:
timeout := defaultTimeout
if st.Timeout > 0 {
timeout = st.Timeout
}
entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
entries = append(entries, entry)
if !succeeded {
return entries, step, false
}
}
}
return entries, step, true
}
func execRun(
ctx context.Context,
st Step,
sandboxID, phase string,
step int,
timeout time.Duration,
bctx *ExecContext,
execFn ExecFunc,
) (BuildLogEntry, bool) {
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: "/bin/sh",
Args: []string{"-c", bctx.WrappedCommand(st.Shell)},
TimeoutSec: int32(timeout.Seconds()),
}))
entry := BuildLogEntry{
Step: step,
Phase: phase,
Cmd: st.Raw,
Elapsed: time.Since(start).Milliseconds(),
}
if err != nil {
entry.Stderr = fmt.Sprintf("exec error: %v", err)
return entry, false
}
entry.Stdout = string(resp.Msg.Stdout)
entry.Stderr = string(resp.Msg.Stderr)
entry.Exit = resp.Msg.ExitCode
entry.Ok = resp.Msg.ExitCode == 0
return entry, entry.Ok
}
func execStart(
ctx context.Context,
st Step,
sandboxID, phase string,
step int,
bctx *ExecContext,
execFn ExecFunc,
) (BuildLogEntry, bool) {
// START uses a short timeout: just long enough for the shell to fork and
// return. The background process itself runs indefinitely inside the VM.
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
start := time.Now()
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: "/bin/sh",
Args: []string{"-c", bctx.StartCommand(st.Shell)},
TimeoutSec: 10,
}))
entry := BuildLogEntry{
Step: step,
Phase: phase,
Cmd: st.Raw,
Elapsed: time.Since(start).Milliseconds(),
}
if err != nil {
entry.Stderr = fmt.Sprintf("start error: %v", err)
return entry, false
}
entry.Exit = resp.Msg.ExitCode
entry.Ok = resp.Msg.ExitCode == 0
if !entry.Ok {
entry.Stderr = fmt.Sprintf("start failed with exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
}
return entry, entry.Ok
}

View File

@ -0,0 +1,94 @@
package recipe
import (
"fmt"
"strconv"
"strings"
"time"
)
// HealthcheckConfig holds the parsed configuration for a build healthcheck.
// A healthcheck is a shell command that is executed repeatedly inside the
// sandbox until it succeeds or the retry/timeout budget is exhausted.
//
// Retries of 0 means unlimited retries (bounded only by the overall deadline)
type HealthcheckConfig struct {
Cmd string
Interval time.Duration
Timeout time.Duration
StartPeriod time.Duration
Retries int // 0 = unlimited
}
// ParseHealthcheck parses a healthcheck string with optional flag prefix into
// a HealthcheckConfig. The syntax is:
//
// [--interval=<duration>] [--timeout=<duration>] [--start-period=<duration>]
// [--retries=<n>] <command>
//
// Flags must use the form --flag=value. The first token that does not start
// with "--" and everything after it is treated as the command. Defaults:
// interval=3s, timeout=10s, start-period=0, retries=0 (unlimited)
func ParseHealthcheck(s string) (HealthcheckConfig, error) {
s = strings.TrimSpace(s)
if s == "" {
return HealthcheckConfig{}, fmt.Errorf("empty healthcheck")
}
hc := HealthcheckConfig{
Interval: 3 * time.Second,
Timeout: 10 * time.Second,
}
tokens := strings.Fields(s)
cmdIndex := -1
for i, token := range tokens {
if !strings.HasPrefix(token, "--") {
cmdIndex = i
break
}
parts := strings.SplitN(token, "=", 2)
if len(parts) != 2 {
return HealthcheckConfig{}, fmt.Errorf("malformed flag (missing '='): %q", token)
}
key, val := parts[0], parts[1]
switch key {
case "--interval":
d, err := time.ParseDuration(val)
if err != nil {
return HealthcheckConfig{}, fmt.Errorf("parse interval: %w", err)
}
hc.Interval = d
case "--timeout":
d, err := time.ParseDuration(val)
if err != nil {
return HealthcheckConfig{}, fmt.Errorf("parse timeout: %w", err)
}
hc.Timeout = d
case "--start-period":
d, err := time.ParseDuration(val)
if err != nil {
return HealthcheckConfig{}, fmt.Errorf("parse start period: %w", err)
}
hc.StartPeriod = d
case "--retries":
r, err := strconv.Atoi(val)
if err != nil {
return HealthcheckConfig{}, fmt.Errorf("parse retries: %w", err)
}
hc.Retries = r
default:
return HealthcheckConfig{}, fmt.Errorf("unknown healthcheck flag: %q", token)
}
}
if cmdIndex == -1 {
return HealthcheckConfig{}, fmt.Errorf("healthcheck has no command")
}
hc.Cmd = strings.Join(tokens[cmdIndex:], " ")
return hc, nil
}

View File

@ -0,0 +1,126 @@
package recipe
import (
"testing"
"time"
)
func TestParseHealthcheck(t *testing.T) {
tests := []struct {
name string
input string
want HealthcheckConfig
wantErr bool
}{
{
name: "plain command",
input: "curl -f http://localhost:8080",
want: HealthcheckConfig{
Cmd: "curl -f http://localhost:8080",
Interval: 3 * time.Second,
Timeout: 10 * time.Second,
},
wantErr: false,
},
{
name: "all flags",
input: "--interval=5s --timeout=2s --start-period=15s --retries=3 ping -c 1 8.8.8.8",
want: HealthcheckConfig{
Cmd: "ping -c 1 8.8.8.8",
Interval: 5 * time.Second,
Timeout: 2 * time.Second,
StartPeriod: 15 * time.Second,
Retries: 3,
},
wantErr: false,
},
{
name: "partial flags",
input: "--timeout=5s my-custom-check --verbose",
want: HealthcheckConfig{
Cmd: "my-custom-check --verbose",
Interval: 3 * time.Second,
Timeout: 5 * time.Second,
},
wantErr: false,
},
{
name: "retries only",
input: "--retries=5 test.sh",
want: HealthcheckConfig{
Cmd: "test.sh",
Interval: 3 * time.Second,
Timeout: 10 * time.Second,
Retries: 5,
},
wantErr: false,
},
{
name: "empty string",
input: "",
wantErr: true,
},
{
name: "whitespace only",
input: " \t \n ",
wantErr: true,
},
{
name: "flags but no command",
input: "--interval=5s --retries=2",
wantErr: true,
},
{
name: "unknown flag",
input: "--magic=true my-check",
wantErr: true,
},
{
name: "invalid duration",
input: "--interval=5smiles check.sh",
wantErr: true,
},
{
name: "invalid retries",
input: "--retries=five check.sh",
wantErr: true,
},
{
name: "command with dashes",
input: "--interval=2s command-with-dash --flag=value",
want: HealthcheckConfig{
Cmd: "command-with-dash --flag=value",
Interval: 2 * time.Second,
Timeout: 10 * time.Second,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseHealthcheck(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseHealthcheck() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if got.Cmd != tt.want.Cmd {
t.Errorf("Cmd got = %v, want %v", got.Cmd, tt.want.Cmd)
}
if got.Interval != tt.want.Interval {
t.Errorf("Interval got = %v, want %v", got.Interval, tt.want.Interval)
}
if got.Timeout != tt.want.Timeout {
t.Errorf("Timeout got = %v, want %v", got.Timeout, tt.want.Timeout)
}
if got.StartPeriod != tt.want.StartPeriod {
t.Errorf("StartPeriod got = %v, want %v", got.StartPeriod, tt.want.StartPeriod)
}
if got.Retries != tt.want.Retries {
t.Errorf("Retries got = %v, want %v", got.Retries, tt.want.Retries)
}
}
})
}
}

129
internal/recipe/step.go Normal file
View File

@ -0,0 +1,129 @@
package recipe
import (
"fmt"
"strings"
"time"
)
// Kind identifies the instruction type in a recipe line.
type Kind int
const (
KindRUN Kind = iota // Execute a command and wait for it to exit.
KindSTART // Start a command in the background (non-blocking).
KindENV // Set an environment variable for subsequent steps.
KindWORKDIR // Set the working directory for subsequent steps.
KindUSER // Switch the unix user for subsequent steps. (stub)
KindCOPY // Copy files into the sandbox. (stub)
)
// Step is the parsed representation of one recipe instruction.
type Step struct {
Kind Kind
Raw string // original string, preserved for logging
Shell string // KindRUN, KindSTART: the shell command text
Timeout time.Duration // KindRUN: 0 means use caller's default
Key string // KindENV: variable name
Value string // KindENV: variable value
Path string // KindWORKDIR: directory path
}
// ParseStep parses a single recipe instruction string into a Step.
// Instructions are Dockerfile-like: a keyword followed by arguments.
//
// Supported syntax:
//
// RUN <cmd> — run command, wait for exit
// RUN --timeout=<d> <cmd> — run command with explicit timeout (e.g. --timeout=5m)
// START <cmd> — start command in background, return immediately
// ENV <key>=<value> — set environment variable
// WORKDIR <path> — set working directory
// USER <name> — not yet supported
// COPY <src> <dst> — not yet supported
func ParseStep(s string) (Step, error) {
s = strings.TrimSpace(s)
if s == "" {
return Step{}, fmt.Errorf("empty step")
}
// Split on first space to get the keyword.
keyword, rest, _ := strings.Cut(s, " ")
rest = strings.TrimSpace(rest)
switch strings.ToUpper(keyword) {
case "RUN":
return parseRUN(s, rest)
case "START":
return parseSTART(s, rest)
case "ENV":
return parseENV(s, rest)
case "WORKDIR":
return parseWORKDIR(s, rest)
case "USER":
return Step{Kind: KindUSER, Raw: s}, nil
case "COPY":
return Step{Kind: KindCOPY, Raw: s}, nil
default:
return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword)
}
}
// ParseRecipe parses all recipe lines, returning on the first error.
func ParseRecipe(lines []string) ([]Step, error) {
steps := make([]Step, 0, len(lines))
for i, line := range lines {
st, err := ParseStep(line)
if err != nil {
return nil, fmt.Errorf("recipe line %d: %w", i+1, err)
}
steps = append(steps, st)
}
return steps, nil
}
func parseRUN(raw, rest string) (Step, error) {
var timeout time.Duration
if strings.HasPrefix(rest, "--timeout=") {
rest = rest[len("--timeout="):]
flag, cmd, found := strings.Cut(rest, " ")
if !found || strings.TrimSpace(cmd) == "" {
return Step{}, fmt.Errorf("RUN --timeout= flag has no command: %q", raw)
}
d, err := time.ParseDuration(flag)
if err != nil {
return Step{}, fmt.Errorf("RUN --timeout= invalid duration %q: %w", flag, err)
}
timeout = d
rest = strings.TrimSpace(cmd)
}
if rest == "" {
return Step{}, fmt.Errorf("RUN requires a command: %q", raw)
}
return Step{Kind: KindRUN, Raw: raw, Shell: rest, Timeout: timeout}, nil
}
func parseSTART(raw, rest string) (Step, error) {
if rest == "" {
return Step{}, fmt.Errorf("START requires a command: %q", raw)
}
return Step{Kind: KindSTART, Raw: raw, Shell: rest}, nil
}
func parseENV(raw, rest string) (Step, error) {
key, value, found := strings.Cut(rest, "=")
if !found {
return Step{}, fmt.Errorf("ENV requires KEY=VALUE format: %q", raw)
}
if key == "" {
return Step{}, fmt.Errorf("ENV key is empty: %q", raw)
}
return Step{Kind: KindENV, Raw: raw, Key: key, Value: value}, nil
}
func parseWORKDIR(raw, path string) (Step, error) {
if path == "" {
return Step{}, fmt.Errorf("WORKDIR requires a path: %q", raw)
}
return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil
}

View File

@ -0,0 +1,208 @@
package recipe
import (
"testing"
"time"
)
func TestParseStep(t *testing.T) {
tests := []struct {
name string
input string
want Step
wantErr bool
}{
// RUN
{
name: "RUN basic",
input: "RUN apt install -y curl",
want: Step{Kind: KindRUN, Raw: "RUN apt install -y curl", Shell: "apt install -y curl"},
},
{
name: "RUN lowercase",
input: "run echo hello",
want: Step{Kind: KindRUN, Raw: "run echo hello", Shell: "echo hello"},
},
{
name: "RUN with timeout",
input: "RUN --timeout=5m npm install",
want: Step{Kind: KindRUN, Raw: "RUN --timeout=5m npm install", Shell: "npm install", Timeout: 5 * time.Minute},
},
{
name: "RUN with timeout seconds",
input: "RUN --timeout=30s make build",
want: Step{Kind: KindRUN, Raw: "RUN --timeout=30s make build", Shell: "make build", Timeout: 30 * time.Second},
},
{
name: "RUN no command",
input: "RUN",
wantErr: true,
},
{
name: "RUN timeout no command",
input: "RUN --timeout=5m",
wantErr: true,
},
{
name: "RUN invalid timeout",
input: "RUN --timeout=notaduration echo hi",
wantErr: true,
},
// START
{
name: "START basic",
input: "START python3 app.py",
want: Step{Kind: KindSTART, Raw: "START python3 app.py", Shell: "python3 app.py"},
},
{
name: "START uppercase",
input: "START node server.js --port=8080",
want: Step{Kind: KindSTART, Raw: "START node server.js --port=8080", Shell: "node server.js --port=8080"},
},
{
name: "START no command",
input: "START",
wantErr: true,
},
// ENV
{
name: "ENV basic",
input: "ENV FOO=bar",
want: Step{Kind: KindENV, Raw: "ENV FOO=bar", Key: "FOO", Value: "bar"},
},
{
name: "ENV value with spaces",
input: "ENV GREETING=hello world",
want: Step{Kind: KindENV, Raw: "ENV GREETING=hello world", Key: "GREETING", Value: "hello world"},
},
{
name: "ENV value with equals sign",
input: "ENV URL=http://example.com?a=1",
want: Step{Kind: KindENV, Raw: "ENV URL=http://example.com?a=1", Key: "URL", Value: "http://example.com?a=1"},
},
{
name: "ENV empty value",
input: "ENV FOO=",
want: Step{Kind: KindENV, Raw: "ENV FOO=", Key: "FOO", Value: ""},
},
{
name: "ENV missing equals",
input: "ENV FOO",
wantErr: true,
},
{
name: "ENV empty key",
input: "ENV =value",
wantErr: true,
},
// WORKDIR
{
name: "WORKDIR basic",
input: "WORKDIR /app",
want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /app", Path: "/app"},
},
{
name: "WORKDIR with spaces in path",
input: "WORKDIR /my project",
want: Step{Kind: KindWORKDIR, Raw: "WORKDIR /my project", Path: "/my project"},
},
{
name: "WORKDIR empty",
input: "WORKDIR",
wantErr: true,
},
// USER and COPY stubs
{
name: "USER stub",
input: "USER www-data",
want: Step{Kind: KindUSER, Raw: "USER www-data"},
},
{
name: "COPY stub",
input: "COPY config.yaml /etc/app/config.yaml",
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"},
},
// Unknown keyword
{
name: "unknown keyword",
input: "FROBNICATE something",
wantErr: true,
},
// Empty input
{
name: "empty string",
input: "",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ParseStep(tc.input)
if tc.wantErr {
if err == nil {
t.Fatalf("ParseStep(%q) expected error, got %+v", tc.input, got)
}
return
}
if err != nil {
t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err)
}
if got != tc.want {
t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want)
}
})
}
}
func TestParseRecipe(t *testing.T) {
t.Run("valid recipe", func(t *testing.T) {
lines := []string{
"RUN apt update",
"WORKDIR /app",
"ENV PORT=8080",
"START python3 server.py",
"RUN --timeout=2m pip install -r requirements.txt",
}
steps, err := ParseRecipe(lines)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(steps) != 5 {
t.Fatalf("expected 5 steps, got %d", len(steps))
}
if steps[0].Kind != KindRUN {
t.Errorf("step 0: want KindRUN, got %v", steps[0].Kind)
}
if steps[1].Kind != KindWORKDIR {
t.Errorf("step 1: want KindWORKDIR, got %v", steps[1].Kind)
}
if steps[3].Kind != KindSTART {
t.Errorf("step 3: want KindSTART, got %v", steps[3].Kind)
}
if steps[4].Timeout != 2*time.Minute {
t.Errorf("step 4: want 2m timeout, got %v", steps[4].Timeout)
}
})
t.Run("error on invalid line", func(t *testing.T) {
lines := []string{
"RUN apt update",
"BADCMD something",
}
_, err := ParseRecipe(lines)
if err == nil {
t.Fatal("expected error for invalid line, got nil")
}
})
t.Run("empty recipe", func(t *testing.T) {
steps, err := ParseRecipe(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(steps) != 0 {
t.Fatalf("expected 0 steps, got %d", len(steps))
}
})
}

View File

@ -0,0 +1,85 @@
package sandbox
import (
"sync"
"sync/atomic"
"time"
)
// ConnTracker tracks active proxy connections for a single sandbox and
// provides a drain mechanism for pre-pause graceful shutdown.
// It is safe for concurrent use.
type ConnTracker struct {
draining atomic.Bool
wg sync.WaitGroup
// cancelMu protects cancelDrain so Reset can signal a timed-out Drain
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
cancelMu sync.Mutex
cancelDrain chan struct{}
}
// Acquire registers one in-flight connection. Returns false if the tracker
// is already draining; the caller must not call Release in that case.
func (t *ConnTracker) Acquire() bool {
if t.draining.Load() {
return false
}
t.wg.Add(1)
// Re-check after Add: Drain may have set draining between our Load
// and Add. If so, undo the Add and reject the connection.
if t.draining.Load() {
t.wg.Done()
return false
}
return true
}
// Release marks one connection as complete. Must be called exactly once
// per successful Acquire.
func (t *ConnTracker) Release() {
t.wg.Done()
}
// Drain marks the tracker as draining (all future Acquire calls return
// false) and waits up to timeout for in-flight connections to finish.
func (t *ConnTracker) Drain(timeout time.Duration) {
t.draining.Store(true)
cancel := make(chan struct{})
t.cancelMu.Lock()
t.cancelDrain = cancel
t.cancelMu.Unlock()
done := make(chan struct{})
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-cancel:
// Reset was called; stop waiting.
case <-time.After(timeout):
}
}
// Reset re-enables the tracker after a failed drain. This allows the
// sandbox to accept proxy connections again if the pause operation fails
// and the VM is resumed. It also cancels any lingering Drain goroutine.
func (t *ConnTracker) Reset() {
t.cancelMu.Lock()
if t.cancelDrain != nil {
select {
case <-t.cancelDrain:
// Already closed.
default:
close(t.cancelDrain)
}
t.cancelDrain = nil
}
t.cancelMu.Unlock()
t.draining.Store(false)
}

106
internal/sandbox/images.go Normal file
View File

@ -0,0 +1,106 @@
package sandbox
import (
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
)
// DefaultDiskSizeMB is the standard disk size for base images. Images smaller
// than this are expanded at startup so that dm-snapshot sandboxes see the full
// size without per-sandbox copies. The expansion is sparse — only metadata
// changes; no physical disk is consumed beyond the original content.
const DefaultDiskSizeMB = 5120 // 5 GB
// EnsureImageSizes walks template directories and expands any rootfs.ext4 that
// is smaller than the target size. This is idempotent: images already at or
// above the target size are left untouched. Should be called once at host agent
// startup before any sandboxes are created.
func EnsureImageSizes(wrennDir string, targetMB int) error {
if targetMB <= 0 {
targetMB = DefaultDiskSizeMB
}
targetBytes := int64(targetMB) * 1024 * 1024
// Expand the built-in minimal image.
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil {
return err
}
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep.
teamsDir := layout.TeamsDir(wrennDir)
teamEntries, err := os.ReadDir(teamsDir)
if err != nil {
if os.IsNotExist(err) {
return nil // teams dir doesn't exist yet — nothing to expand
}
return fmt.Errorf("read teams dir: %w", err)
}
for _, teamEntry := range teamEntries {
if !teamEntry.IsDir() {
continue
}
teamPath := filepath.Join(teamsDir, teamEntry.Name())
templateEntries, err := os.ReadDir(teamPath)
if err != nil {
continue
}
for _, tmplEntry := range templateEntries {
if !tmplEntry.IsDir() {
continue
}
rootfs := filepath.Join(teamPath, tmplEntry.Name(), "rootfs.ext4")
if err := expandImage(rootfs, targetBytes, targetMB); err != nil {
return err
}
}
}
return nil
}
// expandImage expands a single rootfs image if it is smaller than targetBytes.
func expandImage(rootfs string, targetBytes int64, targetMB int) error {
info, err := os.Stat(rootfs)
if err != nil {
return nil // not every template dir has a rootfs.ext4
}
if info.Size() >= targetBytes {
return nil // already large enough
}
slog.Info("expanding base image",
"path", rootfs,
"from_mb", info.Size()/(1024*1024),
"to_mb", targetMB,
)
// Expand the file (sparse — instant, no physical disk used).
if err := os.Truncate(rootfs, targetBytes); err != nil {
return fmt.Errorf("truncate %s: %w", rootfs, err)
}
// Check filesystem before resize.
if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil {
// e2fsck returns 1 if it fixed errors, which is fine.
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
return fmt.Errorf("e2fsck %s: %s: %w", rootfs, string(out), err)
}
}
// Grow the ext4 filesystem to fill the new file size.
if out, err := exec.Command("resize2fs", rootfs).CombinedOutput(); err != nil {
return fmt.Errorf("resize2fs %s: %s: %w", rootfs, string(out), err)
}
slog.Info("base image expanded", "path", rootfs, "size_mb", targetMB)
return nil
}

File diff suppressed because it is too large Load Diff

178
internal/sandbox/metrics.go Normal file
View File

@ -0,0 +1,178 @@
package sandbox
import (
"sync"
"time"
)
// MetricPoint holds one metrics sample.
type MetricPoint struct {
Timestamp time.Time
CPUPct float64
MemBytes int64
DiskBytes int64
}
// Ring buffer capacity constants.
const (
ring10mCap = 1200 // 500ms × 1200 = 10 min
ring2hCap = 240 // 30s × 240 = 2 h
ring24hCap = 288 // 5min × 288 = 24 h
downsample2hEvery = 60 // 60 × 500ms = 30s
downsample24hEvery = 10 // 10 × 30s = 5min
)
// metricsRing holds three tiered ring buffers with automatic downsampling
// from the finest tier into coarser tiers.
type metricsRing struct {
mu sync.Mutex
// 10-minute tier: 500ms samples.
buf10m [ring10mCap]MetricPoint
idx10m int
count10m int
// 2-hour tier: 30s averages.
buf2h [ring2hCap]MetricPoint
idx2h int
count2h int
// 24-hour tier: 5min averages.
buf24h [ring24hCap]MetricPoint
idx24h int
count24h int
// Accumulators for downsampling.
acc500ms [downsample2hEvery]MetricPoint
acc500msN int
acc30s [downsample24hEvery]MetricPoint
acc30sN int
}
// newMetricsRing creates an empty metrics ring buffer.
func newMetricsRing() *metricsRing {
return &metricsRing{}
}
// Push adds a 500ms sample to the finest tier and triggers downsampling
// into coarser tiers when enough samples have accumulated.
func (r *metricsRing) Push(p MetricPoint) {
r.mu.Lock()
defer r.mu.Unlock()
// Write to 10m ring.
r.buf10m[r.idx10m] = p
r.idx10m = (r.idx10m + 1) % ring10mCap
if r.count10m < ring10mCap {
r.count10m++
}
// Accumulate for 2h downsample.
r.acc500ms[r.acc500msN] = p
r.acc500msN++
if r.acc500msN == downsample2hEvery {
avg := averagePoints(r.acc500ms[:downsample2hEvery])
r.push2h(avg)
r.acc500msN = 0
}
}
func (r *metricsRing) push2h(p MetricPoint) {
r.buf2h[r.idx2h] = p
r.idx2h = (r.idx2h + 1) % ring2hCap
if r.count2h < ring2hCap {
r.count2h++
}
// Accumulate for 24h downsample.
r.acc30s[r.acc30sN] = p
r.acc30sN++
if r.acc30sN == downsample24hEvery {
avg := averagePoints(r.acc30s[:downsample24hEvery])
r.push24h(avg)
r.acc30sN = 0
}
}
func (r *metricsRing) push24h(p MetricPoint) {
r.buf24h[r.idx24h] = p
r.idx24h = (r.idx24h + 1) % ring24hCap
if r.count24h < ring24hCap {
r.count24h++
}
}
// Get10m returns the 10-minute tier points in chronological order.
func (r *metricsRing) Get10m() []MetricPoint {
r.mu.Lock()
defer r.mu.Unlock()
return r.readRing(r.buf10m[:], r.idx10m, r.count10m)
}
// Get2h returns the 2-hour tier points in chronological order.
func (r *metricsRing) Get2h() []MetricPoint {
r.mu.Lock()
defer r.mu.Unlock()
return r.readRing(r.buf2h[:], r.idx2h, r.count2h)
}
// Get24h returns the 24-hour tier points in chronological order.
func (r *metricsRing) Get24h() []MetricPoint {
r.mu.Lock()
defer r.mu.Unlock()
return r.readRing(r.buf24h[:], r.idx24h, r.count24h)
}
// Flush returns all three tiers and resets the ring buffer.
func (r *metricsRing) Flush() (pts10m, pts2h, pts24h []MetricPoint) {
r.mu.Lock()
defer r.mu.Unlock()
pts10m = r.readRing(r.buf10m[:], r.idx10m, r.count10m)
pts2h = r.readRing(r.buf2h[:], r.idx2h, r.count2h)
pts24h = r.readRing(r.buf24h[:], r.idx24h, r.count24h)
// Reset all state.
r.idx10m, r.count10m = 0, 0
r.idx2h, r.count2h = 0, 0
r.idx24h, r.count24h = 0, 0
r.acc500msN = 0
r.acc30sN = 0
return pts10m, pts2h, pts24h
}
// readRing extracts elements from a circular buffer in chronological order.
func (r *metricsRing) readRing(buf []MetricPoint, nextIdx, count int) []MetricPoint {
if count == 0 {
return nil
}
result := make([]MetricPoint, count)
bufLen := len(buf)
start := (nextIdx - count + bufLen) % bufLen
for i := range count {
result[i] = buf[(start+i)%bufLen]
}
return result
}
// averagePoints computes the average of a slice of MetricPoints.
// The timestamp is set to the last point's timestamp.
func averagePoints(pts []MetricPoint) MetricPoint {
n := float64(len(pts))
var cpu float64
var mem, disk int64
for _, p := range pts {
cpu += p.CPUPct
mem += p.MemBytes
disk += p.DiskBytes
}
return MetricPoint{
Timestamp: pts[len(pts)-1].Timestamp,
CPUPct: cpu / n,
MemBytes: int64(float64(mem) / n),
DiskBytes: int64(float64(disk) / n),
}
}

83
internal/sandbox/proc.go Normal file
View File

@ -0,0 +1,83 @@
package sandbox
import (
"fmt"
"os"
"strconv"
"strings"
"syscall"
)
// cpuStat holds raw CPU jiffies read from /proc/{pid}/stat.
type cpuStat struct {
utime uint64
stime uint64
}
// readCPUStat reads user and system CPU jiffies from /proc/{pid}/stat.
// Fields 14 (utime) and 15 (stime) are 1-indexed in the man page;
// after splitting on space, they are at indices 13 and 14.
func readCPUStat(pid int) (cpuStat, error) {
path := fmt.Sprintf("/proc/%d/stat", pid)
data, err := os.ReadFile(path)
if err != nil {
return cpuStat{}, fmt.Errorf("read stat: %w", err)
}
// /proc/{pid}/stat format: pid (comm) state fields...
// The comm field may contain spaces and parens, so find the last ')' first.
content := string(data)
idx := strings.LastIndex(content, ")")
if idx < 0 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid)
}
// After ")" there is " state field3 field4 ... fieldN"
// field1 after ')' is state (index 0), utime is field 11, stime is field 12
// (0-indexed from after the closing paren).
fields := strings.Fields(content[idx+2:])
if len(fields) < 13 {
return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields))
}
utime, err := strconv.ParseUint(fields[11], 10, 64)
if err != nil {
return cpuStat{}, fmt.Errorf("parse utime: %w", err)
}
stime, err := strconv.ParseUint(fields[12], 10, 64)
if err != nil {
return cpuStat{}, fmt.Errorf("parse stime: %w", err)
}
return cpuStat{utime: utime, stime: stime}, nil
}
// readMemRSS reads VmRSS from /proc/{pid}/status and returns bytes.
func readMemRSS(pid int) (int64, error) {
path := fmt.Sprintf("/proc/%d/status", pid)
data, err := os.ReadFile(path)
if err != nil {
return 0, fmt.Errorf("read status: %w", err)
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line)
if len(fields) < 2 {
return 0, fmt.Errorf("malformed VmRSS line")
}
kb, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, fmt.Errorf("parse VmRSS: %w", err)
}
return kb * 1024, nil
}
}
return 0, fmt.Errorf("VmRSS not found in /proc/%d/status", pid)
}
// readDiskAllocated returns the actual allocated bytes (not apparent size)
// of the file at path. This uses stat's block count × 512.
func readDiskAllocated(path string) (int64, error) {
var stat syscall.Stat_t
if err := syscall.Stat(path, &stat); err != nil {
return 0, fmt.Errorf("stat %s: %w", path, err)
}
return stat.Blocks * 512, nil
}

View File

@ -0,0 +1,71 @@
package scheduler
import (
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// HostScheduler selects a host for a new sandbox. Implementations may use
// different strategies (round-robin, least-loaded, tag-based, etc.).
type HostScheduler interface {
// SelectHost returns a host that can accept a new sandbox.
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
// are considered. For non-BYOC teams, only online regular (platform) hosts
// are considered. Returns an error if no suitable host is available.
SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error)
}
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
// It re-fetches the host list on every call so that newly registered or
// recovered hosts are considered immediately.
type RoundRobinScheduler struct {
db *db.Queries
counter atomic.Int64
}
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
return &RoundRobinScheduler{db: queries}
}
// SelectHost returns the next eligible online host in round-robin order.
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
}
var eligible []db.Host
for _, h := range hosts {
if h.Status != "online" || h.Address == "" {
continue
}
if isByoc {
// BYOC team: only use hosts belonging to this team.
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
continue
}
} else {
// Non-BYOC team: only use platform (regular) hosts.
if h.Type != "regular" {
continue
}
}
eligible = append(eligible, h)
}
if len(eligible) == 0 {
if isByoc {
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
}
return db.Host{}, fmt.Errorf("no online platform hosts available")
}
idx := s.counter.Add(1) - 1
return eligible[int(idx%int64(len(eligible)))], nil
}

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
@ -22,7 +24,7 @@ type APIKeyCreateResult struct {
}
// Create generates a new API key for the given team.
func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string)
}
// List returns all API keys belonging to the given team.
func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}

113
internal/service/audit.go Normal file
View File

@ -0,0 +1,113 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
const auditMaxLimit = 200
// AuditEntry is a single audit log record returned by List.
type AuditEntry struct {
ID string
TeamID string
ActorType string
ActorID string // empty for system
ActorName string // empty for system
ResourceType string
ResourceID string // empty when not applicable
Action string
Scope string
Status string // 'success', 'info', 'warning', 'error'
Metadata map[string]any
CreatedAt time.Time
}
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
TeamID pgtype.UUID
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
ResourceTypes []string // empty = no filter; multiple values = OR match
Actions []string // empty = no filter; multiple values = OR match
Before time.Time // zero = no cursor (start from latest)
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
type AuditService struct {
DB *db.Queries
}
// List returns a page of audit log entries for the given team.
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
limit := p.Limit
if limit <= 0 {
limit = 50
}
if limit > auditMaxLimit {
limit = auditMaxLimit
}
scopes := []string{"team"}
if p.AdminScoped {
scopes = append(scopes, "admin")
}
var before pgtype.Timestamptz
if !p.Before.IsZero() {
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
}
resourceTypes := p.ResourceTypes
if resourceTypes == nil {
resourceTypes = []string{}
}
actions := p.Actions
if actions == nil {
actions = []string{}
}
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
TeamID: p.TeamID,
Column2: scopes,
Column3: resourceTypes,
Column4: actions,
Column5: before,
ID: p.BeforeID,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list audit logs: %w", err)
}
entries := make([]AuditEntry, len(rows))
for i, row := range rows {
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
ID: id.FormatAuditLogID(row.ID),
TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
Scope: row.Scope,
Status: row.Status,
Metadata: meta,
CreatedAt: row.CreatedAt.Time,
}
}
return entries, nil
}

605
internal/service/build.go Normal file
View File

@ -0,0 +1,605 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/recipe"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
const (
buildQueueKey = "wrenn:build_queue"
buildCommandTimeout = 30 * time.Second
)
// preBuildCmds run before the user recipe to prepare the build environment.
var preBuildCmds = []string{
"RUN apt update",
}
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
var postBuildCmds = []string{
"RUN apt clean",
"RUN apt autoremove -y",
"RUN rm -rf /var/lib/apt/lists/*",
}
// buildAgentClient is the subset of the host agent client used by the build worker.
type buildAgentClient interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
}
// BuildService handles template build orchestration.
type BuildService struct {
DB *db.Queries
Redis *redis.Client
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
mu sync.Mutex
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
}
// BuildCreateParams holds the parameters for creating a template build.
type BuildCreateParams struct {
Name string
BaseTemplate string
Recipe []string
Healthcheck string
VCPUs int32
MemoryMB int32
SkipPrePost bool
}
// Create inserts a new build record and enqueues it to Redis.
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
if p.BaseTemplate == "" {
p.BaseTemplate = "minimal"
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
recipeJSON, err := json.Marshal(p.Recipe)
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
}
buildID := id.NewBuildID()
buildIDStr := id.FormatBuildID(buildID)
newTemplateID := id.NewTemplateID()
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
if p.SkipPrePost {
defaultSteps = 0
}
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
ID: buildID,
Name: p.Name,
BaseTemplate: p.BaseTemplate,
Recipe: recipeJSON,
Healthcheck: p.Healthcheck,
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TotalSteps: int32(len(p.Recipe) + defaultSteps),
TemplateID: newTemplateID,
TeamID: id.PlatformTeamID,
SkipPrePost: p.SkipPrePost,
})
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
}
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
}
return build, nil
}
// Get returns a single build by ID.
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
return s.DB.GetTemplateBuild(ctx, buildID)
}
// List returns all builds ordered by creation time.
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
return s.DB.ListTemplateBuilds(ctx)
}
// Cancel cancels a pending or running build. For pending builds the status is
// updated in the DB and the worker skips it when dequeued. For running builds
// the per-build context is cancelled, which causes the current exec step to
// abort; executeBuild then detects the cancellation and records the status.
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
build, err := s.DB.GetTemplateBuild(ctx, buildID)
if err != nil {
return fmt.Errorf("get build: %w", err)
}
switch build.Status {
case "success", "failed", "cancelled":
return fmt.Errorf("build is already %s", build.Status)
}
// Mark cancelled in DB first. This handles both pending builds (which haven't
// been picked up yet) and acts as a flag for executeBuild to check on start.
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
ID: buildID, Status: "cancelled",
}); err != nil {
return fmt.Errorf("update build status: %w", err)
}
// If the build is currently running, signal its context.
buildIDStr := id.FormatBuildID(buildID)
s.mu.Lock()
cancel, running := s.cancelMap[buildIDStr]
s.mu.Unlock()
if running {
cancel()
}
return nil
}
// StartWorkers launches n goroutines that consume from the Redis build queue.
// The returned cancel function stops all workers.
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
for i := range n {
go s.worker(ctx, i)
}
slog.Info("build workers started", "count", n)
return cancel
}
func (s *BuildService) worker(ctx context.Context, workerID int) {
log := slog.With("worker", workerID)
for {
// BLPOP blocks until a build ID is available or context is cancelled.
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
if err != nil {
if ctx.Err() != nil {
log.Info("build worker shutting down")
return
}
log.Error("redis BLPOP error", "error", err)
time.Sleep(time.Second)
continue
}
// result[0] is the key, result[1] is the build ID (formatted string).
buildIDStr := result[1]
log.Info("picked up build", "build_id", buildIDStr)
s.executeBuild(ctx, buildIDStr)
}
}
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
log := slog.With("build_id", buildIDStr)
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
log.Error("invalid build ID from queue", "error", err)
return
}
// Create a per-build context so this build can be cancelled independently of
// the worker. Register in cancelMap before fetching the build so that a
// concurrent Cancel call can always find and signal it.
buildCtx, buildCancel := context.WithCancel(ctx)
defer buildCancel()
s.mu.Lock()
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[buildIDStr] = buildCancel
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.cancelMap, buildIDStr)
s.mu.Unlock()
}()
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
if err != nil {
log.Error("failed to fetch build", "error", err)
return
}
// Skip if already cancelled (Cancel was called before we dequeued).
if build.Status == "cancelled" {
log.Info("build already cancelled, skipping")
return
}
// Mark as running.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "running",
}); err != nil {
log.Error("failed to update build status", "error", err)
return
}
// Parse user recipe.
var userRecipe []string
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
return
}
// Pick a platform host and create a sandbox.
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
return
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
return
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
baseTeamID := id.PlatformTeamID
baseTemplateID := id.MinimalTemplateID
if build.BaseTemplate != "minimal" {
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
return
}
baseTeamID = baseTmpl.TeamID
baseTemplateID = baseTmpl.ID
}
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: build.BaseTemplate,
TeamId: id.UUIDString(baseTeamID),
TemplateId: id.UUIDString(baseTemplateID),
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
TimeoutSec: 0, // no auto-pause for builds
DiskSizeMb: 5120, // 5 GB for template builds
}))
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
return
}
_ = resp
// Record sandbox/host association.
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
ID: buildID,
SandboxID: sandboxID,
HostID: host.ID,
})
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
// valid; panic on error is appropriate here since it would be a programmer mistake.
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
}
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
return
}
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
}
var logs []recipe.BuildLogEntry
step := 0
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
if err != nil {
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
envVars = map[string]string{
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"HOME": "/root",
}
}
bctx := &recipe.ExecContext{EnvVars: envVars}
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
logs = append(logs, newEntries...)
step = nextStep
s.updateLogs(buildCtx, buildID, step, logs)
if !ok {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
// If the build was cancelled, status is already set — don't overwrite with "failed".
if buildCtx.Err() != nil {
return false
}
last := newEntries[len(newEntries)-1]
reason := last.Stderr
if reason == "" {
reason = fmt.Sprintf("exit code %d", last.Exit)
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
}
return ok
}
if !build.SkipPrePost {
if !runPhase("pre-build", preBuildSteps, 0) {
return
}
}
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
return
}
if !build.SkipPrePost {
if !runPhase("post-build", postBuildSteps, 0) {
return
}
}
// Healthcheck or direct snapshot.
var sizeBytes int64
if build.Healthcheck != "" {
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
return
}
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
return
}
// Healthcheck passed → full snapshot (with memory/CPU state).
log.Info("healthcheck passed, creating snapshot")
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
return
}
sizeBytes = snapResp.Msg.SizeBytes
} else {
// No healthcheck → image-only template (rootfs only).
log.Info("no healthcheck, flattening rootfs")
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
return
}
sizeBytes = flatResp.Msg.SizeBytes
}
// Insert into templates table as a global (platform) template.
templateType := "base"
if build.Healthcheck != "" {
templateType = "snapshot"
}
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
ID: build.TemplateID,
Name: build.Name,
Type: templateType,
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
SizeBytes: sizeBytes,
TeamID: id.PlatformTeamID,
}); err != nil {
log.Error("failed to insert template record", "error", err)
// Build succeeded on disk, just DB record failed — don't mark as failed.
}
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
// No additional destroy needed.
// Mark build as success.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "success",
}); err != nil {
log.Error("failed to mark build as success", "error", err)
}
log.Info("template build completed successfully", "name", build.Name)
}
// waitForHealthcheck repeatedly executes the healthcheck command inside the
// sandbox according to the config's interval, timeout, start-period, and
// retries.
// During the start period, failures are not counted toward the retry budget.
// Returns nil on the first successful check, or an error if retries are
// exhausted, the deadline passes, or the context is cancelled.
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
ticker := time.NewTicker(hc.Interval)
defer ticker.Stop()
// When retries > 0, set a deadline based on the retry budget.
// When retries == 0 (unlimited), rely solely on the parent context deadline.
var deadlineCh <-chan time.Time
if hc.Retries > 0 {
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
defer deadline.Stop()
deadlineCh = deadline.C
}
startedAt := time.Now()
failCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadlineCh:
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
case <-ticker.C:
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", hc.Cmd},
TimeoutSec: int32(hc.Timeout.Seconds()),
}))
cancel()
if err != nil {
slog.Debug("healthcheck exec error (retrying)", "error", err)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
}
}
continue
}
if resp.Msg.ExitCode == 0 {
return nil
}
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
}
}
}
}
}
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
logsJSON, err := json.Marshal(logs)
if err != nil {
slog.Warn("failed to marshal build logs", "error", err)
return
}
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
ID: buildID,
CurrentStep: int32(step),
Logs: logsJSON,
}); err != nil {
slog.Warn("failed to update build progress", "error", err)
}
}
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
ID: buildID,
Error: errMsg,
}); err != nil {
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
}
}
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
// Use a detached context so cleanup succeeds even during shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
}
}
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
// the build agent and returns environment variables
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", "env"},
TimeoutSec: 10,
}))
if err != nil {
return nil, fmt.Errorf("fetch env: %w", err)
}
if resp.Msg.ExitCode != 0 {
return nil, fmt.Errorf("fetch env: command exited with code %d",
resp.Msg.ExitCode)
}
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
}
// parseSandboxEnv converts the raw newline-separated output of an 'env'
// command into a map.
// It skips empty lines and malformed entries, and correctly handles values
// containing '='.
func parseSandboxEnv(raw string) map[string]string {
envVars := make(map[string]string)
for line := range strings.SplitSeq(raw, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
envVars[parts[0]] = parts[1]
}
return envVars
}

View File

@ -2,12 +2,14 @@ package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
@ -15,6 +17,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// HostService provides host management operations.
@ -22,15 +26,17 @@ type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID string // required for BYOC, empty for regular
TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
RequestingUserID string
RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
@ -50,10 +56,34 @@ type HostRegisterParams struct {
Address string
}
// HostRegisterResult holds the registered host and its long-lived JWT.
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
@ -64,6 +94,14 @@ type regTokenPayload struct {
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
@ -75,8 +113,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: admin or team owner.
if p.TeamID == "" {
// BYOC: platform admin, or team owner/admin.
if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
@ -90,40 +128,31 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
// Validate team exists for BYOC hosts.
if p.TeamID != "" {
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
// Validate team exists, is not deleted, and has BYOC enabled.
if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
if !team.IsByoc {
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
}
}
hostID := id.NewHostID()
var teamID pgtype.Text
if p.TeamID != "" {
teamID = pgtype.Text{String: p.TeamID, Valid: true}
}
var provider pgtype.Text
if p.Provider != "" {
provider = pgtype.Text{String: p.Provider, Valid: true}
}
var az pgtype.Text
if p.AvailabilityZone != "" {
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
}
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: teamID,
Provider: provider,
AvailabilityZone: az,
TeamID: p.TeamID,
Provider: p.Provider,
AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
@ -135,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@ -149,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
@ -158,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
@ -167,12 +196,11 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
// Same permission model as Create/Delete.
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
@ -185,8 +213,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
@ -194,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@ -208,14 +236,14 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a long-lived host JWT.
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
@ -232,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
hostID, err := id.ParseHostID(payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
tokenID, err := id.ParseHostTokenID(payload.TokenID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: payload.HostID,
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
ID: hostID,
Arch: p.Arch,
CpuCores: p.CPUCores,
MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB,
Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
@ -259,82 +307,293 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, payload.HostID)
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// Heartbeat updates the last heartbeat timestamp for a host.
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
return s.DB.UpdateHostHeartbeat(ctx, hostID)
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
}
if n == 0 {
return fmt.Errorf("host not found")
}
return nil
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID.String != teamID {
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// Delete removes a host. Admins can delete any host. Team owners can delete
// BYOC hosts belonging to their team.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
host, err := s.DB.GetHost(ctx, hostID)
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return fmt.Errorf("host not found: %w", err)
return HostDeletePreview{}, err
}
if !isAdmin {
if host.Type != "byoc" {
return fmt.Errorf("forbidden: only admins can delete regular hosts")
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return fmt.Errorf("forbidden: host does not belong to your team")
return &HostHasSandboxesError{SandboxIDs: ids}
}
hostIDStr := id.FormatHostID(hostID)
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("forbidden: not a member of the specified team")
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err != nil {
return fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return host, nil
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@ -342,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@ -350,9 +609,20 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

View File

@ -11,29 +11,60 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard.
type SandboxService struct {
DB *db.Queries
Agent hostagentv1connect.HostAgentServiceClient
DB *db.Queries
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
}
// SandboxCreateParams holds the parameters for creating a sandbox.
type SandboxCreateParams struct {
TeamID string
TeamID pgtype.UUID
Template string
VCPUs int32
MemoryMB int32
TimeoutSec int32
DiskSizeMB int32
}
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
// and updates the record to running. Returns the sandbox DB row.
// agentForSandbox looks up the host for the given sandbox and returns a client.
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" {
p.Template = "minimal"
@ -47,44 +78,82 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
if p.DiskSizeMB <= 0 {
p.DiskSizeMB = 5120 // 5 GB default
}
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
if tmpl.Vcpus.Valid {
p.VCPUs = tmpl.Vcpus.Int32
// Resolve template name → (teamID, templateID).
templateTeamID := id.PlatformTeamID
templateID := id.MinimalTemplateID
if p.Template != "minimal" {
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
}
if tmpl.MemoryMb.Valid {
p.MemoryMB = tmpl.MemoryMb.Int32
templateTeamID = tmpl.TeamID
templateID = tmpl.ID
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl.Type == "snapshot" {
p.VCPUs = tmpl.Vcpus
p.MemoryMB = tmpl.MemoryMb
}
}
if !p.TeamID.Valid {
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
}
// Determine whether this team uses BYOC hosts or platform hosts.
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
}
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID,
TeamID: p.TeamID,
HostID: "default",
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
ID: sandboxID,
TeamID: p.TeamID,
HostID: host.ID,
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
TemplateID: templateID,
TemplateTeamID: templateTeamID,
}); err != nil {
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxID,
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: p.Template,
TeamId: id.UUIDString(templateTeamID),
TemplateId: id.UUIDString(templateID),
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
}))
if err != nil {
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "error",
}); dbErr != nil {
slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
}
@ -107,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
}
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
return s.DB.ListSandboxesByTeam(ctx, teamID)
}
// Get returns a single sandbox by ID, scoped to the given team.
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
}
// Pause snapshots and freezes a running sandbox to disk.
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@ -126,23 +195,45 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
// mark the sandbox "stopped" while the host agent processes the pause.
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
}
// Flush all metrics tiers before pausing so data survives in DB.
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
// Revert status on failure.
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}
sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
})
sb, err = s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
}
return sb, nil
}
// Resume restores a paused sandbox from snapshot.
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@ -151,8 +242,15 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
}
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxIDStr,
TimeoutSec: sb.TimeoutSec,
}))
if err != nil {
@ -176,18 +274,41 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
}
// Destroy stops a sandbox and marks it as stopped.
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
if _, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}); err != nil {
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// If running, flush 24h tier metrics for analytics before destroying.
if sb.Status == "running" {
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxID,
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
}
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
if sb.Status == "paused" {
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "10m",
})
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "2h",
})
}
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "stopped",
}); err != nil {
@ -196,8 +317,45 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
return nil
}
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
// the returned data to DB. If allTiers is true, all three tiers are saved;
// otherwise only the 24h tier (for post-destroy analytics).
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
SandboxId: sandboxIDStr,
}))
if err != nil {
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
return
}
msg := resp.Msg
if allTiers {
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
}
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
}
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
for _, p := range points {
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
SandboxID: sandboxID,
Tier: tier,
Ts: p.TimestampUnix,
CpuPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}); err != nil {
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
}
}
}
// Ping resets the inactivity timer for a running sandbox.
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
@ -206,8 +364,15 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)
}
@ -219,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
Valid: true,
},
}); err != nil {
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
}
return nil
}

160
internal/service/stats.go Normal file
View File

@ -0,0 +1,160 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// TimeRange identifies a chart time window.
type TimeRange string
const (
Range5m TimeRange = "5m"
Range1h TimeRange = "1h"
Range6h TimeRange = "6h"
Range24h TimeRange = "24h"
Range30d TimeRange = "30d"
)
type rangeConfig struct {
bucketSec int // bucket width in seconds for time-series aggregation
intervalLiteral string // PostgreSQL interval literal for the lookback window
}
var rangeConfigs = map[TimeRange]rangeConfig{
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
}
// ValidRange returns true if r is a known TimeRange value.
func ValidRange(r TimeRange) bool {
_, ok := rangeConfigs[r]
return ok
}
// StatPoint is one bucketed data point in the time-series.
type StatPoint struct {
Bucket time.Time
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// CurrentStats holds the live values for a team, read directly from sandboxes.
type CurrentStats struct {
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// PeakStats holds the 30-day maximum values for a team.
type PeakStats struct {
RunningCount int32
VCPUs int32
MemoryMB int32
}
// StatsService computes sandbox metrics for the dashboard.
type StatsService struct {
DB *db.Queries
Pool *pgxpool.Pool
}
// GetStats returns current stats, 30-day peaks, and a time-series for the
// given team and time range. If no snapshots exist yet, zeros are returned.
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
cfg, ok := rangeConfigs[r]
if !ok {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
}
// Current live values — read directly from sandboxes so we always reflect
// the true state even when no capsules are running.
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
}
current := CurrentStats{
RunningCount: cur.RunningCount,
VCPUsReserved: cur.VcpusReserved,
MemoryMBReserved: cur.MemoryMbReserved,
}
// 30-day peaks.
var peaks PeakStats
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
}
if err == nil {
peaks = PeakStats{
RunningCount: pk.PeakRunningCount,
VCPUs: pk.PeakVcpus,
MemoryMB: pk.PeakMemoryMb,
}
}
// Time-series — dynamic bucket width, executed via pgx directly.
series, err := s.queryTimeSeries(ctx, teamID, cfg)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
}
return current, peaks, series, nil
}
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
// integer number of seconds without requiring TimescaleDB.
//
// MAX is used instead of AVG so that short-lived running states are not
// averaged down to zero within a bucket. For capacity metrics the peak
// value in each bucket is what matters — AVG with ::INTEGER rounding
// caused running_count, vcpus, and memory to become inconsistent with
// each other (e.g. running=0 but vcpus=1).
//
// $1 = bucket width in seconds (integer)
// $2 = team_id
// $3 = lookback interval literal (e.g. '1 hour')
const timeSeriesSQL = `
SELECT
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
MAX(running_count) AS running_count,
MAX(vcpus_reserved) AS vcpus_reserved,
MAX(memory_mb_reserved) AS memory_mb_reserved
FROM sandbox_metrics_snapshots
WHERE team_id = $2
AND sampled_at >= NOW() - $3::INTERVAL
GROUP BY bucket
ORDER BY bucket ASC
`
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
if err != nil {
return nil, err
}
defer rows.Close()
var points []StatPoint
for rows.Next() {
var p StatPoint
var bucket time.Time
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
return nil, err
}
p.Bucket = bucket
points = append(points, p)
}
return points, rows.Err()
}

443
internal/service/team.go Normal file
View File

@ -0,0 +1,443 @@
package service
import (
"context"
"fmt"
"log/slog"
"regexp"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
// TeamService provides team management operations.
type TeamService struct {
DB *db.Queries
Pool *pgxpool.Pool
HostPool *lifecycle.HostClientPool
}
// TeamWithRole pairs a team with the calling user's role in it.
type TeamWithRole struct {
db.Team
Role string `json:"role"`
}
// MemberInfo is a team member with resolved user details.
type MemberInfo struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt time.Time `json:"joined_at"`
}
// callerRole fetches the calling user's role in the given team from DB.
// Returns an error wrapping "forbidden" if the caller is not a member.
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: callerUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("forbidden: not a member of this team")
}
return "", fmt.Errorf("get membership: %w", err)
}
return m.Role, nil
}
// requireAdmin returns an error if the caller is not an admin or owner.
func requireAdmin(role string) error {
if role != "owner" && role != "admin" {
return fmt.Errorf("forbidden: admin or owner role required")
}
return nil
}
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
if err == pgx.ErrNoRows {
return db.Team{}, fmt.Errorf("team not found")
}
return db.Team{}, fmt.Errorf("get team: %w", err)
}
if team.DeletedAt.Valid {
return db.Team{}, fmt.Errorf("team not found")
}
return team, nil
}
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
rows, err := s.DB.GetTeamsForUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list teams: %w", err)
}
result := make([]TeamWithRole, len(rows))
for i, r := range rows {
result[i] = TeamWithRole{
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
Role: r.Role,
}
}
return result, nil
}
// CreateTeam creates a new team owned by the given user.
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
if !teamNameRE.MatchString(name) {
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
tx, err := s.Pool.Begin(ctx)
if err != nil {
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := s.DB.WithTx(tx)
teamID := id.NewTeamID()
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: name,
Slug: id.NewTeamSlug(),
})
if err != nil {
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: ownerUserID,
TeamID: teamID,
IsDefault: false,
Role: "owner",
}); err != nil {
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
}
return TeamWithRole{Team: team, Role: "owner"}, nil
}
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
if !teamNameRE.MatchString(newName) {
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(role); err != nil {
return err
}
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
return fmt.Errorf("update name: %w", err)
}
return nil
}
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
// are preserved; only the team's deleted_at is set and active VMs are stopped.
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role != "owner" {
return fmt.Errorf("forbidden: only the owner can delete a team")
}
// Collect active sandboxes and stop them.
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("list active sandboxes: %w", err)
}
var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}
if len(stopIDs) > 0 {
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: stopIDs,
Status: "stopped",
}); err != nil {
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
// as that would leave orphaned "running" records for a deleted team.
return fmt.Errorf("update sandbox statuses: %w", err)
}
}
// Clean up team-owned templates from all hosts in the background.
go s.cleanupTeamTemplates(context.Background(), teamID)
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
// then removes the DB records. Called asynchronously during team deletion.
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
if err != nil {
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
return
}
if len(templates) == 0 {
return
}
hosts, err := s.DB.ListActiveHosts(ctx)
if err != nil {
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
return
}
for _, tmpl := range templates {
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := s.HostPool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: id.UUIDString(tmpl.TeamID),
TemplateId: id.UUIDString(tmpl.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to delete template on host",
"host_id", id.FormatHostID(host.ID),
"template", tmpl.Name,
"error", err,
)
}
}
}
// Remove DB records.
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
}
}
// GetMembers returns all members of the team with their emails and roles.
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
rows, err := s.DB.GetTeamMembers(ctx, teamID)
if err != nil {
return nil, fmt.Errorf("get members: %w", err)
}
members := make([]MemberInfo, len(rows))
for i, r := range rows {
var joinedAt time.Time
if r.JoinedAt.Valid {
joinedAt = r.JoinedAt.Time
}
members[i] = MemberInfo{
UserID: id.FormatUserID(r.ID),
Name: r.Name,
Email: r.Email,
Role: r.Role,
JoinedAt: joinedAt,
}
}
return members, nil
}
// AddMember adds an existing user (looked up by email) to the team as a member.
// Caller must be admin or owner (verified from DB).
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return MemberInfo{}, err
}
if err := requireAdmin(role); err != nil {
return MemberInfo{}, err
}
target, err := s.DB.GetUserByEmail(ctx, email)
if err != nil {
if err == pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
}
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
}
// Check if already a member.
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: target.ID,
TeamID: teamID,
})
if memberCheckErr == nil {
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
} else if memberCheckErr != pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
}
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: target.ID,
TeamID: teamID,
IsDefault: false,
Role: "member",
}); err != nil {
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
}
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
}
// RemoveMember removes a user from the team.
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: targetUserID,
}); err != nil {
return fmt.Errorf("delete member: %w", err)
}
return nil
}
// UpdateMemberRole changes a member's role to admin or member.
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
// Valid target roles: "admin", "member".
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
if newRole != "admin" && newRole != "member" {
return fmt.Errorf("invalid: role must be admin or member")
}
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner's role cannot be changed")
}
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
TeamID: teamID,
UserID: targetUserID,
Role: newRole,
}); err != nil {
return fmt.Errorf("update role: %w", err)
}
return nil
}
// LeaveTeam removes the calling user from the team.
// The owner cannot leave; they must delete the team instead.
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role == "owner" {
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: callerUserID,
}); err != nil {
return fmt.Errorf("leave team: %w", err)
}
return nil
}
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
// be disabled — it is a one-way transition.
// Admin-only — the caller must verify admin status before invoking this.
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
if !enabled {
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
}
if team.IsByoc {
// Already enabled — idempotent, no-op.
return nil
}
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
return fmt.Errorf("set byoc: %w", err)
}
return nil
}

View File

@ -3,6 +3,8 @@ package service
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
@ -14,7 +16,7 @@ type TemplateService struct {
// List returns all templates belonging to the given team. If typeFilter is
// non-empty, only templates of that type ("base" or "snapshot") are returned.
func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
if typeFilter != "" {
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
TeamID: teamID,

View File

@ -4,6 +4,7 @@
package snapshot
import (
"context"
"fmt"
"io"
"os"
@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe
return header, nil
}
// MergeDiffs consolidates multiple generation diff files into a single diff
// file and resets the generation counter to 0. This is a pure file-level
// operation — no Firecracker involvement.
//
// It reads each non-nil block from the appropriate diff file (as mapped by
// the header), writes them all sequentially into a single new diff file,
// and produces a fresh header pointing only at that file.
//
// diffFiles maps build ID (string) → open file path for each generation's diff.
func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
blockSize := int64(header.Metadata.BlockSize)
mergedBuildID := uuid.New()
// Open all source diff files.
sources := make(map[string]*os.File, len(diffFiles))
for id, path := range diffFiles {
f, err := os.Open(path)
if err != nil {
// Close already opened files.
for _, sf := range sources {
sf.Close()
}
return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
}
sources[id] = f
}
defer func() {
for _, f := range sources {
f.Close()
}
}()
dst, err := os.Create(mergedDiffPath)
if err != nil {
return nil, fmt.Errorf("create merged diff file: %w", err)
}
defer dst.Close()
totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
dirty := make([]bool, totalBlocks)
empty := make([]bool, totalBlocks)
buf := make([]byte, blockSize)
for i := int64(0); i < totalBlocks; i++ {
offset := i * blockSize
mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
if err != nil {
return nil, fmt.Errorf("lookup block %d: %w", i, err)
}
if *buildID == uuid.Nil {
empty[i] = true
continue
}
src, ok := sources[buildID.String()]
if !ok {
return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
}
if _, err := src.ReadAt(buf, mappedOffset); err != nil {
return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
}
dirty[i] = true
if _, err := dst.Write(buf); err != nil {
return nil, fmt.Errorf("write merged block %d: %w", i, err)
}
}
// Build fresh header with generation 0.
dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
merged := MergeMappings(dirtyMappings, emptyMappings)
normalized := NormalizeMappings(merged)
metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
newHeader, err := NewHeader(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("create merged header: %w", err)
}
headerData, err := Serialize(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("serialize merged header: %w", err)
}
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
return nil, fmt.Errorf("write merged header: %w", err)
}
return newHeader, nil
}
// isZeroBlock checks if a block is entirely zero bytes.
func isZeroBlock(block []byte) bool {
// Fast path: compare 8 bytes at a time.

View File

@ -11,7 +11,7 @@ func TestSafeName(t *testing.T) {
{"simple", "minimal", false},
{"with-dash", "template-abc123", false},
{"with-dot", "my-snapshot.v2", false},
{"sandbox-id", "sb-12345678", false},
{"sandbox-id", "cl-12345678", false},
{"single-char", "a", false},
{"numbers", "123", false},
{"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false},

View File

@ -4,9 +4,13 @@ import "fmt"
// VMConfig holds the configuration for creating a Firecracker microVM.
type VMConfig struct {
// SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4").
// SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4").
SandboxID string
// TemplateID is the template UUID string used to populate MMDS metadata
// so that envd can read WRENN_TEMPLATE_ID from inside the guest.
TemplateID string
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
KernelPath string
@ -91,7 +95,7 @@ func (c *VMConfig) kernelArgs() string {
)
return fmt.Sprintf(
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s",
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s",
c.InitPath, ipArg,
)
}

View File

@ -101,6 +101,31 @@ func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error
})
}
// setMMDSConfig enables MMDS V2 token-based access on the given network interface.
// Must be called before startVM.
func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error {
return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{
"version": "V2",
"network_interfaces": []string{ifaceID},
})
}
// mmdsMetadata is the metadata payload written to the Firecracker MMDS store.
// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID.
type mmdsMetadata struct {
SandboxID string `json:"instanceID"`
TemplateID string `json:"envID"`
}
// setMMDS writes sandbox metadata to the Firecracker MMDS store.
// Can be called after the VM has started.
func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error {
return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{
SandboxID: sandboxID,
TemplateID: templateID,
})
}
// startVM issues the InstanceStart action.
func (c *fcClient) startVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/actions", map[string]string{

View File

@ -71,6 +71,13 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
return nil, fmt.Errorf("start VM: %w", err)
}
// Step 5: Push sandbox metadata into MMDS so envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("set MMDS metadata: %w", err)
}
vm := &VM{
Config: cfg,
process: proc,
@ -108,6 +115,12 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
return fmt.Errorf("set machine config: %w", err)
}
// MMDS config — enable V2 token access on eth0 so that envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
return fmt.Errorf("set MMDS config: %w", err)
}
return nil
}
@ -238,6 +251,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
return nil, fmt.Errorf("resume VM: %w", err)
}
// Step 5: Push sandbox metadata into MMDS.
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("set MMDS metadata: %w", err)
}
vm := &VM{
Config: cfg,
process: proc,
@ -250,6 +269,12 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
return vm, nil
}
// PID returns the process ID of the unshare wrapper process.
// The actual Firecracker process is a direct child of this PID.
func (v *VM) PID() int {
return v.process.cmd.Process.Pid
}
// Get returns a running VM by sandbox ID.
func (m *Manager) Get(sandboxID string) (*VM, bool) {
vm, ok := m.vms[sandboxID]