1
0
forked from wrenn/wrenn

Switch database IDs from TEXT to native UUID

Consolidate 16 migrations into one with UUID columns for all entity
IDs. TEXT is kept only for polymorphic fields (audit_logs.actor_id,
resource_id) and template names. The id package now generates UUIDs
via google/uuid, with Format*/Parse* helpers for the prefixed wire
format (sb-{uuid}, usr-{uuid}, etc.). Auth context, services, and
handlers pass pgtype.UUID internally; conversion to/from prefixed
strings happens at API and RPC boundaries. Adds PlatformTeamID
(all-zeros UUID) for shared resources.
This commit is contained in:
2026-03-26 16:16:21 +06:00
parent cdd89a7cee
commit 4ddd494160
66 changed files with 1350 additions and 1127 deletions

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
@ -11,7 +13,7 @@ import (
// agentForHost looks up the host record and returns a Connect RPC client for it.
// Returns an error if the host is not found or has no address.
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID string) (hostagentv1connect.HostAgentServiceClient, error) {
func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, hostID pgtype.UUID) (hostagentv1connect.HostAgentServiceClient, error) {
host, err := queries.GetHost(ctx, hostID)
if err != nil {
return nil, fmt.Errorf("host not found: %w", err)

View File

@ -10,14 +10,17 @@ import (
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
)
// sandboxHostPattern matches hostnames like "49999-sb-abcd1234.localhost" or
// "49999-sb-abcd1234.example.com". Captures: port, sandbox ID.
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f]+)\.`)
var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(sb-[0-9a-f-]+)\.`)
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
@ -57,7 +60,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
}
port := matches[1]
sandboxID := matches[2]
sandboxIDStr := matches[2]
// Validate port.
portNum, err := strconv.Atoi(port)
@ -73,6 +76,12 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
return
}
ctx := r.Context()
// Look up sandbox and verify ownership.
@ -96,13 +105,13 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
if !agentHost.Address.Valid || agentHost.Address.String == "" {
if agentHost.Address == "" {
http.Error(w, "host agent has no address", http.StatusServiceUnavailable)
return
}
agentAddr := lifecycle.EnsureScheme(agentHost.Address.String)
upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxID, port, r.URL.Path)
agentAddr := lifecycle.EnsureScheme(agentHost.Address)
upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path)
target, err := url.Parse(agentAddr)
if err != nil {
@ -121,7 +130,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("sandbox proxy error",
"sandbox_id", sandboxID,
"sandbox_id", sandboxIDStr,
"port", port,
"error", err,
)
@ -134,16 +143,16 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
// authenticateRequest validates the request's API key and returns the team ID.
// Only API key authentication is supported for sandbox proxy requests (not JWT).
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (string, error) {
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
key := r.Header.Get("X-API-Key")
if key == "" {
return "", fmt.Errorf("X-API-Key header required")
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
}
hash := auth.HashAPIKey(key)
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
return "", fmt.Errorf("invalid API key")
return pgtype.UUID{}, fmt.Errorf("invalid API key")
}
return row.TeamID, nil
}

View File

@ -9,6 +9,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -39,11 +40,11 @@ type apiKeyResponse struct {
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
@ -57,11 +58,11 @@ func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
resp := apiKeyResponse{
ID: k.ID,
TeamID: k.TeamID,
ID: id.FormatAPIKeyID(k.ID),
TeamID: id.FormatTeamID(k.TeamID),
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: k.CreatedBy,
CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
}
if k.CreatedAt.Valid {
@ -118,7 +119,13 @@ func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /v1/api-keys/{id}.
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
keyID := chi.URLParam(r, "id")
keyIDStr := chi.URLParam(r, "id")
keyID, err := id.ParseAPIKeyID(keyIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid API key ID")
return
}
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")

View File

@ -6,7 +6,10 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -65,13 +68,24 @@ func (h *auditHandler) List(w http.ResponseWriter, r *http.Request) {
limit = n
}
// Parse ?before_id cursor (UUID).
var beforeID pgtype.UUID
if s := r.URL.Query().Get("before_id"); s != "" {
parsed, err := id.ParseAuditLogID(s)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "before_id must be a valid audit log ID")
return
}
beforeID = parsed
}
entries, err := h.svc.List(r.Context(), service.AuditListParams{
TeamID: ac.TeamID,
AdminScoped: ac.Role == "owner" || ac.Role == "admin",
ResourceTypes: parseMultiParam(r.URL.Query()["resource_type"]),
Actions: parseMultiParam(r.URL.Query()["action"]),
Before: before,
BeforeID: r.URL.Query().Get("before_id"),
BeforeID: beforeID,
Limit: limit,
})
if err != nil {

View File

@ -20,7 +20,7 @@ import (
// It prefers the user's default team; if none is flagged as default it falls
// back to the earliest-joined team. Returns pgx.ErrNoRows when the user has
// no team memberships at all.
func loginTeam(ctx context.Context, q *db.Queries, userID string) (db.Team, string, error) {
func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team, string, error) {
team, err := q.GetDefaultTeamForUser(ctx, userID)
if err == nil {
membership, err := q.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: team.ID})
@ -176,8 +176,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, authResponse{
Token: token,
UserID: userID,
TeamID: teamID,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
})
@ -236,8 +236,8 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: user.ID,
TeamID: team.ID,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
@ -260,10 +260,16 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
ctx := r.Context()
// Verify team exists and is not deleted.
team, err := h.db.GetTeam(ctx, req.TeamID)
team, err := h.db.GetTeam(ctx, teamID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusNotFound, "not_found", "team not found")
@ -280,7 +286,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
// Verify membership from DB — JWT role is not trusted here.
membership, err := h.db.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: ac.UserID,
TeamID: req.TeamID,
TeamID: teamID,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@ -298,7 +304,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -306,8 +312,8 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: ac.UserID,
TeamID: req.TeamID,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
})

View File

@ -10,6 +10,7 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
"git.omukk.dev/wrenn/sandbox/internal/validate"
)
@ -53,7 +54,7 @@ type buildResponse struct {
func buildToResponse(b db.TemplateBuild) buildResponse {
resp := buildResponse{
ID: b.ID,
ID: id.FormatBuildID(b.ID),
Name: b.Name,
BaseTemplate: b.BaseTemplate,
Recipe: b.Recipe,
@ -64,17 +65,19 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
TotalSteps: b.TotalSteps,
Logs: b.Logs,
}
if b.Healthcheck.Valid {
resp.Healthcheck = &b.Healthcheck.String
if b.Healthcheck != "" {
resp.Healthcheck = &b.Healthcheck
}
if b.Error.Valid {
resp.Error = &b.Error.String
if b.Error != "" {
resp.Error = &b.Error
}
if b.SandboxID.Valid {
resp.SandboxID = &b.SandboxID.String
s := id.FormatSandboxID(b.SandboxID)
resp.SandboxID = &s
}
if b.HostID.Valid {
resp.HostID = &b.HostID.String
s := id.FormatHostID(b.HostID)
resp.HostID = &s
}
if b.CreatedAt.Valid {
resp.CreatedAt = b.CreatedAt.Time.Format(time.RFC3339)
@ -146,7 +149,13 @@ func (h *buildHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/admin/builds/{id}.
func (h *buildHandler) Get(w http.ResponseWriter, r *http.Request) {
buildID := chi.URLParam(r, "id")
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.svc.Get(r.Context(), buildID)
if err != nil {

View File

@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -46,10 +47,16 @@ type execResponse struct {
// Exec handles POST /v1/sandboxes/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -80,7 +87,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
@ -101,7 +108,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
// Use base64 encoding if output contains non-UTF-8 bytes.
@ -112,7 +119,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
@ -124,7 +131,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),

View File

@ -14,6 +14,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -48,10 +49,16 @@ type wsOutMsg struct {
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -91,7 +98,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
defer cancel()
stream, err := agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Cmd: startMsg.Cmd,
Args: startMsg.Args,
}))
@ -157,7 +164,7 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
}

View File

@ -11,6 +11,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -29,10 +30,16 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -82,7 +89,7 @@ func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
}
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
Content: content,
})); err != nil {
@ -101,10 +108,16 @@ type readFileRequest struct {
// Download handles POST /v1/sandboxes/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -133,7 +146,7 @@ func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
}
resp, err := agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -12,6 +12,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -29,10 +30,16 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -101,7 +108,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
if err := stream.Send(&pb.WriteFileStreamRequest{
Content: &pb.WriteFileStreamRequest_Meta{
Meta: &pb.WriteFileStreamMeta{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: filePath,
},
},
@ -146,10 +153,16 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -178,7 +191,7 @@ func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Reque
// Open server-streaming RPC to host agent.
stream, err := agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {

View File

@ -8,9 +8,12 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -93,34 +96,35 @@ type hostResponse struct {
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
ID: h.ID,
ID: id.FormatHostID(h.ID),
Type: h.Type,
Status: h.Status,
CreatedBy: h.CreatedBy,
CreatedBy: id.FormatUserID(h.CreatedBy),
}
if h.TeamID.Valid {
resp.TeamID = &h.TeamID.String
s := id.FormatTeamID(h.TeamID)
resp.TeamID = &s
}
if h.Provider.Valid {
resp.Provider = &h.Provider.String
if h.Provider != "" {
resp.Provider = &h.Provider
}
if h.AvailabilityZone.Valid {
resp.AvailabilityZone = &h.AvailabilityZone.String
if h.AvailabilityZone != "" {
resp.AvailabilityZone = &h.AvailabilityZone
}
if h.Arch.Valid {
resp.Arch = &h.Arch.String
if h.Arch != "" {
resp.Arch = &h.Arch
}
if h.CpuCores.Valid {
resp.CPUCores = &h.CpuCores.Int32
if h.CpuCores != 0 {
resp.CPUCores = &h.CpuCores
}
if h.MemoryMb.Valid {
resp.MemoryMB = &h.MemoryMb.Int32
if h.MemoryMb != 0 {
resp.MemoryMB = &h.MemoryMb
}
if h.DiskGb.Valid {
resp.DiskGB = &h.DiskGb.Int32
if h.DiskGb != 0 {
resp.DiskGB = &h.DiskGb
}
if h.Address.Valid {
resp.Address = &h.Address.String
if h.Address != "" {
resp.Address = &h.Address
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
@ -133,7 +137,7 @@ func hostToResponse(h db.Host) hostResponse {
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
@ -151,14 +155,23 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
Type: req.Type,
TeamID: req.TeamID,
Provider: req.Provider,
AvailabilityZone: req.AvailabilityZone,
RequestingUserID: ac.UserID,
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
})
// Parse optional team ID from request body.
var params service.HostCreateParams
params.Type = req.Type
params.Provider = req.Provider
params.AvailabilityZone = req.AvailabilityZone
params.RequestingUserID = ac.UserID
params.IsRequestorAdmin = h.isAdmin(r, ac.UserID)
if req.TeamID != "" {
teamID, err := id.ParseTeamID(req.TeamID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team_id")
return
}
params.TeamID = teamID
}
result, err := h.svc.Create(r.Context(), params)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -166,8 +179,7 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Log audit for the owning team (BYOC hosts have a team; shared hosts use caller's team).
hostTeamID := result.Host.TeamID.String
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, hostTeamID)
h.audit.LogHostCreate(r.Context(), ac, result.Host.ID, result.Host.TeamID)
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
@ -192,14 +204,22 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
seen[host.TeamID.String] = struct{}{}
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for id := range seen {
if team, err := h.queries.GetTeam(r.Context(), id); err == nil {
teamNames[id] = team.Name
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
@ -209,7 +229,8 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
for i, host := range hosts {
resp[i] = hostToResponse(host)
if host.TeamID.Valid {
if name, ok := teamNames[host.TeamID.String]; ok {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
@ -220,9 +241,15 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -236,9 +263,15 @@ func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
// DeletePreview handles GET /v1/hosts/{id}/delete-preview.
// Returns what would be affected without making changes, for confirmation UI.
func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
preview, err := h.svc.DeletePreview(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -256,19 +289,25 @@ func (h *hostHandler) DeletePreview(w http.ResponseWriter, r *http.Request) {
// Without ?force=true: returns 409 with affected sandbox IDs if any are active.
// With ?force=true: gracefully stops all sandboxes then deletes the host.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
force := r.URL.Query().Get("force") == "true"
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Fetch host before deletion to capture team_id for audit.
deletedHost, hostErr := h.queries.GetHost(r.Context(), hostID)
if hostErr != nil {
slog.Warn("audit: could not fetch host before delete", "host_id", hostID, "error", hostErr)
slog.Warn("audit: could not fetch host before delete", "host_id", hostIDStr, "error", hostErr)
}
err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
err = h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID), force)
if err == nil {
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID.String)
h.audit.LogHostDelete(r.Context(), ac, hostID, deletedHost.TeamID)
w.WriteHeader(http.StatusNoContent)
return
}
@ -292,9 +331,15 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -348,9 +393,15 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
@ -368,7 +419,7 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// Log marked_up if the host just recovered from unreachable.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID.String, hc.HostID)
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
}
w.WriteHeader(http.StatusNoContent)
@ -376,10 +427,16 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
@ -401,10 +458,16 @@ func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -443,9 +506,15 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hostIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
hostID, err := id.ParseHostID(hostIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid host ID")
return
}
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)

View File

@ -7,9 +7,11 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -38,10 +40,16 @@ type metricsResponse struct {
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
rangeTier := r.URL.Query().Get("range")
if rangeTier == "" {
rangeTier = "10m"
@ -60,15 +68,15 @@ func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Reques
switch sb.Status {
case "running":
h.getFromAgent(w, r, sandboxID, rangeTier, sb.HostID)
h.getFromAgent(w, r, sandboxIDStr, rangeTier, sb.HostID)
case "paused":
h.getFromDB(ctx, w, sandboxID, rangeTier)
h.getFromDB(ctx, w, sandboxIDStr, sandboxID, rangeTier)
default:
writeError(w, http.StatusNotFound, "not_found", "metrics not available for sandbox in state: "+sb.Status)
}
}
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxID, rangeTier, hostID string) {
func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Request, sandboxIDStr, rangeTier string, hostID pgtype.UUID) {
ctx := r.Context()
agent, err := agentForHost(ctx, h.db, h.pool, hostID)
@ -78,7 +86,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
SandboxId: sandboxID,
SandboxId: sandboxIDStr,
Range: rangeTier,
}))
if err != nil {
@ -98,7 +106,7 @@ func (h *sandboxMetricsHandler) getFromAgent(w http.ResponseWriter, r *http.Requ
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})
@ -118,7 +126,7 @@ var rangeToDB = map[string]struct {
"24h": {"24h", 24 * time.Hour},
}
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxID, rangeTier string) {
func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWriter, sandboxIDStr string, sandboxID pgtype.UUID, rangeTier string) {
mapping := rangeToDB[rangeTier]
rows, err := h.db.GetSandboxMetricPoints(ctx, db.GetSandboxMetricPointsParams{
SandboxID: sandboxID,
@ -141,7 +149,7 @@ func (h *sandboxMetricsHandler) getFromDB(ctx context.Context, w http.ResponseWr
}
writeJSON(w, http.StatusOK, metricsResponse{
SandboxID: sandboxID,
SandboxID: sandboxIDStr,
Range: rangeTier,
Points: points,
})

View File

@ -162,7 +162,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -262,7 +262,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
redirectWithToken(w, r, redirectBase, token, userID, teamID, email, profile.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -296,7 +296,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email, user.Name)
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {

View File

@ -10,6 +10,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -46,7 +47,7 @@ type sandboxResponse struct {
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
resp := sandboxResponse{
ID: sb.ID,
ID: id.FormatSandboxID(sb.ID),
Status: sb.Status,
Template: sb.Template,
VCPUs: sb.Vcpus,
@ -81,7 +82,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
}
ac := auth.MustFromContext(r.Context())
if ac.TeamID == "" {
if !ac.TeamID.Valid {
writeError(w, http.StatusForbidden, "no_team", "no active team context; re-authenticate")
return
}
@ -122,9 +123,15 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
// Get handles GET /v1/sandboxes/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
@ -136,9 +143,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
// Pause handles POST /v1/sandboxes/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -152,9 +165,15 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
// Resume handles POST /v1/sandboxes/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -168,9 +187,15 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
// Ping handles POST /v1/sandboxes/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
@ -182,9 +207,15 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
// Destroy handles DELETE /v1/sandboxes/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxID := chi.URLParam(r, "id")
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)

View File

@ -10,7 +10,6 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
@ -51,7 +50,7 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, name stri
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: name})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", host.ID, "name", name, "error", err)
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
}
@ -78,11 +77,11 @@ func templateToResponse(t db.Template) snapshotResponse {
Type: t.Type,
SizeBytes: t.SizeBytes,
}
if t.Vcpus.Valid {
resp.VCPUs = &t.Vcpus.Int32
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
}
if t.MemoryMb.Valid {
resp.MemoryMB = &t.MemoryMb.Int32
if t.MemoryMb != 0 {
resp.MemoryMB = &t.MemoryMb
}
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -103,6 +102,12 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
@ -133,7 +138,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
@ -162,7 +167,7 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
if sb.Status != "paused" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: req.SandboxID, Status: "paused",
ID: sandboxID, Status: "paused",
}); err != nil {
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
}
@ -171,8 +176,8 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
Name: req.Name,
Type: "snapshot",
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
})

View File

@ -7,10 +7,12 @@ import (
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
@ -48,7 +50,7 @@ type memberResponse struct {
func teamToResponse(t db.Team) teamResponse {
resp := teamResponse{
ID: t.ID,
ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
@ -72,11 +74,16 @@ func memberInfoToResponse(m service.MemberInfo) memberResponse {
// requireTeamAccess is an inline check used by every team-scoped handler:
// the JWT team_id must match the URL {id} before any DB call is made.
// Returns false and writes 403 if they don't match.
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (string, bool) {
teamID := chi.URLParam(r, "id")
func requireTeamAccess(w http.ResponseWriter, r *http.Request, ac auth.AuthContext) (pgtype.UUID, bool) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return pgtype.UUID{}, false
}
if ac.TeamID != teamID {
writeError(w, http.StatusForbidden, "forbidden", "JWT team does not match requested team; use switch-team first")
return "", false
return pgtype.UUID{}, false
}
return teamID, true
}
@ -185,7 +192,7 @@ func (h *teamHandler) Rename(w http.ResponseWriter, r *http.Request) {
// Fetch old name for audit log before renaming.
oldTeam, err := h.svc.GetTeam(r.Context(), teamID)
if err != nil {
slog.Warn("audit: could not fetch old team name for rename log", "team_id", teamID, "error", err)
slog.Warn("audit: could not fetch old team name for rename log", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := h.svc.RenameTeam(r.Context(), teamID, ac.UserID, req.Name); err != nil {
@ -267,7 +274,11 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
return
}
h.audit.LogMemberAdd(r.Context(), ac, member.UserID, member.Email, member.Role)
// member.UserID is already formatted with prefix; parse it back for the audit logger.
targetUserID, parseErr := id.ParseUserID(member.UserID)
if parseErr == nil {
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
}
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
@ -279,7 +290,13 @@ func (h *teamHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
targetUserID := chi.URLParam(r, "uid")
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
if err := h.svc.RemoveMember(r.Context(), teamID, ac.UserID, targetUserID); err != nil {
status, code, msg := serviceErrToHTTP(err)
@ -299,7 +316,13 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
targetUserID := chi.URLParam(r, "uid")
targetUserIDStr := chi.URLParam(r, "uid")
targetUserID, err := id.ParseUserID(targetUserIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Role string `json:"role"`
@ -341,7 +364,13 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
// Enables or disables the BYOC feature flag for a team.
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
teamID := chi.URLParam(r, "id")
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return
}
var req struct {
Enabled bool `json:"enabled"`

View File

@ -8,6 +8,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type usersHandler struct {
@ -45,7 +46,7 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
}
resp := make([]userResult, len(results))
for i, u := range results {
resp[i] = userResult{UserID: u.ID, Email: u.Email}
resp[i] = userResult{UserID: id.FormatUserID(u.ID), Email: u.Email}
}
writeJSON(w, http.StatusOK, resp)
}

View File

@ -6,9 +6,11 @@ import (
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
@ -82,15 +84,15 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
time.Since(host.LastHeartbeatAt.Time) > unreachableThreshold
if stale && host.Status != "unreachable" {
slog.Info("host monitor: marking host unreachable", "host_id", host.ID,
slog.Info("host monitor: marking host unreachable", "host_id", id.FormatHostID(host.ID),
"last_heartbeat", host.LastHeartbeatAt.Time)
if err := m.db.MarkHostUnreachable(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark host unreachable", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark host unreachable", "host_id", id.FormatHostID(host.ID), "error", err)
}
if err := m.db.MarkSandboxesMissingByHost(ctx, host.ID); err != nil {
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark sandboxes missing", "host_id", id.FormatHostID(host.ID), "error", err)
}
m.audit.LogHostMarkedDown(ctx, host.TeamID.String, host.ID)
m.audit.LogHostMarkedDown(ctx, host.TeamID, host.ID)
return
}
@ -110,19 +112,20 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
if err != nil {
// RPC failure is a transient condition; the passive phase will catch it
// if heartbeats stop arriving.
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", host.ID, "error", err)
slog.Debug("host monitor: ListSandboxes failed (transient)", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, id := range resp.Msg.AutoPausedSandboxIds {
autoPaused[id] = struct{}{}
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
@ -134,30 +137,31 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"missing"},
})
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []string
var toStop []string
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
for _, sb := range missingSandboxes {
if _, ok := alive[sb.ID]; ok {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
}
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", host.ID, "count", len(toRestore))
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", host.ID, "count", len(toStop))
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}
@ -169,18 +173,19 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
Column2: []string{"running"},
})
if err != nil {
slog.Warn("host monitor: failed to list running sandboxes", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to list running sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
var toPause, toStop []string
sbTeamID := make(map[string]string, len(runningSandboxes))
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sb.ID]; ok {
if _, ok := alive[sbIDStr]; ok {
continue
}
if _, ok := autoPaused[sb.ID]; ok {
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
@ -188,24 +193,24 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", host.ID, "count", len(toPause))
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", host.ID, "count", len(toStop))
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", host.ID, "error", err)
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
}
}
}

View File

@ -7,6 +7,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@ -24,7 +25,7 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
@ -45,9 +46,20 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,

View File

@ -4,6 +4,7 @@ import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
@ -23,7 +24,13 @@ func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
hostID, err := id.ParseHostID(claims.HostID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid host ID in token")
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: hostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
@ -25,9 +26,20 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: claims.TeamID,
UserID: claims.Subject,
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,