1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/sandbox#8
This commit is contained in:
2026-04-09 19:24:49 +00:00
parent 32e5a5a715
commit d3e4812e46
199 changed files with 24552 additions and 2776 deletions

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
@ -22,7 +24,7 @@ type APIKeyCreateResult struct {
}
// Create generates a new API key for the given team.
func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string)
}
// List returns all API keys belonging to the given team.
func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}

113
internal/service/audit.go Normal file
View File

@ -0,0 +1,113 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
const auditMaxLimit = 200
// AuditEntry is a single audit log record returned by List.
type AuditEntry struct {
ID string
TeamID string
ActorType string
ActorID string // empty for system
ActorName string // empty for system
ResourceType string
ResourceID string // empty when not applicable
Action string
Scope string
Status string // 'success', 'info', 'warning', 'error'
Metadata map[string]any
CreatedAt time.Time
}
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
TeamID pgtype.UUID
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
ResourceTypes []string // empty = no filter; multiple values = OR match
Actions []string // empty = no filter; multiple values = OR match
Before time.Time // zero = no cursor (start from latest)
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
type AuditService struct {
DB *db.Queries
}
// List returns a page of audit log entries for the given team.
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
limit := p.Limit
if limit <= 0 {
limit = 50
}
if limit > auditMaxLimit {
limit = auditMaxLimit
}
scopes := []string{"team"}
if p.AdminScoped {
scopes = append(scopes, "admin")
}
var before pgtype.Timestamptz
if !p.Before.IsZero() {
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
}
resourceTypes := p.ResourceTypes
if resourceTypes == nil {
resourceTypes = []string{}
}
actions := p.Actions
if actions == nil {
actions = []string{}
}
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
TeamID: p.TeamID,
Column2: scopes,
Column3: resourceTypes,
Column4: actions,
Column5: before,
ID: p.BeforeID,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list audit logs: %w", err)
}
entries := make([]AuditEntry, len(rows))
for i, row := range rows {
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
ID: id.FormatAuditLogID(row.ID),
TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
Scope: row.Scope,
Status: row.Status,
Metadata: meta,
CreatedAt: row.CreatedAt.Time,
}
}
return entries, nil
}

605
internal/service/build.go Normal file
View File

@ -0,0 +1,605 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/recipe"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
const (
buildQueueKey = "wrenn:build_queue"
buildCommandTimeout = 30 * time.Second
)
// preBuildCmds run before the user recipe to prepare the build environment.
var preBuildCmds = []string{
"RUN apt update",
}
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
var postBuildCmds = []string{
"RUN apt clean",
"RUN apt autoremove -y",
"RUN rm -rf /var/lib/apt/lists/*",
}
// buildAgentClient is the subset of the host agent client used by the build worker.
type buildAgentClient interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
}
// BuildService handles template build orchestration.
type BuildService struct {
DB *db.Queries
Redis *redis.Client
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
mu sync.Mutex
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
}
// BuildCreateParams holds the parameters for creating a template build.
type BuildCreateParams struct {
Name string
BaseTemplate string
Recipe []string
Healthcheck string
VCPUs int32
MemoryMB int32
SkipPrePost bool
}
// Create inserts a new build record and enqueues it to Redis.
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
if p.BaseTemplate == "" {
p.BaseTemplate = "minimal"
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
recipeJSON, err := json.Marshal(p.Recipe)
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
}
buildID := id.NewBuildID()
buildIDStr := id.FormatBuildID(buildID)
newTemplateID := id.NewTemplateID()
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
if p.SkipPrePost {
defaultSteps = 0
}
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
ID: buildID,
Name: p.Name,
BaseTemplate: p.BaseTemplate,
Recipe: recipeJSON,
Healthcheck: p.Healthcheck,
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TotalSteps: int32(len(p.Recipe) + defaultSteps),
TemplateID: newTemplateID,
TeamID: id.PlatformTeamID,
SkipPrePost: p.SkipPrePost,
})
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
}
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
}
return build, nil
}
// Get returns a single build by ID.
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
return s.DB.GetTemplateBuild(ctx, buildID)
}
// List returns all builds ordered by creation time.
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
return s.DB.ListTemplateBuilds(ctx)
}
// Cancel cancels a pending or running build. For pending builds the status is
// updated in the DB and the worker skips it when dequeued. For running builds
// the per-build context is cancelled, which causes the current exec step to
// abort; executeBuild then detects the cancellation and records the status.
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
build, err := s.DB.GetTemplateBuild(ctx, buildID)
if err != nil {
return fmt.Errorf("get build: %w", err)
}
switch build.Status {
case "success", "failed", "cancelled":
return fmt.Errorf("build is already %s", build.Status)
}
// Mark cancelled in DB first. This handles both pending builds (which haven't
// been picked up yet) and acts as a flag for executeBuild to check on start.
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
ID: buildID, Status: "cancelled",
}); err != nil {
return fmt.Errorf("update build status: %w", err)
}
// If the build is currently running, signal its context.
buildIDStr := id.FormatBuildID(buildID)
s.mu.Lock()
cancel, running := s.cancelMap[buildIDStr]
s.mu.Unlock()
if running {
cancel()
}
return nil
}
// StartWorkers launches n goroutines that consume from the Redis build queue.
// The returned cancel function stops all workers.
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
for i := range n {
go s.worker(ctx, i)
}
slog.Info("build workers started", "count", n)
return cancel
}
func (s *BuildService) worker(ctx context.Context, workerID int) {
log := slog.With("worker", workerID)
for {
// BLPOP blocks until a build ID is available or context is cancelled.
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
if err != nil {
if ctx.Err() != nil {
log.Info("build worker shutting down")
return
}
log.Error("redis BLPOP error", "error", err)
time.Sleep(time.Second)
continue
}
// result[0] is the key, result[1] is the build ID (formatted string).
buildIDStr := result[1]
log.Info("picked up build", "build_id", buildIDStr)
s.executeBuild(ctx, buildIDStr)
}
}
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
log := slog.With("build_id", buildIDStr)
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
log.Error("invalid build ID from queue", "error", err)
return
}
// Create a per-build context so this build can be cancelled independently of
// the worker. Register in cancelMap before fetching the build so that a
// concurrent Cancel call can always find and signal it.
buildCtx, buildCancel := context.WithCancel(ctx)
defer buildCancel()
s.mu.Lock()
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[buildIDStr] = buildCancel
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.cancelMap, buildIDStr)
s.mu.Unlock()
}()
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
if err != nil {
log.Error("failed to fetch build", "error", err)
return
}
// Skip if already cancelled (Cancel was called before we dequeued).
if build.Status == "cancelled" {
log.Info("build already cancelled, skipping")
return
}
// Mark as running.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "running",
}); err != nil {
log.Error("failed to update build status", "error", err)
return
}
// Parse user recipe.
var userRecipe []string
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
return
}
// Pick a platform host and create a sandbox.
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
return
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
return
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
baseTeamID := id.PlatformTeamID
baseTemplateID := id.MinimalTemplateID
if build.BaseTemplate != "minimal" {
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
return
}
baseTeamID = baseTmpl.TeamID
baseTemplateID = baseTmpl.ID
}
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: build.BaseTemplate,
TeamId: id.UUIDString(baseTeamID),
TemplateId: id.UUIDString(baseTemplateID),
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
TimeoutSec: 0, // no auto-pause for builds
DiskSizeMb: 5120, // 5 GB for template builds
}))
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
return
}
_ = resp
// Record sandbox/host association.
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
ID: buildID,
SandboxID: sandboxID,
HostID: host.ID,
})
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
// valid; panic on error is appropriate here since it would be a programmer mistake.
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
}
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
return
}
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
}
var logs []recipe.BuildLogEntry
step := 0
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
if err != nil {
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
envVars = map[string]string{
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"HOME": "/root",
}
}
bctx := &recipe.ExecContext{EnvVars: envVars}
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
logs = append(logs, newEntries...)
step = nextStep
s.updateLogs(buildCtx, buildID, step, logs)
if !ok {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
// If the build was cancelled, status is already set — don't overwrite with "failed".
if buildCtx.Err() != nil {
return false
}
last := newEntries[len(newEntries)-1]
reason := last.Stderr
if reason == "" {
reason = fmt.Sprintf("exit code %d", last.Exit)
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
}
return ok
}
if !build.SkipPrePost {
if !runPhase("pre-build", preBuildSteps, 0) {
return
}
}
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
return
}
if !build.SkipPrePost {
if !runPhase("post-build", postBuildSteps, 0) {
return
}
}
// Healthcheck or direct snapshot.
var sizeBytes int64
if build.Healthcheck != "" {
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
return
}
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
return
}
// Healthcheck passed → full snapshot (with memory/CPU state).
log.Info("healthcheck passed, creating snapshot")
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
return
}
sizeBytes = snapResp.Msg.SizeBytes
} else {
// No healthcheck → image-only template (rootfs only).
log.Info("no healthcheck, flattening rootfs")
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
return
}
sizeBytes = flatResp.Msg.SizeBytes
}
// Insert into templates table as a global (platform) template.
templateType := "base"
if build.Healthcheck != "" {
templateType = "snapshot"
}
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
ID: build.TemplateID,
Name: build.Name,
Type: templateType,
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
SizeBytes: sizeBytes,
TeamID: id.PlatformTeamID,
}); err != nil {
log.Error("failed to insert template record", "error", err)
// Build succeeded on disk, just DB record failed — don't mark as failed.
}
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
// No additional destroy needed.
// Mark build as success.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "success",
}); err != nil {
log.Error("failed to mark build as success", "error", err)
}
log.Info("template build completed successfully", "name", build.Name)
}
// waitForHealthcheck repeatedly executes the healthcheck command inside the
// sandbox according to the config's interval, timeout, start-period, and
// retries.
// During the start period, failures are not counted toward the retry budget.
// Returns nil on the first successful check, or an error if retries are
// exhausted, the deadline passes, or the context is cancelled.
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
ticker := time.NewTicker(hc.Interval)
defer ticker.Stop()
// When retries > 0, set a deadline based on the retry budget.
// When retries == 0 (unlimited), rely solely on the parent context deadline.
var deadlineCh <-chan time.Time
if hc.Retries > 0 {
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
defer deadline.Stop()
deadlineCh = deadline.C
}
startedAt := time.Now()
failCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadlineCh:
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
case <-ticker.C:
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", hc.Cmd},
TimeoutSec: int32(hc.Timeout.Seconds()),
}))
cancel()
if err != nil {
slog.Debug("healthcheck exec error (retrying)", "error", err)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
}
}
continue
}
if resp.Msg.ExitCode == 0 {
return nil
}
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
}
}
}
}
}
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
logsJSON, err := json.Marshal(logs)
if err != nil {
slog.Warn("failed to marshal build logs", "error", err)
return
}
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
ID: buildID,
CurrentStep: int32(step),
Logs: logsJSON,
}); err != nil {
slog.Warn("failed to update build progress", "error", err)
}
}
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
ID: buildID,
Error: errMsg,
}); err != nil {
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
}
}
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
// Use a detached context so cleanup succeeds even during shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
}
}
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
// the build agent and returns environment variables
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", "env"},
TimeoutSec: 10,
}))
if err != nil {
return nil, fmt.Errorf("fetch env: %w", err)
}
if resp.Msg.ExitCode != 0 {
return nil, fmt.Errorf("fetch env: command exited with code %d",
resp.Msg.ExitCode)
}
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
}
// parseSandboxEnv converts the raw newline-separated output of an 'env'
// command into a map.
// It skips empty lines and malformed entries, and correctly handles values
// containing '='.
func parseSandboxEnv(raw string) map[string]string {
envVars := make(map[string]string)
for line := range strings.SplitSeq(raw, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
envVars[parts[0]] = parts[1]
}
return envVars
}

View File

@ -2,12 +2,14 @@ package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
@ -15,6 +17,8 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
// HostService provides host management operations.
@ -22,15 +26,17 @@ type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID string // required for BYOC, empty for regular
TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
RequestingUserID string
RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
@ -50,10 +56,34 @@ type HostRegisterParams struct {
Address string
}
// HostRegisterResult holds the registered host and its long-lived JWT.
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
@ -64,6 +94,14 @@ type regTokenPayload struct {
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
@ -75,8 +113,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: admin or team owner.
if p.TeamID == "" {
// BYOC: platform admin, or team owner/admin.
if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
@ -90,40 +128,31 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
// Validate team exists for BYOC hosts.
if p.TeamID != "" {
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
// Validate team exists, is not deleted, and has BYOC enabled.
if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
if !team.IsByoc {
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
}
}
hostID := id.NewHostID()
var teamID pgtype.Text
if p.TeamID != "" {
teamID = pgtype.Text{String: p.TeamID, Valid: true}
}
var provider pgtype.Text
if p.Provider != "" {
provider = pgtype.Text{String: p.Provider, Valid: true}
}
var az pgtype.Text
if p.AvailabilityZone != "" {
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
}
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: teamID,
Provider: provider,
AvailabilityZone: az,
TeamID: p.TeamID,
Provider: p.Provider,
AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
@ -135,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@ -149,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
@ -158,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
@ -167,12 +196,11 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
// Same permission model as Create/Delete.
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
@ -185,8 +213,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
@ -194,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
@ -208,14 +236,14 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a long-lived host JWT.
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
@ -232,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
hostID, err := id.ParseHostID(payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
tokenID, err := id.ParseHostTokenID(payload.TokenID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: payload.HostID,
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
ID: hostID,
Arch: p.Arch,
CpuCores: p.CPUCores,
MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB,
Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
@ -259,82 +307,293 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, payload.HostID)
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// Heartbeat updates the last heartbeat timestamp for a host.
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
return s.DB.UpdateHostHeartbeat(ctx, hostID)
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
}
if n == 0 {
return fmt.Errorf("host not found")
}
return nil
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID.String != teamID {
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// Delete removes a host. Admins can delete any host. Team owners can delete
// BYOC hosts belonging to their team.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
host, err := s.DB.GetHost(ctx, hostID)
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return fmt.Errorf("host not found: %w", err)
return HostDeletePreview{}, err
}
if !isAdmin {
if host.Type != "byoc" {
return fmt.Errorf("forbidden: only admins can delete regular hosts")
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return fmt.Errorf("forbidden: host does not belong to your team")
return &HostHasSandboxesError{SandboxIDs: ids}
}
hostIDStr := id.FormatHostID(hostID)
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("forbidden: not a member of the specified team")
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err != nil {
return fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return host, nil
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@ -342,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
@ -350,9 +609,20 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

View File

@ -11,29 +11,60 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
"git.omukk.dev/wrenn/sandbox/internal/validate"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
)
// SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard.
type SandboxService struct {
DB *db.Queries
Agent hostagentv1connect.HostAgentServiceClient
DB *db.Queries
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
}
// SandboxCreateParams holds the parameters for creating a sandbox.
type SandboxCreateParams struct {
TeamID string
TeamID pgtype.UUID
Template string
VCPUs int32
MemoryMB int32
TimeoutSec int32
DiskSizeMB int32
}
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
// and updates the record to running. Returns the sandbox DB row.
// agentForSandbox looks up the host for the given sandbox and returns a client.
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" {
p.Template = "minimal"
@ -47,44 +78,82 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
if p.DiskSizeMB <= 0 {
p.DiskSizeMB = 5120 // 5 GB default
}
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
if tmpl.Vcpus.Valid {
p.VCPUs = tmpl.Vcpus.Int32
// Resolve template name → (teamID, templateID).
templateTeamID := id.PlatformTeamID
templateID := id.MinimalTemplateID
if p.Template != "minimal" {
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
}
if tmpl.MemoryMb.Valid {
p.MemoryMB = tmpl.MemoryMb.Int32
templateTeamID = tmpl.TeamID
templateID = tmpl.ID
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl.Type == "snapshot" {
p.VCPUs = tmpl.Vcpus
p.MemoryMB = tmpl.MemoryMb
}
}
if !p.TeamID.Valid {
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
}
// Determine whether this team uses BYOC hosts or platform hosts.
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
}
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID,
TeamID: p.TeamID,
HostID: "default",
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
ID: sandboxID,
TeamID: p.TeamID,
HostID: host.ID,
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
TemplateID: templateID,
TemplateTeamID: templateTeamID,
}); err != nil {
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxID,
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: p.Template,
TeamId: id.UUIDString(templateTeamID),
TemplateId: id.UUIDString(templateID),
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
}))
if err != nil {
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "error",
}); dbErr != nil {
slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
}
@ -107,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
}
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
return s.DB.ListSandboxesByTeam(ctx, teamID)
}
// Get returns a single sandbox by ID, scoped to the given team.
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
}
// Pause snapshots and freezes a running sandbox to disk.
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@ -126,23 +195,45 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
// mark the sandbox "stopped" while the host agent processes the pause.
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
}
// Flush all metrics tiers before pausing so data survives in DB.
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
// Revert status on failure.
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}
sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
})
sb, err = s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
}
return sb, nil
}
// Resume restores a paused sandbox from snapshot.
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
@ -151,8 +242,15 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
}
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxIDStr,
TimeoutSec: sb.TimeoutSec,
}))
if err != nil {
@ -176,18 +274,41 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
}
// Destroy stops a sandbox and marks it as stopped.
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
if _, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}); err != nil {
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// If running, flush 24h tier metrics for analytics before destroying.
if sb.Status == "running" {
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxID,
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
}
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
if sb.Status == "paused" {
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "10m",
})
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "2h",
})
}
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "stopped",
}); err != nil {
@ -196,8 +317,45 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
return nil
}
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
// the returned data to DB. If allTiers is true, all three tiers are saved;
// otherwise only the 24h tier (for post-destroy analytics).
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
SandboxId: sandboxIDStr,
}))
if err != nil {
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
return
}
msg := resp.Msg
if allTiers {
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
}
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
}
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
for _, p := range points {
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
SandboxID: sandboxID,
Tier: tier,
Ts: p.TimestampUnix,
CpuPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}); err != nil {
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
}
}
}
// Ping resets the inactivity timer for a running sandbox.
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
@ -206,8 +364,15 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxID,
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)
}
@ -219,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
Valid: true,
},
}); err != nil {
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
}
return nil
}

160
internal/service/stats.go Normal file
View File

@ -0,0 +1,160 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
// TimeRange identifies a chart time window.
type TimeRange string
const (
Range5m TimeRange = "5m"
Range1h TimeRange = "1h"
Range6h TimeRange = "6h"
Range24h TimeRange = "24h"
Range30d TimeRange = "30d"
)
type rangeConfig struct {
bucketSec int // bucket width in seconds for time-series aggregation
intervalLiteral string // PostgreSQL interval literal for the lookback window
}
var rangeConfigs = map[TimeRange]rangeConfig{
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
}
// ValidRange returns true if r is a known TimeRange value.
func ValidRange(r TimeRange) bool {
_, ok := rangeConfigs[r]
return ok
}
// StatPoint is one bucketed data point in the time-series.
type StatPoint struct {
Bucket time.Time
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// CurrentStats holds the live values for a team, read directly from sandboxes.
type CurrentStats struct {
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// PeakStats holds the 30-day maximum values for a team.
type PeakStats struct {
RunningCount int32
VCPUs int32
MemoryMB int32
}
// StatsService computes sandbox metrics for the dashboard.
type StatsService struct {
DB *db.Queries
Pool *pgxpool.Pool
}
// GetStats returns current stats, 30-day peaks, and a time-series for the
// given team and time range. If no snapshots exist yet, zeros are returned.
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
cfg, ok := rangeConfigs[r]
if !ok {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
}
// Current live values — read directly from sandboxes so we always reflect
// the true state even when no capsules are running.
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
}
current := CurrentStats{
RunningCount: cur.RunningCount,
VCPUsReserved: cur.VcpusReserved,
MemoryMBReserved: cur.MemoryMbReserved,
}
// 30-day peaks.
var peaks PeakStats
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
}
if err == nil {
peaks = PeakStats{
RunningCount: pk.PeakRunningCount,
VCPUs: pk.PeakVcpus,
MemoryMB: pk.PeakMemoryMb,
}
}
// Time-series — dynamic bucket width, executed via pgx directly.
series, err := s.queryTimeSeries(ctx, teamID, cfg)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
}
return current, peaks, series, nil
}
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
// integer number of seconds without requiring TimescaleDB.
//
// MAX is used instead of AVG so that short-lived running states are not
// averaged down to zero within a bucket. For capacity metrics the peak
// value in each bucket is what matters — AVG with ::INTEGER rounding
// caused running_count, vcpus, and memory to become inconsistent with
// each other (e.g. running=0 but vcpus=1).
//
// $1 = bucket width in seconds (integer)
// $2 = team_id
// $3 = lookback interval literal (e.g. '1 hour')
const timeSeriesSQL = `
SELECT
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
MAX(running_count) AS running_count,
MAX(vcpus_reserved) AS vcpus_reserved,
MAX(memory_mb_reserved) AS memory_mb_reserved
FROM sandbox_metrics_snapshots
WHERE team_id = $2
AND sampled_at >= NOW() - $3::INTERVAL
GROUP BY bucket
ORDER BY bucket ASC
`
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
if err != nil {
return nil, err
}
defer rows.Close()
var points []StatPoint
for rows.Next() {
var p StatPoint
var bucket time.Time
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
return nil, err
}
p.Bucket = bucket
points = append(points, p)
}
return points, rows.Err()
}

443
internal/service/team.go Normal file
View File

@ -0,0 +1,443 @@
package service
import (
"context"
"fmt"
"log/slog"
"regexp"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
)
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
// TeamService provides team management operations.
type TeamService struct {
DB *db.Queries
Pool *pgxpool.Pool
HostPool *lifecycle.HostClientPool
}
// TeamWithRole pairs a team with the calling user's role in it.
type TeamWithRole struct {
db.Team
Role string `json:"role"`
}
// MemberInfo is a team member with resolved user details.
type MemberInfo struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt time.Time `json:"joined_at"`
}
// callerRole fetches the calling user's role in the given team from DB.
// Returns an error wrapping "forbidden" if the caller is not a member.
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: callerUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("forbidden: not a member of this team")
}
return "", fmt.Errorf("get membership: %w", err)
}
return m.Role, nil
}
// requireAdmin returns an error if the caller is not an admin or owner.
func requireAdmin(role string) error {
if role != "owner" && role != "admin" {
return fmt.Errorf("forbidden: admin or owner role required")
}
return nil
}
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
if err == pgx.ErrNoRows {
return db.Team{}, fmt.Errorf("team not found")
}
return db.Team{}, fmt.Errorf("get team: %w", err)
}
if team.DeletedAt.Valid {
return db.Team{}, fmt.Errorf("team not found")
}
return team, nil
}
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
rows, err := s.DB.GetTeamsForUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list teams: %w", err)
}
result := make([]TeamWithRole, len(rows))
for i, r := range rows {
result[i] = TeamWithRole{
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
Role: r.Role,
}
}
return result, nil
}
// CreateTeam creates a new team owned by the given user.
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
if !teamNameRE.MatchString(name) {
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
tx, err := s.Pool.Begin(ctx)
if err != nil {
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := s.DB.WithTx(tx)
teamID := id.NewTeamID()
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: name,
Slug: id.NewTeamSlug(),
})
if err != nil {
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: ownerUserID,
TeamID: teamID,
IsDefault: false,
Role: "owner",
}); err != nil {
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
}
return TeamWithRole{Team: team, Role: "owner"}, nil
}
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
if !teamNameRE.MatchString(newName) {
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(role); err != nil {
return err
}
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
return fmt.Errorf("update name: %w", err)
}
return nil
}
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
// are preserved; only the team's deleted_at is set and active VMs are stopped.
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role != "owner" {
return fmt.Errorf("forbidden: only the owner can delete a team")
}
// Collect active sandboxes and stop them.
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("list active sandboxes: %w", err)
}
var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}
if len(stopIDs) > 0 {
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: stopIDs,
Status: "stopped",
}); err != nil {
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
// as that would leave orphaned "running" records for a deleted team.
return fmt.Errorf("update sandbox statuses: %w", err)
}
}
// Clean up team-owned templates from all hosts in the background.
go s.cleanupTeamTemplates(context.Background(), teamID)
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
// then removes the DB records. Called asynchronously during team deletion.
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
if err != nil {
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
return
}
if len(templates) == 0 {
return
}
hosts, err := s.DB.ListActiveHosts(ctx)
if err != nil {
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
return
}
for _, tmpl := range templates {
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := s.HostPool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: id.UUIDString(tmpl.TeamID),
TemplateId: id.UUIDString(tmpl.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to delete template on host",
"host_id", id.FormatHostID(host.ID),
"template", tmpl.Name,
"error", err,
)
}
}
}
// Remove DB records.
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
}
}
// GetMembers returns all members of the team with their emails and roles.
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
rows, err := s.DB.GetTeamMembers(ctx, teamID)
if err != nil {
return nil, fmt.Errorf("get members: %w", err)
}
members := make([]MemberInfo, len(rows))
for i, r := range rows {
var joinedAt time.Time
if r.JoinedAt.Valid {
joinedAt = r.JoinedAt.Time
}
members[i] = MemberInfo{
UserID: id.FormatUserID(r.ID),
Name: r.Name,
Email: r.Email,
Role: r.Role,
JoinedAt: joinedAt,
}
}
return members, nil
}
// AddMember adds an existing user (looked up by email) to the team as a member.
// Caller must be admin or owner (verified from DB).
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return MemberInfo{}, err
}
if err := requireAdmin(role); err != nil {
return MemberInfo{}, err
}
target, err := s.DB.GetUserByEmail(ctx, email)
if err != nil {
if err == pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
}
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
}
// Check if already a member.
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: target.ID,
TeamID: teamID,
})
if memberCheckErr == nil {
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
} else if memberCheckErr != pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
}
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: target.ID,
TeamID: teamID,
IsDefault: false,
Role: "member",
}); err != nil {
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
}
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
}
// RemoveMember removes a user from the team.
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: targetUserID,
}); err != nil {
return fmt.Errorf("delete member: %w", err)
}
return nil
}
// UpdateMemberRole changes a member's role to admin or member.
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
// Valid target roles: "admin", "member".
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
if newRole != "admin" && newRole != "member" {
return fmt.Errorf("invalid: role must be admin or member")
}
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner's role cannot be changed")
}
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
TeamID: teamID,
UserID: targetUserID,
Role: newRole,
}); err != nil {
return fmt.Errorf("update role: %w", err)
}
return nil
}
// LeaveTeam removes the calling user from the team.
// The owner cannot leave; they must delete the team instead.
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role == "owner" {
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: callerUserID,
}); err != nil {
return fmt.Errorf("leave team: %w", err)
}
return nil
}
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
// be disabled — it is a one-way transition.
// Admin-only — the caller must verify admin status before invoking this.
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
if !enabled {
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
}
if team.IsByoc {
// Already enabled — idempotent, no-op.
return nil
}
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
return fmt.Errorf("set byoc: %w", err)
}
return nil
}

View File

@ -3,6 +3,8 @@ package service
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/sandbox/internal/db"
)
@ -14,7 +16,7 @@ type TemplateService struct {
// List returns all templates belonging to the given team. If typeFilter is
// non-empty, only templates of that type ("base" or "snapshot") are returned.
func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
if typeFilter != "" {
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
TeamID: teamID,