1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/sandbox#8
This commit is contained in:
2026-04-09 19:24:49 +00:00
parent 32e5a5a715
commit d3e4812e46
199 changed files with 24552 additions and 2776 deletions

View File

@ -0,0 +1,22 @@
package api
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)
}
return pool.GetForHost(host)
}

View File

@ -0,0 +1,230 @@
package api
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"path"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
)
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
// without relying on error message text.
var (
errProxySandboxNotFound = errors.New("sandbox not found")
errProxyNoHostAddress = errors.New("host agent has no address")
)
const proxyCacheTTL = 120 * time.Second
// sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or
// "49999-cl-abcd1234.example.com". Captures: port, sandbox ID.
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`)
// errProxySandboxNotRunning carries the sandbox status so callers can include
// it in the HTTP response without parsing error strings.
type errProxySandboxNotRunning struct{ status string }
func (e errProxySandboxNotRunning) Error() string {
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
}
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
// can capture the correct port without the cache key needing to include it.
type proxyCacheEntry struct {
agentURL *url.URL
expiresAt time.Time
}
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
type proxyCacheKey [32]byte
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
var k proxyCacheKey
copy(k[:16], sandboxID.Bytes[:])
copy(k[16:], teamID.Bytes[:])
return k
}
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
// requests are reverse-proxied through the host agent that owns the sandbox.
// All other requests are passed through to the inner handler.
//
// Authentication is via X-API-Key header only (no JWT). The API key's team
// must own the sandbox.
type SandboxProxyWrapper struct {
inner http.Handler
db *db.Queries
pool *lifecycle.HostClientPool
transport http.RoundTripper
cacheMu sync.Mutex
cache map[proxyCacheKey]proxyCacheEntry
}
// NewSandboxProxyWrapper creates a new proxy wrapper.
func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifecycle.HostClientPool) *SandboxProxyWrapper {
return &SandboxProxyWrapper{
inner: inner,
db: queries,
pool: pool,
transport: pool.Transport(),
cache: make(map[proxyCacheKey]proxyCacheEntry),
}
}
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
// On a miss it queries the DB, resolves the address, and populates the cache.
// The *httputil.ReverseProxy is built by the caller so the Director closure
// captures the correct port without the cache key needing to include it.
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
cacheKey := makeProxyCacheKey(sandboxID, teamID)
h.cacheMu.Lock()
entry, ok := h.cache[cacheKey]
h.cacheMu.Unlock()
if ok && time.Now().Before(entry.expiresAt) {
return entry.agentURL, nil
}
// Cache miss or expired — query DB.
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
ID: sandboxID,
TeamID: teamID,
})
if err != nil {
return nil, errProxySandboxNotFound
}
if target.Status != "running" {
return nil, errProxySandboxNotRunning{status: target.Status}
}
if target.HostAddress == "" {
return nil, errProxyNoHostAddress
}
agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress))
if err != nil {
return nil, fmt.Errorf("invalid host agent address: %w", err)
}
h.cacheMu.Lock()
h.cache[cacheKey] = proxyCacheEntry{
agentURL: agentURL,
expiresAt: time.Now().Add(proxyCacheTTL),
}
h.cacheMu.Unlock()
return agentURL, nil
}
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
h.cacheMu.Lock()
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
h.cacheMu.Unlock()
}
func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Host
// Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost")
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
host = host[:colonIdx]
}
matches := sandboxHostPattern.FindStringSubmatch(host)
if matches == nil {
h.inner.ServeHTTP(w, r)
return
}
port := matches[1]
sandboxIDStr := matches[2]
// Validate port.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
// Authenticate: require API key or JWT, extract team ID.
teamID, err := h.authenticateRequest(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
return
}
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
return
}
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
if err != nil {
switch {
case errors.Is(err, errProxySandboxNotFound):
http.Error(w, err.Error(), http.StatusNotFound)
case errors.As(err, new(errProxySandboxNotRunning)):
http.Error(w, err.Error(), http.StatusConflict)
default:
http.Error(w, err.Error(), http.StatusServiceUnavailable)
}
return
}
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = agentURL.Scheme
req.URL.Host = agentURL.Host
req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path))
req.Host = agentURL.Host
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("sandbox proxy error",
"sandbox_id", sandboxIDStr,
"port", port,
"error", err,
)
h.evictProxyCache(sandboxID, teamID)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}
// authenticateRequest validates the request's API key and returns the team ID.
// Only API key authentication is supported for sandbox proxy requests (not JWT).
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
key := r.Header.Get("X-API-Key")
if key == "" {
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
}
hash := auth.HashAPIKey(key)
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid API key")
}
return row.TeamID, nil
}

View File

@ -6,17 +6,20 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type apiKeyHandler struct {
svc *service.APIKeyService
svc *service.APIKeyService
audit *audit.AuditLogger
}
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
return &apiKeyHandler{svc: svc}
func newAPIKeyHandler(svc *service.APIKeyService, al *audit.AuditLogger) *apiKeyHandler {
return &apiKeyHandler{svc: svc, audit: al}
}
type createAPIKeyRequest struct {
@ -37,11 +40,11 @@ type apiKeyResponse struct {
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
@ -55,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
}
if k.CreatedAt.Valid {
@ -91,6 +94,7 @@ func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
resp := apiKeyToResponse(result.Row)
resp.Key = &result.Plaintext
h.audit.LogAPIKeyCreate(r.Context(), ac, result.Row.ID, result.Row.Name)
writeJSON(w, http.StatusCreated, resp)
}
@ -115,12 +119,19 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /v1/api-keys/{id}.
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
keyID := chi.URLParam(r, "id")
keyIDStr := chi.URLParam(r, "id")
keyID, err := id.ParseAPIKeyID(keyIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
return
}
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
return
}
h.audit.LogAPIKeyRevoke(r.Context(), ac, keyID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,148 @@
package api
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type auditHandler struct {
svc *service.AuditService
}
func newAuditHandler(svc *service.AuditService) *auditHandler {
return &auditHandler{svc: svc}
}
type auditLogResponse struct {
ID string `json:"id"`
ActorType string `json:"actor_type"`
ActorID string `json:"actor_id,omitempty"`
ActorName string `json:"actor_name,omitempty"`
ResourceType string `json:"resource_type"`
ResourceID string `json:"resource_id,omitempty"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt string `json:"created_at"`
}
// List handles GET /v1/audit-logs.
// Query params:
// - before: RFC3339 timestamp cursor (exclusive); omit to start from latest
// - limit: page size, default 50, max 200
// - resource_type: filter by resource type (sandbox, snapshot, team, api_key, member, host)
// - action: filter by action verb
//
// Members see only team-scoped events; admins/owners see all.
func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
// Parse ?before cursor.
var before time.Time
if s := r.URL.Query().Get("before"); s != "" {
var err error
before, err = time.Parse(time.RFC3339, s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before must be an RFC3339 timestamp")
return
}
}
// Parse ?limit.
limit := 50
if s := r.URL.Query().Get("limit"); s != "" {
n, err := strconv.Atoi(s)
if err != nil || n < 1 {
writeError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer")
return
}
limit = n
}
// Parse ?before_id cursor (UUID).
var beforeID pgtype.UUID
if s := r.URL.Query().Get("before_id"); s != "" {
parsed, err := id.ParseAuditLogID(s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
return
}
beforeID = parsed
}
entries, err := h.svc.List(r.Context(), service.AuditListParams{
TeamID: ac.TeamID,
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
Actions: parseMultiParam(r.URL.Query()["action"]),
Before: before,
BeforeID: beforeID,
Limit: limit,
})
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list audit logs")
return
}
items := make([]auditLogResponse, len(entries))
for i, e := range entries {
items[i] = auditLogResponse{
ID: e.ID,
ActorType: e.ActorType,
ActorID: e.ActorID,
ActorName: e.ActorName,
ResourceType: e.ResourceType,
ResourceID: e.ResourceID,
Action: e.Action,
Scope: e.Scope,
Status: e.Status,
Metadata: e.Metadata,
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
}
}
resp := map[string]any{"items": items}
if len(items) > 0 {
last := entries[len(entries)-1]
resp["next_before"] = last.CreatedAt.UTC().Format(time.RFC3339)
resp["next_before_id"] = last.ID
}
writeJSON(w, http.StatusOK, resp)
}
// parseMultiParam flattens repeated params and comma-separated values into a
// single deduplicated slice. Empty strings are dropped. Returns nil (no filter)
// when no values are present.
//
// Both ?resource_type=sandbox&resource_type=snapshot
// and ?resource_type=sandbox,snapshot are accepted.
func parseMultiParam(values []string) []string {
if len(values) == 0 {
return nil
}
seen := make(map[string]struct{})
var out []string
for _, v := range values {
for _, part := range strings.Split(v, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if _, ok := seen[part]; !ok {
seen[part] = struct{}{}
out = append(out, part)
}
}
}
return out
}

View File

@ -1,7 +1,9 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
@ -15,6 +17,45 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// loginTeam returns the team and role to stamp into a login JWT.
// It prefers the user's default team; if none is flagged as default it falls
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
// no team memberships at all.
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
team, err := q.GetDefaultTeamForUser(ctx, userID)
if err == nil {
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
if err != nil {
return db.Team{}, "", err
}
return team, membership.Role, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return db.Team{}, "", err
}
// No default set — fall back to earliest-joined team.
rows, err := q.GetTeamsForUser(ctx, userID)
if err != nil {
return db.Team{}, "", err
}
if len(rows) == 0 {
return db.Team{}, "", pgx.ErrNoRows
}
first := rows[0]
return db.Team{
ID: first.ID,
Name: first.Name,
Slug: first.Slug,
IsByoc: first.IsByoc,
CreatedAt: first.CreatedAt,
DeletedAt: first.DeletedAt,
}, first.Role, nil
}
type switchTeamRequest struct {
TeamID string `json:"team_id"`
}
type authHandler struct {
db *db.Queries
pool *pgxpool.Pool
@ -28,6 +69,7 @@ func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authH
type signupRequest struct {
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
}
type loginRequest struct {
@ -40,6 +82,7 @@ type authResponse struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
}
// Signup handles POST /v1/auth/signup.
@ -51,6 +94,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
}
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
req.Name = strings.TrimSpace(req.Name)
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
return
@ -59,6 +103,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
return
}
if req.Name == "" || len(req.Name) > 100 {
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
return
}
ctx := r.Context()
@ -83,6 +131,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
ID: userID,
Email: req.Email,
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
Name: req.Name,
})
if err != nil {
var pgErr *pgconn.PgError
@ -98,7 +147,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
teamID := id.NewTeamID()
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: req.Email + "'s Team",
Name: req.Name + "'s Team",
Slug: id.NewTeamSlug(),
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
return
@ -119,7 +169,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -127,9 +177,10 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, authResponse{
Token: token,
UserID: userID,
TeamID: teamID,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
})
}
@ -152,6 +203,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
user, err := h.db.GetUserByEmail(ctx, req.Email)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("login failed: unknown email", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
@ -160,21 +212,27 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
}
if !user.PasswordHash.Valid {
slog.Warn("login failed: no password set", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
slog.Warn("login failed: wrong password", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -182,8 +240,85 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: user.ID,
TeamID: team.ID,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
}
// SwitchTeam handles POST /v1/auth/switch-team.
// Verifies from DB that the user is a member of the target team, then re-issues
// a JWT scoped to that team. The JWT's team_id is used as a pre-filter on all
// subsequent team-scoped requests; DB is the source of truth for actual permissions.
func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req switchTeamRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.TeamID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "team_id is required")
return
}
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
ctx := r.Context()
// Verify team exists and is not deleted.
team, err := h.db.GetTeam(ctx, teamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "team not found")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
return
}
if team.DeletedAt.Valid {
writeError(w, http.StatusNotFound, "not_found", "team not found")
return
}
// Verify membership from DB — JWT role is not trusted here.
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: ac.UserID,
TeamID: teamID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusForbidden, "forbidden", "not a member of this team")
return
}
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up membership")
return
}
// Fetch current name from DB — JWT name is not trusted here (may be stale or empty for old tokens).
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
})
}

View File

@ -0,0 +1,276 @@
package api
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
type buildHandler struct {
svc *service.BuildService
db *db.Queries
pool *lifecycle.HostClientPool
}
func newBuildHandler(svc *service.BuildService, db *db.Queries, pool *lifecycle.HostClientPool) *buildHandler {
return &buildHandler{svc: svc, db: db, pool: pool}
}
type createBuildRequest struct {
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []string `json:"recipe"`
Healthcheck string `json:"healthcheck"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SkipPrePost bool `json:"skip_pre_post"`
}
type buildResponse struct {
ID string `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe json.RawMessage `json:"recipe"`
Healthcheck *string `json:"healthcheck,omitempty"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs json.RawMessage `json:"logs"`
Error *string `json:"error,omitempty"`
SandboxID *string `json:"sandbox_id,omitempty"`
HostID *string `json:"host_id,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
}
func buildToResponse(b db.TemplateBuild) buildResponse {
resp := buildResponse{
ID: id.FormatBuildID(b.ID),
Name: b.Name,
BaseTemplate: b.BaseTemplate,
Recipe: b.Recipe,
VCPUs: b.Vcpus,
MemoryMB: b.MemoryMb,
Status: b.Status,
CurrentStep: b.CurrentStep,
TotalSteps: b.TotalSteps,
Logs: b.Logs,
}
if b.Healthcheck != "" {
resp.Healthcheck = &b.Healthcheck
}
if b.Error != "" {
resp.Error = &b.Error
}
if b.SandboxID.Valid {
s := id.FormatSandboxID(b.SandboxID)
resp.SandboxID = &s
}
if b.HostID.Valid {
s := id.FormatHostID(b.HostID)
resp.HostID = &s
}
if b.CreatedAt.Valid {
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
}
if b.StartedAt.Valid {
s := b.StartedAt.Time.Format(time.RFC3339)
resp.StartedAt = &s
}
if b.CompletedAt.Valid {
s := b.CompletedAt.Time.Format(time.RFC3339)
resp.CompletedAt = &s
}
return resp
}
// Create handles POST /v1/admin/builds.
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createBuildRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "name is required")
return
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
if len(req.Recipe) == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "recipe must contain at least one command")
return
}
build, err := h.svc.Create(r.Context(), service.BuildCreateParams{
Name: req.Name,
BaseTemplate: req.BaseTemplate,
Recipe: req.Recipe,
Healthcheck: req.Healthcheck,
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
})
if err != nil {
slog.Error("failed to create build", "error", err)
writeError(w, http.StatusInternalServerError, "build_error", "failed to create build")
return
}
writeJSON(w, http.StatusCreated, buildToResponse(build))
}
// List handles GET /v1/admin/builds.
func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
builds, err := h.svc.List(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list builds")
return
}
resp := make([]buildResponse, len(builds))
for i, b := range builds {
resp[i] = buildToResponse(b)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/admin/builds/{id}.
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.svc.Get(r.Context(), buildID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "build not found")
return
}
writeJSON(w, http.StatusOK, buildToResponse(build))
}
// ListTemplates handles GET /v1/admin/templates — returns all templates across all teams.
func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
templates, err := h.db.ListTemplates(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
return
}
type templateResponse struct {
Name string `json:"name"`
Type string `json:"type"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
CreatedAt string `json:"created_at"`
}
resp := make([]templateResponse, len(templates))
for i, t := range templates {
resp[i] = templateResponse{
Name: t.Name,
Type: t.Type,
VCPUs: t.Vcpus,
MemoryMB: t.MemoryMb,
SizeBytes: t.SizeBytes,
TeamID: id.FormatTeamID(t.TeamID),
}
if t.CreatedAt.Valid {
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
}
writeJSON(w, http.StatusOK, resp)
}
// DeleteTemplate handles DELETE /v1/admin/templates/{name}.
func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
if err := validate.SafeName(name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err))
return
}
ctx := r.Context()
tmpl, err := h.db.GetPlatformTemplateByName(ctx, name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
// Broadcast delete to all online hosts.
hosts, _ := h.db.ListActiveHosts(ctx)
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(tmpl.TeamID),
TemplateId: formatUUIDForRPC(tmpl.ID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
}
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
w.WriteHeader(http.StatusNoContent)
}
// Cancel handles POST /v1/admin/builds/{id}/cancel.
func (h *buildHandler) Cancel(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
if err := h.svc.Cancel(r.Context(), buildID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,242 @@
package api
import (
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/channels"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type channelHandler struct {
svc *channels.Service
audit *audit.AuditLogger
}
func newChannelHandler(svc *channels.Service, al *audit.AuditLogger) *channelHandler {
return &channelHandler{svc: svc, audit: al}
}
type createChannelRequest struct {
Name string `json:"name"`
Provider string `json:"provider"`
Config map[string]string `json:"config"`
Events []string `json:"events"`
}
type updateChannelRequest struct {
Name string `json:"name"`
Events []string `json:"events"`
}
type rotateConfigRequest struct {
Config map[string]string `json:"config"`
}
type testChannelRequest struct {
Provider string `json:"provider"`
Config map[string]string `json:"config"`
}
type channelResponse struct {
ID string `json:"id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Events []string `json:"events"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Secret *string `json:"secret,omitempty"`
}
func channelToResponse(ch db.Channel) channelResponse {
resp := channelResponse{
ID: id.FormatChannelID(ch.ID),
TeamID: id.FormatTeamID(ch.TeamID),
Name: ch.Name,
Provider: ch.Provider,
Events: ch.EventTypes,
}
if ch.CreatedAt.Valid {
resp.CreatedAt = ch.CreatedAt.Time.Format(time.RFC3339)
}
if ch.UpdatedAt.Valid {
resp.UpdatedAt = ch.UpdatedAt.Time.Format(time.RFC3339)
}
return resp
}
// Create handles POST /v1/channels.
func (h *channelHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req createChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
result, err := h.svc.Create(r.Context(), channels.CreateParams{
TeamID: ac.TeamID,
Name: req.Name,
Provider: req.Provider,
Config: req.Config,
Events: req.Events,
})
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelCreate(r.Context(), ac, result.Channel.ID, result.Channel.Name, result.Channel.Provider)
resp := channelToResponse(result.Channel)
if result.PlaintextSecret != "" {
resp.Secret = &result.PlaintextSecret
}
writeJSON(w, http.StatusCreated, resp)
}
// List handles GET /v1/channels.
func (h *channelHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
chs, err := h.svc.List(r.Context(), ac.TeamID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list channels")
return
}
resp := make([]channelResponse, len(chs))
for i, ch := range chs {
resp[i] = channelToResponse(ch)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/channels/{id}.
func (h *channelHandler) Get(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
ch, err := h.svc.Get(r.Context(), channelID, ac.TeamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "channel not found")
} else {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get channel")
}
return
}
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Update handles PATCH /v1/channels/{id}.
func (h *channelHandler) Update(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
var req updateChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ch, err := h.svc.Update(r.Context(), channelID, ac.TeamID, req.Name, req.Events)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelUpdate(r.Context(), ac, channelID)
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Test handles POST /v1/channels/test.
func (h *channelHandler) Test(w http.ResponseWriter, r *http.Request) {
var req testChannelRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.Test(r.Context(), req.Provider, req.Config); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// RotateConfig handles PUT /v1/channels/{id}/config.
func (h *channelHandler) RotateConfig(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
var req rotateConfigRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ch, err := h.svc.RotateConfig(r.Context(), channelID, ac.TeamID, req.Config)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogChannelRotateConfig(r.Context(), ac, channelID)
writeJSON(w, http.StatusOK, channelToResponse(ch))
}
// Delete handles DELETE /v1/channels/{id}.
func (h *channelHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
channelIDStr := chi.URLParam(r, "id")
channelID, err := id.ParseChannelID(channelIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid channel ID")
return
}
if err := h.svc.Delete(r.Context(), channelID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete channel")
return
}
h.audit.LogChannelDelete(r.Context(), ac, channelID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -14,17 +14,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
return &execHandler{db: db, agent: agent}
func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler {
return &execHandler{db: db, pool: pool}
}
type execRequest struct {
@ -46,10 +47,16 @@ type execResponse struct {
// Exec handles POST /v1/sandboxes/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -73,8 +80,14 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
start := time.Now()
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
@ -95,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
// Use base64 encoding if output contains non-UTF-8 bytes.
@ -106,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
@ -118,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),

View File

@ -14,17 +14,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type execStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
return &execStreamHandler{db: db, agent: agent}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
}
var upgrader = websocket.Upgrader{
@ -48,10 +49,16 @@ type wsOutMsg struct {
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -80,12 +87,18 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendWSError(conn, "sandbox host is not reachable")
return
}
// Open streaming exec to host agent.
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID,
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxIDStr,
Cmd: startMsg.Cmd,
Args: startMsg.Args,
}))
@ -151,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
}

View File

@ -11,17 +11,18 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
return &filesHandler{db: db, agent: agent}
func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandler {
return &filesHandler{db: db, pool: pool}
}
// Upload handles POST /v1/sandboxes/{id}/files/write.
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceCl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -75,8 +82,14 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
return
}
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxIDStr,
Path: filePath,
Content: content,
})); err != nil {
@ -95,10 +108,16 @@ type readFileRequest struct {
// Download handles POST /v1/sandboxes/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -120,8 +139,14 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -12,27 +12,34 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type filesStreamHandler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
return &filesStreamHandler{db: db, agent: agent}
func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesStreamHandler {
return &filesStreamHandler{db: db, pool: pool}
}
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -88,14 +95,20 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
}
defer filePart.Close()
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open client-streaming RPC to host agent.
stream := h.agent.WriteFileStream(ctx)
stream := agent.WriteFileStream(ctx)
// Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{
Content: &pb.WriteFileStreamRequest_Meta{
Meta: &pb.WriteFileStreamMeta{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
},
},
@ -140,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -164,9 +183,15 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Open server-streaming RPC to host agent.
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID,
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -1,23 +1,30 @@
package api
import (
"errors"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type hostHandler struct {
svc *service.HostService
queries *db.Queries
audit *audit.AuditLogger
}
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
return &hostHandler{svc: svc, queries: queries}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al}
}
// Request/response types.
@ -34,6 +41,24 @@ type createHostResponse struct {
RegistrationToken string `json:"registration_token"`
}
type refreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshTokenResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type deletePreviewResponse struct {
Host hostResponse `json:"host"`
SandboxIDs []string `json:"sandbox_ids"`
}
type registerHostRequest struct {
Token string `json:"token"`
Arch string `json:"arch,omitempty"`
@ -44,8 +69,12 @@ type registerHostRequest struct {
}
type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
Host hostResponse `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type addTagRequest struct {
@ -56,6 +85,7 @@ type hostResponse struct {
ID string `json:"id"`
Type string `json:"type"`
TeamID *string `json:"team_id,omitempty"`
TeamName *string `json:"team_name,omitempty"`
Provider *string `json:"provider,omitempty"`
AvailabilityZone *string `json:"availability_zone,omitempty"`
Arch *string `json:"arch,omitempty"`
@ -72,34 +102,35 @@ type hostResponse struct {
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
ID: h.ID,
ID: id.FormatHostID(h.ID),
Type: h.Type,
Status: h.Status,
CreatedBy: h.CreatedBy,
CreatedBy: id.FormatUserID(h.CreatedBy),
}
if h.TeamID.Valid {
resp.TeamID = &h.TeamID.String
s := id.FormatTeamID(h.TeamID)
resp.TeamID = &s
}
if h.Provider.Valid {
resp.Provider = &h.Provider.String
if h.Provider != "" {
resp.Provider = &h.Provider
}
if h.AvailabilityZone.Valid {
resp.AvailabilityZone = &h.AvailabilityZone.String
if h.AvailabilityZone != "" {
resp.AvailabilityZone = &h.AvailabilityZone
}
if h.Arch.Valid {
resp.Arch = &h.Arch.String
if h.Arch != "" {
resp.Arch = &h.Arch
}
if h.CpuCores.Valid {
resp.CPUCores = &h.CpuCores.Int32
if h.CpuCores != 0 {
resp.CPUCores = &h.CpuCores
}
if h.MemoryMb.Valid {
resp.MemoryMB = &h.MemoryMb.Int32
if h.MemoryMb != 0 {
resp.MemoryMB = &h.MemoryMb
}
if h.DiskGb.Valid {
resp.DiskGB = &h.DiskGb.Int32
if h.DiskGb != 0 {
resp.DiskGB = &h.DiskGb
}
if h.Address.Valid {
resp.Address = &h.Address.String
if h.Address != "" {
resp.Address = &h.Address
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
@ -112,7 +143,7 @@ func hostToResponse(h db.Host) hostResponse {
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
@ -130,20 +161,32 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
Type: req.Type,
TeamID: req.TeamID,
Provider: req.Provider,
AvailabilityZone: req.AvailabilityZone,
RequestingUserID: ac.UserID,
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
})
// Parse optional team ID from request body.
var params service.HostCreateParams
params.Type = req.Type
params.Provider = req.Provider
params.AvailabilityZone = req.AvailabilityZone
params.RequestingUserID = ac.UserID
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
if req.TeamID != "" {
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
params.TeamID = teamID
}
result, err := h.svc.Create(r.Context(), params)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
RegistrationToken: result.RegistrationToken,
@ -153,16 +196,50 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
// List handles GET /v1/hosts.
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
hosts, err := h.svc.List(r.Context(), ac.TeamID, admin)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
return
}
// Collect unique team IDs so we can fetch team names in one pass.
var teamNames map[string]string
if admin {
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
}
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponse(host)
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
}
writeJSON(w, http.StatusOK, resp)
@ -170,9 +247,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -183,25 +266,86 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, hostToResponse(host))
}
// Delete handles DELETE /v1/hosts/{id}.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
writeJSON(w, http.StatusOK, deletePreviewResponse{
Host: hostToResponse(preview.Host),
SandboxIDs: preview.SandboxIDs,
})
}
// Delete handles DELETE /v1/hosts/{id}.
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Fetch host before deletion to capture team_id for audit.
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
if hostErr != nil {
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
}
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
w.WriteHeader(http.StatusNoContent)
return
}
// Check if it's a "has running sandboxes" error and return a structured 409.
var hasSandboxes *service.HostHasSandboxesError
if errors.As(err, &hasSandboxes) {
writeJSON(w, http.StatusConflict, map[string]any{
"error": map[string]any{
"code": "has_active_sandboxes",
"message": "host has active sandboxes; use ?force=true to destroy them and delete the host",
"sandbox_ids": hasSandboxes.SandboxIDs,
},
})
return
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -247,36 +391,61 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusCreated, registerHostResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
return
}
// Capture pre-heartbeat status to detect unreachable → online transition.
prevHost, _ := h.queries.GetHost(r.Context(), hc.HostID)
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Log marked_up if the host just recovered from unreachable.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
}
w.WriteHeader(http.StatusNoContent)
}
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
@ -298,10 +467,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -311,11 +486,47 @@ func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// RefreshToken handles POST /v1/hosts/auth/refresh (unauthenticated).
// The host agent sends its refresh token to receive a new JWT and rotated refresh token.
func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req refreshTokenRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.RefreshToken == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "refresh_token is required")
return
}
result, err := h.svc.Refresh(r.Context(), req.RefreshToken)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, refreshTokenResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
})
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)

View File

@ -0,0 +1,156 @@
package api
import (
"context"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
type sandboxMetricsHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
}
func newSandboxMetricsHandler(db *db.Queries, pool *lifecycle.HostClientPool) *sandboxMetricsHandler {
return &sandboxMetricsHandler{db: db, pool: pool}
}
type metricPointResponse struct {
TimestampUnix int64 `json:"timestamp_unix"`
CPUPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type metricsResponse struct {
SandboxID string `json:"sandbox_id"`
Range string `json:"range"`
Points []metricPointResponse `json:"points"`
}
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
rangeTier := r.URL.Query().Get("range")
if rangeTier == "" {
rangeTier = "10m"
}
validRanges := map[string]bool{"5m": true, "10m": true, "1h": true, "2h": true, "6h": true, "12h": true, "24h": true}
if !validRanges[rangeTier] {
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 10m, 1h, 2h, 6h, 12h, 24h")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
switch sb.Status {
case "running":
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
case "paused":
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
default:
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
}
}
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
ctx := r.Context()
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
SandboxId: sandboxIDStr,
Range: rangeTier,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
points := make([]metricPointResponse, len(resp.Msg.Points))
for i, p := range resp.Msg.Points {
points[i] = metricPointResponse{
TimestampUnix: p.TimestampUnix,
CPUPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
}
// rangeToDB maps a user-facing range filter to the DB tier and cutoff duration.
var rangeToDB = map[string]struct {
tier string
cutoff time.Duration
}{
"5m": {"10m", 5 * time.Minute},
"10m": {"10m", 10 * time.Minute},
"1h": {"2h", 1 * time.Hour},
"2h": {"2h", 2 * time.Hour},
"6h": {"24h", 6 * time.Hour},
"12h": {"24h", 12 * time.Hour},
"24h": {"24h", 24 * time.Hour},
}
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
mapping := rangeToDB[rangeTier]
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
SandboxID: sandboxID,
Tier: mapping.tier,
Ts: time.Now().Add(-mapping.cutoff).Unix(),
})
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read metrics")
return
}
points := make([]metricPointResponse, len(rows))
for i, row := range rows {
points[i] = metricPointResponse{
TimestampUnix: row.Ts,
CPUPct: row.CpuPct,
MemBytes: row.MemBytes,
DiskBytes: row.DiskBytes,
}
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
}

View File

@ -150,19 +150,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
slog.Error("oauth login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -199,6 +199,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
ID: userID,
Email: email,
Name: profile.Name,
})
if err != nil {
var pgErr *pgconn.PgError
@ -219,6 +220,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: teamName,
Slug: id.NewTeamSlug(),
}); err != nil {
slog.Error("oauth: failed to create team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
@ -253,14 +255,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -282,29 +284,39 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
team, role, err := loginTeam(ctx, h.db, user.ID)
if err != nil {
slog.Error("oauth: retry login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
u := base + "?" + url.Values{
"token": {token},
"user_id": {userID},
"team_id": {teamID},
"email": {email},
}.Encode()
http.Redirect(w, r, u, http.StatusFound)
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
// Set auth data as short-lived cookies instead of URL query parameters.
// This prevents token leakage via server access logs, Referer headers, and browser history.
for _, c := range []http.Cookie{
{Name: "wrenn_oauth_token", Value: token},
{Name: "wrenn_oauth_user_id", Value: userID},
{Name: "wrenn_oauth_team_id", Value: teamID},
{Name: "wrenn_oauth_email", Value: email},
{Name: "wrenn_oauth_name", Value: name},
} {
c.Path = "/auth/"
c.MaxAge = 60
c.HttpOnly = false // frontend JS must read these
c.SameSite = http.SameSiteLaxMode
c.Secure = isSecure(r)
http.SetCookie(w, &c)
}
http.Redirect(w, r, base, http.StatusFound)
}
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {

View File

@ -7,17 +7,20 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type sandboxHandler struct {
svc *service.SandboxService
svc *service.SandboxService
audit *audit.AuditLogger
}
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
return &sandboxHandler{svc: svc}
func newSandboxHandler(svc *service.SandboxService, al *audit.AuditLogger) *sandboxHandler {
return &sandboxHandler{svc: svc, audit: al}
}
type createSandboxRequest struct {
@ -44,7 +47,7 @@ type sandboxResponse struct {
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
resp := sandboxResponse{
ID: sb.ID,
ID: id.FormatSandboxID(sb.ID),
Status: sb.Status,
Template: sb.Template,
VCPUs: sb.Vcpus,
@ -79,6 +82,10 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
}
ac := auth.MustFromContext(r.Context())
if !ac.TeamID.Valid {
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
return
}
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
TeamID: ac.TeamID,
@ -93,6 +100,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
}
@ -115,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/sandboxes/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -129,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
// Pause handles POST /v1/sandboxes/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -139,14 +159,21 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Resume handles POST /v1/sandboxes/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -154,14 +181,21 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Ping handles POST /v1/sandboxes/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -173,14 +207,21 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
// Destroy handles DELETE /v1/sandboxes/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
@ -9,25 +10,57 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/layout"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, agent: agent}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
hosts, err := h.db.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
return nil
}
type createSnapshotRequest struct {
@ -42,6 +75,7 @@ type snapshotResponse struct {
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
}
func templateToResponse(t db.Template) snapshotResponse {
@ -49,12 +83,13 @@ func templateToResponse(t db.Template) snapshotResponse {
Name: t.Name,
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
}
if t.Vcpus.Valid {
resp.VCPUs = &t.Vcpus.Int32
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
}
if t.MemoryMb.Valid {
resp.MemoryMB = &t.MemoryMb.Int32
if t.MemoryMb != 0 {
resp.MemoryMB = &t.MemoryMb
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -75,6 +110,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
@ -87,16 +128,21 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old files from the agent before removing the DB record.
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
// Delete old snapshot files from all hosts before removing the DB record.
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
@ -106,7 +152,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
@ -116,30 +162,59 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
if sb.Status != "paused" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: req.SandboxID, Status: "paused",
}); err != nil {
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
}
}
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
})
@ -149,6 +224,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
}
@ -181,16 +262,23 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
tmpl, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
// Platform templates can only be deleted by admins via /v1/admin/templates.
if tmpl.TeamID == id.PlatformTeamID {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
return
}
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
Name: name,
})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, "failed to delete snapshot files: "+msg)
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}
@ -199,5 +287,6 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogSnapshotDelete(r.Context(), ac, name)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,95 @@
package api
import (
"log/slog"
"net/http"
"time"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type statsHandler struct {
svc *service.StatsService
}
func newStatsHandler(svc *service.StatsService) *statsHandler {
return &statsHandler{svc: svc}
}
type statsCurrentResponse struct {
RunningCount int32 `json:"running_count"`
VCPUsReserved int32 `json:"vcpus_reserved"`
MemoryMBReserved int32 `json:"memory_mb_reserved"`
}
type statsPeaksResponse struct {
RunningCount int32 `json:"running_count"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
}
type statsSeriesResponse struct {
Labels []string `json:"labels"`
Running []int32 `json:"running"`
VCPUs []int32 `json:"vcpus"`
MemoryMB []int32 `json:"memory_mb"`
}
type statsResponse struct {
Range string `json:"range"`
Current statsCurrentResponse `json:"current"`
Peaks statsPeaksResponse `json:"peaks"`
Series statsSeriesResponse `json:"series"`
}
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
rangeParam := r.URL.Query().Get("range")
if rangeParam == "" {
rangeParam = string(service.Range1h)
}
tr := service.TimeRange(rangeParam)
if !service.ValidRange(tr) {
writeError(w, http.StatusBadRequest, "invalid_request", "range must be one of: 5m, 1h, 6h, 24h, 30d")
return
}
current, peaks, series, err := h.svc.GetStats(r.Context(), ac.TeamID, tr)
if err != nil {
slog.Error("stats handler: get stats failed", "team_id", ac.TeamID, "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to retrieve stats")
return
}
resp := statsResponse{
Range: rangeParam,
Current: statsCurrentResponse{
RunningCount: current.RunningCount,
VCPUsReserved: current.VCPUsReserved,
MemoryMBReserved: current.MemoryMBReserved,
},
Peaks: statsPeaksResponse{
RunningCount: peaks.RunningCount,
VCPUs: peaks.VCPUs,
MemoryMB: peaks.MemoryMB,
},
Series: statsSeriesResponse{
Labels: make([]string, len(series)),
Running: make([]int32, len(series)),
VCPUs: make([]int32, len(series)),
MemoryMB: make([]int32, len(series)),
},
}
for i, pt := range series {
resp.Series.Labels[i] = pt.Bucket.UTC().Format(time.RFC3339)
resp.Series.Running[i] = pt.RunningCount
resp.Series.VCPUs[i] = pt.VCPUsReserved
resp.Series.MemoryMB[i] = pt.MemoryMBReserved
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -0,0 +1,390 @@
package api
import (
"log/slog"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type teamHandler struct {
svc *service.TeamService
audit *audit.AuditLogger
}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
return &teamHandler{svc: svc, audit: al}
}
// teamResponse is the JSON shape for a team.
type teamResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt string `json:"created_at"`
}
// teamWithRoleResponse includes the calling user's role.
type teamWithRoleResponse struct {
teamResponse
Role string `json:"role"`
}
type memberResponse struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt string `json:"joined_at,omitempty"`
}
func teamToResponse(t db.Team) teamResponse {
resp := teamResponse{
ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
return resp
}
func memberInfoToResponse(m service.MemberInfo) memberResponse {
return memberResponse{
UserID: m.UserID,
Name: m.Name,
Email: m.Email,
Role: m.Role,
JoinedAt: m.JoinedAt.Format(time.RFC3339),
}
}
// requireTeamAccess is an inline check used by every team-scoped handler:
// the JWT team_id must match the URL {id} before any DB call is made.
// Returns false and writes 403 if they don't match.
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return pgtype.UUID{}, false
}
if ac.TeamID != teamID {
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
return pgtype.UUID{}, false
}
return teamID, true
}
// List handles GET /v1/teams
// Returns all teams the authenticated user belongs to.
func (h *teamHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teams, err := h.svc.ListTeamsForUser(r.Context(), ac.UserID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
resp := make([]teamWithRoleResponse, len(teams))
for i, t := range teams {
resp[i] = teamWithRoleResponse{
teamResponse: teamToResponse(t.Team),
Role: t.Role,
}
}
writeJSON(w, http.StatusOK, resp)
}
// Create handles POST /v1/teams
// Creates a new team owned by the authenticated user.
func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
var req struct {
Name string `json:"name"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Name = strings.TrimSpace(req.Name)
team, err := h.svc.CreateTeam(r.Context(), ac.UserID, req.Name)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
teamResponse: teamToResponse(team.Team),
Role: team.Role,
})
}
// Get handles GET /v1/teams/{id}
// Returns team info and member list.
func (h *teamHandler) Get(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
team, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
members, err := h.svc.GetMembers(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
memberResp := make([]memberResponse, len(members))
for i, m := range members {
memberResp[i] = memberInfoToResponse(m)
}
writeJSON(w, http.StatusOK, map[string]any{
"team": teamToResponse(team),
"members": memberResp,
})
}
// Rename handles PATCH /v1/teams/{id}
// Renames the team. Requires admin or owner role (verified from DB).
func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
var req struct {
Name string `json:"name"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Name = strings.TrimSpace(req.Name)
// Fetch old name for audit log before renaming.
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogTeamRename(r.Context(), ac, teamID, oldTeam.Name, req.Name)
w.WriteHeader(http.StatusNoContent)
}
// Delete handles DELETE /v1/teams/{id}
// Soft-deletes the team and destroys active sandboxes. Owner only.
func (h *teamHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
if err := h.svc.DeleteTeam(r.Context(), teamID, ac.UserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// ListMembers handles GET /v1/teams/{id}/members
func (h *teamHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
members, err := h.svc.GetMembers(r.Context(), teamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
resp := make([]memberResponse, len(members))
for i, m := range members {
resp[i] = memberInfoToResponse(m)
}
writeJSON(w, http.StatusOK, resp)
}
// AddMember handles POST /v1/teams/{id}/members
// Adds a user by email. Requires admin or owner (verified from DB).
func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
var req struct {
Email string `json:"email"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
if req.Email == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "email is required")
return
}
member, err := h.svc.AddMember(r.Context(), teamID, ac.UserID, req.Email)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
// member.UserID is already formatted with prefix; parse it back for the audit logger.
targetUserID, parseErr := id.ParseUserID(member.UserID)
if parseErr == nil {
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
}
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
// RemoveMember handles DELETE /v1/teams/{id}/members/{uid}
// Removes a member. Requires admin or owner (verified from DB). Owner cannot be removed.
func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberRemove(r.Context(), ac, targetUserID)
w.WriteHeader(http.StatusNoContent)
}
// UpdateMemberRole handles PATCH /v1/teams/{id}/members/{uid}
// Changes a member's role (admin or member). Owner's role cannot be changed.
func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Role string `json:"role"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.UpdateMemberRole(r.Context(), teamID, ac.UserID, targetUserID, req.Role); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
w.WriteHeader(http.StatusNoContent)
}
// Leave handles POST /v1/teams/{id}/leave
// Removes the calling user from the team. Owner cannot leave.
func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
teamID, ok := requireTeamAccess(w, r, ac)
if !ok {
return
}
if err := h.svc.LeaveTeam(r.Context(), teamID, ac.UserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogMemberLeave(r.Context(), ac)
w.WriteHeader(http.StatusNoContent)
}
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
// Enables or disables the BYOC feature flag for a team.
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return
}
var req struct {
Enabled bool `json:"enabled"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,52 @@
package api
import (
"net/http"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type usersHandler struct {
db *db.Queries
}
func newUsersHandler(db *db.Queries) *usersHandler {
return &usersHandler{db: db}
}
// Search handles GET /v1/users/search?email=<prefix>
// Returns up to 10 users whose email starts with the given prefix.
// The prefix must be at least 3 characters long and contain "@".
func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
auth.MustFromContext(r.Context()) // ensure authenticated
prefix := strings.TrimSpace(r.URL.Query().Get("email"))
if len(prefix) < 3 || !strings.Contains(prefix, "@") {
writeError(w, http.StatusBadRequest, "invalid_request", "email prefix must be at least 3 characters and contain '@'")
return
}
// Escape LIKE metacharacters to prevent pattern injection.
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(prefix)
results, err := h.db.SearchUsersByEmailPrefix(r.Context(), pgtype.Text{String: escaped, Valid: true})
if err != nil {
writeError(w, http.StatusInternalServerError, "internal", "search failed")
return
}
type userResult struct {
UserID string `json:"user_id"`
Email string `json:"email"`
}
resp := make([]userResult, len(results))
for i, u := range results {
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -0,0 +1,216 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
// "unreachable" and marks their active sandboxes as "missing".
//
// 2. Active reconciliation: for each online host, calls ListSandboxes and
// reconciles DB state against live host state — restoring "missing"
// sandboxes that are actually alive, and stopping orphaned ones.
type HostMonitor struct {
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
interval time.Duration
}
// NewHostMonitor creates a HostMonitor.
func NewHostMonitor(queries *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger, interval time.Duration) *HostMonitor {
return &HostMonitor{
db: queries,
pool: pool,
audit: al,
interval: interval,
}
}
// Start runs the monitor loop until the context is cancelled.
func (m *HostMonitor) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
// Run immediately on startup so the CP doesn't wait one full interval
// before reconciling host and sandbox state.
m.run(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.run(ctx)
}
}
}()
}
func (m *HostMonitor) run(ctx context.Context) {
hosts, err := m.db.ListActiveHosts(ctx)
if err != nil {
slog.Warn("host monitor: failed to list hosts", "error", err)
return
}
for _, host := range hosts {
m.checkHost(ctx, host)
}
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
stale := !host.LastHeartbeatAt.Valid ||
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
}
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
return
}
// --- Active reconciliation: only for online hosts ---
if host.Status != "online" {
return
}
agent, err := m.pool.GetForHost(host)
if err != nil {
// Host has no address yet (e.g., just registered) — skip.
return
}
resp, err := agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
for _, sb := range missingSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sbIDStr]; ok {
continue
}
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}

View File

@ -0,0 +1,68 @@
package api
import (
"context"
"log/slog"
"time"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// MetricsSampler records per-team sandbox resource usage to
// sandbox_metrics_snapshots every interval. It also prunes rows older than
// 60 days on each tick to keep the table bounded.
type MetricsSampler struct {
db *db.Queries
interval time.Duration
}
// NewMetricsSampler creates a MetricsSampler.
func NewMetricsSampler(queries *db.Queries, interval time.Duration) *MetricsSampler {
return &MetricsSampler{db: queries, interval: interval}
}
// Start runs the sampler loop until the context is cancelled.
func (s *MetricsSampler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
// Sample immediately on startup.
s.run(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.run(ctx)
}
}
}()
}
func (s *MetricsSampler) run(ctx context.Context) {
s.prune(ctx)
if err := s.sample(ctx); err != nil {
slog.Warn("metrics sampler: sample failed", "error", err)
}
}
func (s *MetricsSampler) sample(ctx context.Context) error {
rows, err := s.db.SampleSandboxMetrics(ctx)
if err != nil {
return err
}
for _, row := range rows {
if err := s.db.InsertMetricsSnapshot(ctx, db.InsertMetricsSnapshotParams(row)); err != nil {
slog.Warn("metrics sampler: insert snapshot failed", "team_id", row.TeamID, "error", err)
}
}
return nil
}
func (s *MetricsSampler) prune(ctx context.Context) {
if err := s.db.PruneOldMetrics(ctx); err != nil {
slog.Warn("metrics sampler: prune failed", "error", err)
}
}

View File

@ -12,6 +12,9 @@ import (
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type errorResponse struct {
@ -35,6 +38,11 @@ func writeError(w http.ResponseWriter, status int, code, message string) {
})
}
// formatUUIDForRPC converts a pgtype.UUID to a hex string for RPC messages.
func formatUUIDForRPC(u pgtype.UUID) string {
return id.UUIDString(u)
}
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
func agentErrToHTTP(err error) (int, string, string) {
switch connect.CodeOf(err) {
@ -87,8 +95,12 @@ func serviceErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "conflict:"):
return http.StatusConflict, "conflict", msg
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
default:

View File

@ -0,0 +1,30 @@
package api
import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}
user, err := queries.GetUserByID(r.Context(), ac.UserID)
if err != nil || !user.IsAdmin {
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,38 +0,0 @@
package api
import (
"log/slog"
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
// and stamps TeamID into the request context.
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("X-API-Key")
if key == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
return
}
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
// Best-effort update of last_used timestamp.
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -7,6 +7,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@ -19,15 +20,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
@ -37,14 +43,28 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
tokenStr := strings.TrimPrefix(header, "Bearer ")
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
if err != nil {
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return

View File

@ -4,6 +4,7 @@ import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
hostID, err := id.ParseHostID(claims.HostID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
@ -25,11 +26,26 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
Email: claims.Email,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
IsAdmin: claims.IsAdmin,
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

File diff suppressed because it is too large Load Diff

View File

@ -1,126 +0,0 @@
package api
import (
"context"
"log/slog"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/internal/db"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// Reconciler periodically compares the host agent's sandbox list with the DB
// and marks sandboxes that no longer exist on the host as stopped.
type Reconciler struct {
db *db.Queries
agent hostagentv1connect.HostAgentServiceClient
hostID string
interval time.Duration
}
// NewReconciler creates a new reconciler.
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
return &Reconciler{
db: db,
agent: agent,
hostID: hostID,
interval: interval,
}
}
// Start runs the reconciliation loop until the context is cancelled.
func (rc *Reconciler) Start(ctx context.Context) {
go func() {
ticker := time.NewTicker(rc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.reconcile(ctx)
}
}
}()
}
func (rc *Reconciler) reconcile(ctx context.Context) {
// Single RPC returns both the running sandbox list and any IDs that
// were auto-paused by the TTL reaper since the last call.
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
if err != nil {
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
return
}
// Build a set of sandbox IDs that are alive on the host.
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
// Build auto-paused set from the same response.
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPausedSet[id] = struct{}{}
}
// Get all DB sandboxes for this host that are running.
// Paused sandboxes are excluded: they are expected to not exist on the
// host agent because pause = snapshot + destroy resources.
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: rc.hostID,
Column2: []string{"running"},
})
if err != nil {
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
return
}
// Find sandboxes in DB that are no longer on the host.
var stale []string
for _, sb := range dbSandboxes {
if _, ok := alive[sb.ID]; !ok {
stale = append(stale, sb.ID)
}
}
if len(stale) == 0 {
return
}
// Split stale sandboxes into those auto-paused by the TTL reaper vs
// those that crashed/were orphaned.
var toPause, toStop []string
for _, id := range stale {
if _, ok := autoPausedSet[id]; ok {
toPause = append(toPause, id)
} else {
toStop = append(toStop, id)
}
}
if len(toPause) > 0 {
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
}
}
if len(toStop) > 0 {
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
}
}
}

View File

@ -9,10 +9,14 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/channels"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
//go:embed openapi.yaml
@ -20,30 +24,54 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
router chi.Router
router chi.Router
BuildSvc *service.BuildService
}
// New constructs the chi router and registers all routes.
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
func New(
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
pgPool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
oauthRegistry *oauth.Registry,
oauthRedirectURL string,
ca *auth.CA,
al *audit.AuditLogger,
channelSvc *channels.Service,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
// Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
sandbox := newSandboxHandler(sandboxSvc)
exec := newExecHandler(queries, agent)
execStream := newExecStreamHandler(queries, agent)
files := newFilesHandler(queries, agent)
filesStream := newFilesStreamHandler(queries, agent)
snapshots := newSnapshotHandler(templateSvc, queries, agent)
authH := newAuthHandler(queries, pool, jwtSecret)
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc)
hostH := newHostHandler(hostSvc, queries)
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
authH := newAuthHandler(queries, pgPool, jwtSecret)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc, al)
hostH := newHostHandler(hostSvc, queries, al)
teamH := newTeamHandler(teamSvc, al)
usersH := newUsersHandler(queries)
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
channelH := newChannelHandler(channelSvc, al)
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@ -55,6 +83,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
// JWT-authenticated: switch active team.
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
// JWT-authenticated: API key management.
r.Route("/v1/api-keys", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
@ -63,11 +94,32 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Delete("/{id}", apiKeys.Delete)
})
// JWT-authenticated: team management.
r.Route("/v1/teams", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Get("/", teamH.List)
r.Post("/", teamH.Create)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", teamH.Get)
r.Patch("/", teamH.Rename)
r.Delete("/", teamH.Delete)
r.Get("/members", teamH.ListMembers)
r.Post("/members", teamH.AddMember)
r.Patch("/members/{uid}", teamH.UpdateMemberRole)
r.Delete("/members/{uid}", teamH.RemoveMember)
r.Post("/leave", teamH.Leave)
})
})
// JWT-authenticated: user search (for add-member UI).
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
// Sandbox lifecycle: accepts API key or JWT bearer token.
r.Route("/v1/sandboxes", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", sandbox.Get)
@ -81,6 +133,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Post("/files/read", files.Download)
r.Post("/files/stream/write", filesStream.StreamUpload)
r.Post("/files/stream/read", filesStream.StreamDownload)
r.Get("/metrics", metricsH.GetMetrics)
})
})
@ -97,6 +150,9 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
// Unauthenticated: one-time registration token.
r.Post("/register", hostH.Register)
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
@ -108,6 +164,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Route("/{id}", func(r chi.Router) {
r.Get("/", hostH.Get)
r.Delete("/", hostH.Delete)
r.Get("/delete-preview", hostH.DeletePreview)
r.Post("/token", hostH.RegenerateToken)
r.Get("/tags", hostH.ListTags)
r.Post("/tags", hostH.AddTag)
@ -116,7 +173,37 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
})
})
return &Server{router: r}
// JWT-authenticated: notification channels.
r.Route("/v1/channels", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Post("/", channelH.Create)
r.Get("/", channelH.List)
r.Post("/test", channelH.Test)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", channelH.Get)
r.Patch("/", channelH.Update)
r.Delete("/", channelH.Delete)
r.Put("/config", channelH.RotateConfig)
})
})
// JWT-authenticated: audit log.
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
// Platform admin routes — require JWT + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireAdmin(queries))
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
})
return &Server{router: r, BuildSvc: buildSvc}
}
// Handler returns the HTTP handler.
@ -137,7 +224,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrenn Sandbox API</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
<style>
body { margin: 0; background: #fafafa; }
.swagger-ui .topbar { display: none; }
@ -145,7 +232,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script src="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui-bundle.js" integrity="sha384-NXtFPpN61oWCuN4D42K6Zd5Rt2+uxeIT36R7kpXBuY9tLnZorzrJ4ykpqwJfgjpZ" crossorigin="anonymous"></script>
<script>
SwaggerUIBundle({
url: "/openapi.yaml",