forked from wrenn/wrenn
v0.0.1 (#8)
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/sandbox#8
This commit is contained in:
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
@ -22,7 +24,7 @@ type APIKeyCreateResult struct {
|
||||
}
|
||||
|
||||
// Create generates a new API key for the given team.
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
|
||||
if name == "" {
|
||||
name = "Unnamed API Key"
|
||||
}
|
||||
@ -48,16 +50,16 @@ func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string)
|
||||
}
|
||||
|
||||
// List returns all API keys belonging to the given team.
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
|
||||
return s.DB.ListAPIKeysByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListWithCreator returns all API keys for the team, joined with the creator's email.
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
|
||||
}
|
||||
|
||||
// Delete removes an API key by ID, scoped to the given team.
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
|
||||
}
|
||||
|
||||
113
internal/service/audit.go
Normal file
113
internal/service/audit.go
Normal file
@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
const auditMaxLimit = 200
|
||||
|
||||
// AuditEntry is a single audit log record returned by List.
|
||||
type AuditEntry struct {
|
||||
ID string
|
||||
TeamID string
|
||||
ActorType string
|
||||
ActorID string // empty for system
|
||||
ActorName string // empty for system
|
||||
ResourceType string
|
||||
ResourceID string // empty when not applicable
|
||||
Action string
|
||||
Scope string
|
||||
Status string // 'success', 'info', 'warning', 'error'
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AuditListParams controls the ListAuditLogs query.
|
||||
type AuditListParams struct {
|
||||
TeamID pgtype.UUID
|
||||
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
|
||||
ResourceTypes []string // empty = no filter; multiple values = OR match
|
||||
Actions []string // empty = no filter; multiple values = OR match
|
||||
Before time.Time // zero = no cursor (start from latest)
|
||||
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
|
||||
Limit int // clamped to auditMaxLimit by the handler
|
||||
}
|
||||
|
||||
// AuditService provides the read side of the audit log.
|
||||
type AuditService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns a page of audit log entries for the given team.
|
||||
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
|
||||
limit := p.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > auditMaxLimit {
|
||||
limit = auditMaxLimit
|
||||
}
|
||||
|
||||
scopes := []string{"team"}
|
||||
if p.AdminScoped {
|
||||
scopes = append(scopes, "admin")
|
||||
}
|
||||
|
||||
var before pgtype.Timestamptz
|
||||
if !p.Before.IsZero() {
|
||||
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
|
||||
}
|
||||
|
||||
resourceTypes := p.ResourceTypes
|
||||
if resourceTypes == nil {
|
||||
resourceTypes = []string{}
|
||||
}
|
||||
actions := p.Actions
|
||||
if actions == nil {
|
||||
actions = []string{}
|
||||
}
|
||||
|
||||
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
|
||||
TeamID: p.TeamID,
|
||||
Column2: scopes,
|
||||
Column3: resourceTypes,
|
||||
Column4: actions,
|
||||
Column5: before,
|
||||
ID: p.BeforeID,
|
||||
Limit: int32(limit),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list audit logs: %w", err)
|
||||
}
|
||||
|
||||
entries := make([]AuditEntry, len(rows))
|
||||
for i, row := range rows {
|
||||
var meta map[string]any
|
||||
if len(row.Metadata) > 0 {
|
||||
_ = json.Unmarshal(row.Metadata, &meta)
|
||||
}
|
||||
entries[i] = AuditEntry{
|
||||
ID: id.FormatAuditLogID(row.ID),
|
||||
TeamID: id.FormatTeamID(row.TeamID),
|
||||
ActorType: row.ActorType,
|
||||
ActorID: row.ActorID.String,
|
||||
ActorName: row.ActorName,
|
||||
ResourceType: row.ResourceType,
|
||||
ResourceID: row.ResourceID.String,
|
||||
Action: row.Action,
|
||||
Scope: row.Scope,
|
||||
Status: row.Status,
|
||||
Metadata: meta,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
605
internal/service/build.go
Normal file
605
internal/service/build.go
Normal file
@ -0,0 +1,605 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/recipe"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
const (
|
||||
buildQueueKey = "wrenn:build_queue"
|
||||
buildCommandTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// preBuildCmds run before the user recipe to prepare the build environment.
|
||||
var preBuildCmds = []string{
|
||||
"RUN apt update",
|
||||
}
|
||||
|
||||
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
|
||||
var postBuildCmds = []string{
|
||||
"RUN apt clean",
|
||||
"RUN apt autoremove -y",
|
||||
"RUN rm -rf /var/lib/apt/lists/*",
|
||||
}
|
||||
|
||||
// buildAgentClient is the subset of the host agent client used by the build worker.
|
||||
type buildAgentClient interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
|
||||
}
|
||||
|
||||
// BuildService handles template build orchestration.
|
||||
type BuildService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
|
||||
mu sync.Mutex
|
||||
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
|
||||
}
|
||||
|
||||
// BuildCreateParams holds the parameters for creating a template build.
|
||||
type BuildCreateParams struct {
|
||||
Name string
|
||||
BaseTemplate string
|
||||
Recipe []string
|
||||
Healthcheck string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
SkipPrePost bool
|
||||
}
|
||||
|
||||
// Create inserts a new build record and enqueues it to Redis.
|
||||
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
|
||||
if p.BaseTemplate == "" {
|
||||
p.BaseTemplate = "minimal"
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(p.Recipe)
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
|
||||
}
|
||||
|
||||
buildID := id.NewBuildID()
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
|
||||
if p.SkipPrePost {
|
||||
defaultSteps = 0
|
||||
}
|
||||
|
||||
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
|
||||
ID: buildID,
|
||||
Name: p.Name,
|
||||
BaseTemplate: p.BaseTemplate,
|
||||
Recipe: recipeJSON,
|
||||
Healthcheck: p.Healthcheck,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TotalSteps: int32(len(p.Recipe) + defaultSteps),
|
||||
TemplateID: newTemplateID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
SkipPrePost: p.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
|
||||
}
|
||||
|
||||
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
|
||||
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
|
||||
}
|
||||
|
||||
return build, nil
|
||||
}
|
||||
|
||||
// Get returns a single build by ID.
|
||||
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
|
||||
return s.DB.GetTemplateBuild(ctx, buildID)
|
||||
}
|
||||
|
||||
// List returns all builds ordered by creation time.
|
||||
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
|
||||
return s.DB.ListTemplateBuilds(ctx)
|
||||
}
|
||||
|
||||
// Cancel cancels a pending or running build. For pending builds the status is
|
||||
// updated in the DB and the worker skips it when dequeued. For running builds
|
||||
// the per-build context is cancelled, which causes the current exec step to
|
||||
// abort; executeBuild then detects the cancellation and records the status.
|
||||
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
|
||||
build, err := s.DB.GetTemplateBuild(ctx, buildID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get build: %w", err)
|
||||
}
|
||||
switch build.Status {
|
||||
case "success", "failed", "cancelled":
|
||||
return fmt.Errorf("build is already %s", build.Status)
|
||||
}
|
||||
|
||||
// Mark cancelled in DB first. This handles both pending builds (which haven't
|
||||
// been picked up yet) and acts as a flag for executeBuild to check on start.
|
||||
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "cancelled",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update build status: %w", err)
|
||||
}
|
||||
|
||||
// If the build is currently running, signal its context.
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
s.mu.Lock()
|
||||
cancel, running := s.cancelMap[buildIDStr]
|
||||
s.mu.Unlock()
|
||||
if running {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWorkers launches n goroutines that consume from the Redis build queue.
|
||||
// The returned cancel function stops all workers.
|
||||
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := range n {
|
||||
go s.worker(ctx, i)
|
||||
}
|
||||
slog.Info("build workers started", "count", n)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (s *BuildService) worker(ctx context.Context, workerID int) {
|
||||
log := slog.With("worker", workerID)
|
||||
for {
|
||||
// BLPOP blocks until a build ID is available or context is cancelled.
|
||||
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Info("build worker shutting down")
|
||||
return
|
||||
}
|
||||
log.Error("redis BLPOP error", "error", err)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
// result[0] is the key, result[1] is the build ID (formatted string).
|
||||
buildIDStr := result[1]
|
||||
log.Info("picked up build", "build_id", buildIDStr)
|
||||
s.executeBuild(ctx, buildIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
log := slog.With("build_id", buildIDStr)
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
log.Error("invalid build ID from queue", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a per-build context so this build can be cancelled independently of
|
||||
// the worker. Register in cancelMap before fetching the build so that a
|
||||
// concurrent Cancel call can always find and signal it.
|
||||
buildCtx, buildCancel := context.WithCancel(ctx)
|
||||
defer buildCancel()
|
||||
|
||||
s.mu.Lock()
|
||||
if s.cancelMap == nil {
|
||||
s.cancelMap = make(map[string]context.CancelFunc)
|
||||
}
|
||||
s.cancelMap[buildIDStr] = buildCancel
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.cancelMap, buildIDStr)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
|
||||
if err != nil {
|
||||
log.Error("failed to fetch build", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if already cancelled (Cancel was called before we dequeued).
|
||||
if build.Status == "cancelled" {
|
||||
log.Info("build already cancelled, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as running.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "running",
|
||||
}); err != nil {
|
||||
log.Error("failed to update build status", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse user recipe.
|
||||
var userRecipe []string
|
||||
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a platform host and create a sandbox.
|
||||
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
|
||||
baseTeamID := id.PlatformTeamID
|
||||
baseTemplateID := id.MinimalTemplateID
|
||||
if build.BaseTemplate != "minimal" {
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return
|
||||
}
|
||||
baseTeamID = baseTmpl.TeamID
|
||||
baseTemplateID = baseTmpl.ID
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0, // no auto-pause for builds
|
||||
DiskSizeMb: 5120, // 5 GB for template builds
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return
|
||||
}
|
||||
_ = resp
|
||||
|
||||
// Record sandbox/host association.
|
||||
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
|
||||
// valid; panic on error is appropriate here since it would be a programmer mistake.
|
||||
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
|
||||
}
|
||||
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
|
||||
return
|
||||
}
|
||||
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
|
||||
}
|
||||
|
||||
var logs []recipe.BuildLogEntry
|
||||
step := 0
|
||||
|
||||
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
|
||||
if err != nil {
|
||||
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
|
||||
envVars = map[string]string{
|
||||
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"HOME": "/root",
|
||||
}
|
||||
}
|
||||
bctx := &recipe.ExecContext{EnvVars: envVars}
|
||||
|
||||
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
|
||||
logs = append(logs, newEntries...)
|
||||
step = nextStep
|
||||
s.updateLogs(buildCtx, buildID, step, logs)
|
||||
if !ok {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
// If the build was cancelled, status is already set — don't overwrite with "failed".
|
||||
if buildCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
last := newEntries[len(newEntries)-1]
|
||||
reason := last.Stderr
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("exit code %d", last.Exit)
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("pre-build", preBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
|
||||
return
|
||||
}
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("post-build", postBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck or direct snapshot.
|
||||
var sizeBytes int64
|
||||
if build.Healthcheck != "" {
|
||||
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
return
|
||||
}
|
||||
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
|
||||
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Healthcheck passed → full snapshot (with memory/CPU state).
|
||||
log.Info("healthcheck passed, creating snapshot")
|
||||
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = snapResp.Msg.SizeBytes
|
||||
} else {
|
||||
// No healthcheck → image-only template (rootfs only).
|
||||
log.Info("no healthcheck, flattening rootfs")
|
||||
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = flatResp.Msg.SizeBytes
|
||||
}
|
||||
|
||||
// Insert into templates table as a global (platform) template.
|
||||
templateType := "base"
|
||||
if build.Healthcheck != "" {
|
||||
templateType = "snapshot"
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
|
||||
ID: build.TemplateID,
|
||||
Name: build.Name,
|
||||
Type: templateType,
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
SizeBytes: sizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
}); err != nil {
|
||||
log.Error("failed to insert template record", "error", err)
|
||||
// Build succeeded on disk, just DB record failed — don't mark as failed.
|
||||
}
|
||||
|
||||
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
|
||||
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
|
||||
// No additional destroy needed.
|
||||
|
||||
// Mark build as success.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "success",
|
||||
}); err != nil {
|
||||
log.Error("failed to mark build as success", "error", err)
|
||||
}
|
||||
|
||||
log.Info("template build completed successfully", "name", build.Name)
|
||||
}
|
||||
|
||||
// waitForHealthcheck repeatedly executes the healthcheck command inside the
|
||||
// sandbox according to the config's interval, timeout, start-period, and
|
||||
// retries.
|
||||
// During the start period, failures are not counted toward the retry budget.
|
||||
// Returns nil on the first successful check, or an error if retries are
|
||||
// exhausted, the deadline passes, or the context is cancelled.
|
||||
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
|
||||
ticker := time.NewTicker(hc.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// When retries > 0, set a deadline based on the retry budget.
|
||||
// When retries == 0 (unlimited), rely solely on the parent context deadline.
|
||||
var deadlineCh <-chan time.Time
|
||||
if hc.Retries > 0 {
|
||||
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
|
||||
defer deadline.Stop()
|
||||
deadlineCh = deadline.C
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
failCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadlineCh:
|
||||
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
|
||||
case <-ticker.C:
|
||||
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
|
||||
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", hc.Cmd},
|
||||
TimeoutSec: int32(hc.Timeout.Seconds()),
|
||||
}))
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
slog.Debug("healthcheck exec error (retrying)", "error", err)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if resp.Msg.ExitCode == 0 {
|
||||
return nil
|
||||
}
|
||||
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
|
||||
logsJSON, err := json.Marshal(logs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to marshal build logs", "error", err)
|
||||
return
|
||||
}
|
||||
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
|
||||
ID: buildID,
|
||||
CurrentStep: int32(step),
|
||||
Logs: logsJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update build progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
|
||||
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
|
||||
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
|
||||
ID: buildID,
|
||||
Error: errMsg,
|
||||
}); err != nil {
|
||||
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
|
||||
// Use a detached context so cleanup succeeds even during shutdown.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
|
||||
// the build agent and returns environment variables
|
||||
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
|
||||
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", "env"},
|
||||
TimeoutSec: 10,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch env: %w", err)
|
||||
}
|
||||
|
||||
if resp.Msg.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("fetch env: command exited with code %d",
|
||||
resp.Msg.ExitCode)
|
||||
}
|
||||
|
||||
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
|
||||
}
|
||||
|
||||
// parseSandboxEnv converts the raw newline-separated output of an 'env'
|
||||
// command into a map.
|
||||
// It skips empty lines and malformed entries, and correctly handles values
|
||||
// containing '='.
|
||||
func parseSandboxEnv(raw string) map[string]string {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
for line := range strings.SplitSeq(raw, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
envVars[parts[0]] = parts[1]
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
@ -2,12 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@ -15,6 +17,8 @@ import (
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
@ -22,15 +26,17 @@ type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
Pool *lifecycle.HostClientPool
|
||||
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID string // required for BYOC, empty for regular
|
||||
TeamID pgtype.UUID // required for BYOC, zero value for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID string
|
||||
RequestingUserID pgtype.UUID
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
@ -50,10 +56,34 @@ type HostRegisterParams struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host and its long-lived JWT.
|
||||
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
|
||||
// refresh token, and optionally the host's mTLS certificate material.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
|
||||
// refresh, plus refreshed mTLS certificate material when CA is configured.
|
||||
type HostRefreshResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostDeletePreview describes what will be affected by deleting a host.
|
||||
type HostDeletePreview struct {
|
||||
Host db.Host
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
@ -64,6 +94,14 @@ type regTokenPayload struct {
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||
func requireAdminOrOwner(role string) error {
|
||||
if role == "owner" || role == "admin" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
|
||||
}
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
@ -75,8 +113,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: admin or team owner.
|
||||
if p.TeamID == "" {
|
||||
// BYOC: platform admin, or team owner/admin.
|
||||
if !p.TeamID.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
@ -90,40 +128,31 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists for BYOC hosts.
|
||||
if p.TeamID != "" {
|
||||
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
|
||||
// Validate team exists, is not deleted, and has BYOC enabled.
|
||||
if p.TeamID.Valid {
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil || team.DeletedAt.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
if !team.IsByoc {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
var teamID pgtype.Text
|
||||
if p.TeamID != "" {
|
||||
teamID = pgtype.Text{String: p.TeamID, Valid: true}
|
||||
}
|
||||
var provider pgtype.Text
|
||||
if p.Provider != "" {
|
||||
provider = pgtype.Text{String: p.Provider, Valid: true}
|
||||
}
|
||||
var az pgtype.Text
|
||||
if p.AvailabilityZone != "" {
|
||||
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
|
||||
}
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: teamID,
|
||||
Provider: provider,
|
||||
AvailabilityZone: az,
|
||||
TeamID: p.TeamID,
|
||||
Provider: p.Provider,
|
||||
AvailabilityZone: p.AvailabilityZone,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -135,8 +164,8 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
@ -149,7 +178,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
@ -158,7 +187,7 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
@ -167,12 +196,11 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
// Same permission model as Create/Delete.
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
@ -185,8 +213,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -194,8 +222,8 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
@ -208,14 +236,14 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a long-lived host JWT.
|
||||
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
@ -232,24 +260,44 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
|
||||
hostID, err := id.ParseHostID(payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
tokenID, err := id.ParseHostTokenID(payload.TokenID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Issue mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: payload.HostID,
|
||||
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
|
||||
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
|
||||
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
|
||||
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
|
||||
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
|
||||
ID: hostID,
|
||||
Arch: p.Arch,
|
||||
CpuCores: p.CPUCores,
|
||||
MemoryMb: p.MemoryMB,
|
||||
DiskGb: p.DiskGB,
|
||||
Address: p.Address,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
@ -259,82 +307,293 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Issue a long-lived refresh token.
|
||||
refreshToken, err := s.issueRefreshToken(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, payload.HostID)
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
|
||||
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host.
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||
return s.DB.UpdateHostHeartbeat(ctx, hostID)
|
||||
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||
// and returns a fresh JWT plus the new refresh token.
|
||||
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
|
||||
hash := hashToken(refreshToken)
|
||||
|
||||
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
host, err := s.DB.GetHost(ctx, row.HostID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign new JWT.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||
}
|
||||
|
||||
// Renew mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
|
||||
}
|
||||
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
|
||||
ID: host.ID,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
|
||||
}); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Revoke old refresh token after the new one is safely persisted.
|
||||
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||
}
|
||||
|
||||
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||
// the opaque token string.
|
||||
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
|
||||
token := id.NewRefreshToken()
|
||||
hash := hashToken(token)
|
||||
now := time.Now()
|
||||
|
||||
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
|
||||
ID: id.NewRefreshTokenID(),
|
||||
HostID: hostID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("insert refresh token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// hashToken returns the hex-encoded SHA-256 hash of the token.
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||
// any 'unreachable' host back to 'online'. Returns a "host not found" error
|
||||
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
|
||||
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return fmt.Errorf("host not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
|
||||
return s.DB.ListHostsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Admins can delete any host. Team owners can delete
|
||||
// BYOC hosts belonging to their team.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
// DeletePreview returns what would be affected by deleting the host, without
|
||||
// making any changes. Use this to show the user a confirmation prompt.
|
||||
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("host not found: %w", err)
|
||||
return HostDeletePreview{}, err
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
|
||||
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Without force it returns an error listing active
|
||||
// sandboxes so the caller can present a confirmation. With force it gracefully
|
||||
// destroys all running sandboxes before deleting the host record.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
if len(sandboxes) > 0 && !force {
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return fmt.Errorf("forbidden: host does not belong to your team")
|
||||
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||
}
|
||||
|
||||
hostIDStr := id.FormatHostID(hostID)
|
||||
|
||||
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
|
||||
if host.Address != "" {
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err == nil {
|
||||
for _, sb := range sandboxes {
|
||||
if sb.Status == "running" || sb.Status == "starting" {
|
||||
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
}))
|
||||
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
|
||||
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tell the agent to shut itself down immediately.
|
||||
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
|
||||
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
|
||||
}
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
|
||||
// Mark all affected sandboxes as stopped in DB.
|
||||
if len(sandboxes) > 0 {
|
||||
sbIDs := make([]pgtype.UUID, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
sbIDs[i] = sb.ID
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: sbIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke all refresh tokens for this host.
|
||||
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
|
||||
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Evict the client from the pool so no further RPCs are sent.
|
||||
if s.Pool != nil {
|
||||
s.Pool.Evict(id.FormatHostID(hostID))
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// checkDeletePermission verifies the caller has permission to delete the given
|
||||
// host and returns the host record on success.
|
||||
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if host.Type != "byoc" {
|
||||
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return db.Host{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -342,7 +601,7 @@ func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -350,9 +609,20 @@ func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAd
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
|
||||
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
|
||||
// and force was not set. The caller should present the list to the user and
|
||||
// re-call Delete with force=true if they confirm.
|
||||
type HostHasSandboxesError struct {
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
func (e *HostHasSandboxesError) Error() string {
|
||||
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
|
||||
}
|
||||
|
||||
@ -11,29 +11,60 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/scheduler"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// SandboxService provides sandbox lifecycle operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type SandboxService struct {
|
||||
DB *db.Queries
|
||||
Agent hostagentv1connect.HostAgentServiceClient
|
||||
DB *db.Queries
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
}
|
||||
|
||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||
type SandboxCreateParams struct {
|
||||
TeamID string
|
||||
TeamID pgtype.UUID
|
||||
Template string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
TimeoutSec int32
|
||||
DiskSizeMB int32
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
|
||||
// and updates the record to running. Returns the sandbox DB row.
|
||||
// agentForSandbox looks up the host for the given sandbox and returns a client.
|
||||
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
host, err := s.DB.GetHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
|
||||
}
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
return agent, sb, nil
|
||||
}
|
||||
|
||||
// hostagentClient is a local alias to avoid the full package path in signatures.
|
||||
type hostagentClient = interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
|
||||
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
|
||||
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
|
||||
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
|
||||
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
|
||||
// DB record, calls the host agent, and updates the record to running.
|
||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||
if p.Template == "" {
|
||||
p.Template = "minimal"
|
||||
@ -47,44 +78,82 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
if p.DiskSizeMB <= 0 {
|
||||
p.DiskSizeMB = 5120 // 5 GB default
|
||||
}
|
||||
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
|
||||
if tmpl.Vcpus.Valid {
|
||||
p.VCPUs = tmpl.Vcpus.Int32
|
||||
// Resolve template name → (teamID, templateID).
|
||||
templateTeamID := id.PlatformTeamID
|
||||
templateID := id.MinimalTemplateID
|
||||
if p.Template != "minimal" {
|
||||
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
|
||||
}
|
||||
if tmpl.MemoryMb.Valid {
|
||||
p.MemoryMB = tmpl.MemoryMb.Int32
|
||||
templateTeamID = tmpl.TeamID
|
||||
templateID = tmpl.ID
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl.Type == "snapshot" {
|
||||
p.VCPUs = tmpl.Vcpus
|
||||
p.MemoryMB = tmpl.MemoryMb
|
||||
}
|
||||
}
|
||||
|
||||
if !p.TeamID.Valid {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
|
||||
}
|
||||
|
||||
// Determine whether this team uses BYOC hosts or platform hosts.
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
|
||||
// Pick a host for this sandbox.
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: "default",
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: host.ID,
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
TemplateID: templateID,
|
||||
TemplateTeamID: templateTeamID,
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: p.Template,
|
||||
TeamId: id.UUIDString(templateTeamID),
|
||||
TemplateId: id.UUIDString(templateID),
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
}))
|
||||
if err != nil {
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "error",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
|
||||
}
|
||||
@ -107,17 +176,17 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
}
|
||||
|
||||
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
|
||||
func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
|
||||
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
|
||||
return s.DB.ListSandboxesByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single sandbox by ID, scoped to the given team.
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Pause snapshots and freezes a running sandbox to disk.
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -126,23 +195,45 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (d
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
|
||||
// mark the sandbox "stopped" while the host agent processes the pause.
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
|
||||
}
|
||||
|
||||
// Flush all metrics tiers before pausing so data survives in DB.
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
|
||||
|
||||
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
// Revert status on failure.
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||
}
|
||||
|
||||
sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
})
|
||||
sb, err = s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
|
||||
}
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Resume restores a paused sandbox from snapshot.
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -151,8 +242,15 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
@ -176,18 +274,41 @@ func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox and marks it as stopped.
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
|
||||
if _, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}); err != nil {
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// If running, flush 24h tier metrics for analytics before destroying.
|
||||
if sb.Status == "running" {
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
|
||||
}
|
||||
|
||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
return fmt.Errorf("agent destroy: %w", err)
|
||||
}
|
||||
|
||||
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
|
||||
if sb.Status == "paused" {
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "10m",
|
||||
})
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "2h",
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopped",
|
||||
}); err != nil {
|
||||
@ -196,8 +317,45 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
|
||||
// the returned data to DB. If allTiers is true, all three tiers are saved;
|
||||
// otherwise only the 24h tier (for post-destroy analytics).
|
||||
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
|
||||
return
|
||||
}
|
||||
msg := resp.Msg
|
||||
|
||||
if allTiers {
|
||||
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
|
||||
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
|
||||
}
|
||||
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
|
||||
}
|
||||
|
||||
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
for _, p := range points {
|
||||
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
|
||||
SandboxID: sandboxID,
|
||||
Tier: tier,
|
||||
Ts: p.TimestampUnix,
|
||||
CpuPct: p.CpuPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}); err != nil {
|
||||
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ping resets the inactivity timer for a running sandbox.
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
@ -206,8 +364,15 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
|
||||
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("agent ping: %w", err)
|
||||
}
|
||||
@ -219,7 +384,7 @@ func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) err
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
160
internal/service/stats.go
Normal file
160
internal/service/stats.go
Normal file
@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// TimeRange identifies a chart time window.
|
||||
type TimeRange string
|
||||
|
||||
const (
|
||||
Range5m TimeRange = "5m"
|
||||
Range1h TimeRange = "1h"
|
||||
Range6h TimeRange = "6h"
|
||||
Range24h TimeRange = "24h"
|
||||
Range30d TimeRange = "30d"
|
||||
)
|
||||
|
||||
type rangeConfig struct {
|
||||
bucketSec int // bucket width in seconds for time-series aggregation
|
||||
intervalLiteral string // PostgreSQL interval literal for the lookback window
|
||||
}
|
||||
|
||||
var rangeConfigs = map[TimeRange]rangeConfig{
|
||||
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
|
||||
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
|
||||
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
|
||||
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
|
||||
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
|
||||
}
|
||||
|
||||
// ValidRange returns true if r is a known TimeRange value.
|
||||
func ValidRange(r TimeRange) bool {
|
||||
_, ok := rangeConfigs[r]
|
||||
return ok
|
||||
}
|
||||
|
||||
// StatPoint is one bucketed data point in the time-series.
|
||||
type StatPoint struct {
|
||||
Bucket time.Time
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// CurrentStats holds the live values for a team, read directly from sandboxes.
|
||||
type CurrentStats struct {
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// PeakStats holds the 30-day maximum values for a team.
|
||||
type PeakStats struct {
|
||||
RunningCount int32
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
}
|
||||
|
||||
// StatsService computes sandbox metrics for the dashboard.
|
||||
type StatsService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// GetStats returns current stats, 30-day peaks, and a time-series for the
|
||||
// given team and time range. If no snapshots exist yet, zeros are returned.
|
||||
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
|
||||
cfg, ok := rangeConfigs[r]
|
||||
if !ok {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
|
||||
}
|
||||
|
||||
// Current live values — read directly from sandboxes so we always reflect
|
||||
// the true state even when no capsules are running.
|
||||
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
|
||||
}
|
||||
current := CurrentStats{
|
||||
RunningCount: cur.RunningCount,
|
||||
VCPUsReserved: cur.VcpusReserved,
|
||||
MemoryMBReserved: cur.MemoryMbReserved,
|
||||
}
|
||||
|
||||
// 30-day peaks.
|
||||
var peaks PeakStats
|
||||
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
peaks = PeakStats{
|
||||
RunningCount: pk.PeakRunningCount,
|
||||
VCPUs: pk.PeakVcpus,
|
||||
MemoryMB: pk.PeakMemoryMb,
|
||||
}
|
||||
}
|
||||
|
||||
// Time-series — dynamic bucket width, executed via pgx directly.
|
||||
series, err := s.queryTimeSeries(ctx, teamID, cfg)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
|
||||
}
|
||||
|
||||
return current, peaks, series, nil
|
||||
}
|
||||
|
||||
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
|
||||
// integer number of seconds without requiring TimescaleDB.
|
||||
//
|
||||
// MAX is used instead of AVG so that short-lived running states are not
|
||||
// averaged down to zero within a bucket. For capacity metrics the peak
|
||||
// value in each bucket is what matters — AVG with ::INTEGER rounding
|
||||
// caused running_count, vcpus, and memory to become inconsistent with
|
||||
// each other (e.g. running=0 but vcpus=1).
|
||||
//
|
||||
// $1 = bucket width in seconds (integer)
|
||||
// $2 = team_id
|
||||
// $3 = lookback interval literal (e.g. '1 hour')
|
||||
const timeSeriesSQL = `
|
||||
SELECT
|
||||
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
|
||||
MAX(running_count) AS running_count,
|
||||
MAX(vcpus_reserved) AS vcpus_reserved,
|
||||
MAX(memory_mb_reserved) AS memory_mb_reserved
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $2
|
||||
AND sampled_at >= NOW() - $3::INTERVAL
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC
|
||||
`
|
||||
|
||||
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
|
||||
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var points []StatPoint
|
||||
for rows.Next() {
|
||||
var p StatPoint
|
||||
var bucket time.Time
|
||||
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Bucket = bucket
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
443
internal/service/team.go
Normal file
443
internal/service/team.go
Normal file
@ -0,0 +1,443 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
||||
|
||||
// TeamService provides team management operations.
|
||||
type TeamService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
HostPool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
// TeamWithRole pairs a team with the calling user's role in it.
|
||||
type TeamWithRole struct {
|
||||
db.Team
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// MemberInfo is a team member with resolved user details.
|
||||
type MemberInfo struct {
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
}
|
||||
|
||||
// callerRole fetches the calling user's role in the given team from DB.
|
||||
// Returns an error wrapping "forbidden" if the caller is not a member.
|
||||
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
|
||||
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: callerUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("forbidden: not a member of this team")
|
||||
}
|
||||
return "", fmt.Errorf("get membership: %w", err)
|
||||
}
|
||||
return m.Role, nil
|
||||
}
|
||||
|
||||
// requireAdmin returns an error if the caller is not an admin or owner.
|
||||
func requireAdmin(role string) error {
|
||||
if role != "owner" && role != "admin" {
|
||||
return fmt.Errorf("forbidden: admin or owner role required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
|
||||
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return db.Team{}, fmt.Errorf("get team: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return team, nil
|
||||
}
|
||||
|
||||
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
|
||||
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
|
||||
rows, err := s.DB.GetTeamsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list teams: %w", err)
|
||||
}
|
||||
result := make([]TeamWithRole, len(rows))
|
||||
for i, r := range rows {
|
||||
result[i] = TeamWithRole{
|
||||
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
|
||||
Role: r.Role,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreateTeam creates a new team owned by the given user.
|
||||
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
|
||||
if !teamNameRE.MatchString(name) {
|
||||
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
tx, err := s.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := s.DB.WithTx(tx)
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: name,
|
||||
Slug: id.NewTeamSlug(),
|
||||
})
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: ownerUserID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return TeamWithRole{Team: team, Role: "owner"}, nil
|
||||
}
|
||||
|
||||
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
|
||||
if !teamNameRE.MatchString(newName) {
|
||||
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
|
||||
return fmt.Errorf("update name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
|
||||
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
|
||||
// are preserved; only the team's deleted_at is set and active VMs are stopped.
|
||||
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role != "owner" {
|
||||
return fmt.Errorf("forbidden: only the owner can delete a team")
|
||||
}
|
||||
|
||||
// Collect active sandboxes and stop them.
|
||||
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list active sandboxes: %w", err)
|
||||
}
|
||||
|
||||
var stopIDs []pgtype.UUID
|
||||
for _, sb := range sandboxes {
|
||||
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
|
||||
if hostErr == nil {
|
||||
agent, agentErr := s.HostPool.GetForHost(host)
|
||||
if agentErr == nil {
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
stopIDs = append(stopIDs, sb.ID)
|
||||
}
|
||||
|
||||
if len(stopIDs) > 0 {
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: stopIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
|
||||
// as that would leave orphaned "running" records for a deleted team.
|
||||
return fmt.Errorf("update sandbox statuses: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up team-owned templates from all hosts in the background.
|
||||
go s.cleanupTeamTemplates(context.Background(), teamID)
|
||||
|
||||
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
|
||||
return fmt.Errorf("soft delete team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
|
||||
// then removes the DB records. Called asynchronously during team deletion.
|
||||
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
|
||||
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
return
|
||||
}
|
||||
if len(templates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
hosts, err := s.DB.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range templates {
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := s.HostPool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: id.UUIDString(tmpl.TeamID),
|
||||
TemplateId: id.UUIDString(tmpl.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to delete template on host",
|
||||
"host_id", id.FormatHostID(host.ID),
|
||||
"template", tmpl.Name,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove DB records.
|
||||
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMembers returns all members of the team with their emails and roles.
|
||||
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
|
||||
rows, err := s.DB.GetTeamMembers(ctx, teamID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get members: %w", err)
|
||||
}
|
||||
members := make([]MemberInfo, len(rows))
|
||||
for i, r := range rows {
|
||||
var joinedAt time.Time
|
||||
if r.JoinedAt.Valid {
|
||||
joinedAt = r.JoinedAt.Time
|
||||
}
|
||||
members[i] = MemberInfo{
|
||||
UserID: id.FormatUserID(r.ID),
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Role: r.Role,
|
||||
JoinedAt: joinedAt,
|
||||
}
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// AddMember adds an existing user (looked up by email) to the team as a member.
|
||||
// Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
|
||||
target, err := s.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
|
||||
}
|
||||
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
|
||||
}
|
||||
|
||||
// Check if already a member.
|
||||
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if memberCheckErr == nil {
|
||||
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
|
||||
} else if memberCheckErr != pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
|
||||
}
|
||||
|
||||
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "member",
|
||||
}); err != nil {
|
||||
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
|
||||
}
|
||||
|
||||
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
|
||||
}
|
||||
|
||||
// RemoveMember removes a user from the team.
|
||||
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
|
||||
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("delete member: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMemberRole changes a member's role to admin or member.
|
||||
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
|
||||
// Valid target roles: "admin", "member".
|
||||
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
|
||||
if newRole != "admin" && newRole != "member" {
|
||||
return fmt.Errorf("invalid: role must be admin or member")
|
||||
}
|
||||
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner's role cannot be changed")
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
Role: newRole,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeaveTeam removes the calling user from the team.
|
||||
// The owner cannot leave; they must delete the team instead.
|
||||
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: callerUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("leave team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
|
||||
// be disabled — it is a one-way transition.
|
||||
// Admin-only — the caller must verify admin status before invoking this.
|
||||
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
if !enabled {
|
||||
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
|
||||
}
|
||||
if team.IsByoc {
|
||||
// Already enabled — idempotent, no-op.
|
||||
return nil
|
||||
}
|
||||
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
|
||||
return fmt.Errorf("set byoc: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
@ -14,7 +16,7 @@ type TemplateService struct {
|
||||
|
||||
// List returns all templates belonging to the given team. If typeFilter is
|
||||
// non-empty, only templates of that type ("base" or "snapshot") are returned.
|
||||
func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
|
||||
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
|
||||
if typeFilter != "" {
|
||||
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
|
||||
TeamID: teamID,
|
||||
|
||||
Reference in New Issue
Block a user