1
0
forked from wrenn/wrenn
This commit is contained in:
2026-04-16 19:24:25 +00:00
parent 172413e91e
commit 605ad666a0
239 changed files with 19966 additions and 3454 deletions

65
pkg/service/apikey.go Normal file
View File

@ -0,0 +1,65 @@
package service
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// APIKeyService provides API key operations shared between the REST API and the dashboard.
type APIKeyService struct {
DB *db.Queries
}
// APIKeyCreateResult holds the result of creating an API key, including the
// plaintext key which is only available at creation time.
type APIKeyCreateResult struct {
Row db.TeamApiKey
Plaintext string
}
// Create generates a new API key for the given team.
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
plaintext, hash, err := auth.GenerateAPIKey()
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
}
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
ID: id.NewAPIKeyID(),
TeamID: teamID,
Name: name,
KeyHash: hash,
KeyPrefix: auth.APIKeyPrefix(plaintext),
CreatedBy: userID,
})
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
}
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
}
// List returns all API keys belonging to the given team.
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}

113
pkg/service/audit.go Normal file
View File

@ -0,0 +1,113 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const auditMaxLimit = 200
// AuditEntry is a single audit log record returned by List.
type AuditEntry struct {
ID string
TeamID string
ActorType string
ActorID string // empty for system
ActorName string // empty for system
ResourceType string
ResourceID string // empty when not applicable
Action string
Scope string
Status string // 'success', 'info', 'warning', 'error'
Metadata map[string]any
CreatedAt time.Time
}
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
TeamID pgtype.UUID
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
ResourceTypes []string // empty = no filter; multiple values = OR match
Actions []string // empty = no filter; multiple values = OR match
Before time.Time // zero = no cursor (start from latest)
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
type AuditService struct {
DB *db.Queries
}
// List returns a page of audit log entries for the given team.
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
limit := p.Limit
if limit <= 0 {
limit = 50
}
if limit > auditMaxLimit {
limit = auditMaxLimit
}
scopes := []string{"team"}
if p.AdminScoped {
scopes = append(scopes, "admin")
}
var before pgtype.Timestamptz
if !p.Before.IsZero() {
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
}
resourceTypes := p.ResourceTypes
if resourceTypes == nil {
resourceTypes = []string{}
}
actions := p.Actions
if actions == nil {
actions = []string{}
}
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
TeamID: p.TeamID,
Column2: scopes,
Column3: resourceTypes,
Column4: actions,
Column5: before,
ID: p.BeforeID,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list audit logs: %w", err)
}
entries := make([]AuditEntry, len(rows))
for i, row := range rows {
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
ID: id.FormatAuditLogID(row.ID),
TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
Scope: row.Scope,
Status: row.Status,
Metadata: meta,
CreatedAt: row.CreatedAt.Time,
}
}
return entries, nil
}

786
pkg/service/build.go Normal file
View File

@ -0,0 +1,786 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/recipe"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
const (
buildQueueKey = "wrenn:build_queue"
buildCommandTimeout = 30 * time.Second
)
// preBuildCmds run before the user recipe to prepare the build environment.
// apt update runs as root first, then USER switches to wrenn-user for the recipe.
var preBuildCmds = []string{
"RUN apt update",
"USER wrenn-user",
"WORKDIR /home/wrenn-user",
}
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
var postBuildCmds = []string{
"RUN apt clean",
"RUN apt autoremove -y",
"RUN rm -rf /var/lib/apt/lists/*",
"RUN rm -rf /tmp/build-files /tmp/build-files.*",
}
// buildAgentClient is the subset of the host agent client used by the build worker.
type buildAgentClient interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
WriteFile(ctx context.Context, req *connect.Request[pb.WriteFileRequest]) (*connect.Response[pb.WriteFileResponse], error)
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
}
// BuildService handles template build orchestration.
type BuildService struct {
DB *db.Queries
Redis *redis.Client
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
mu sync.Mutex
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
filesMap map[string][]byte // buildID → uploaded archive bytes
}
// BuildCreateParams holds the parameters for creating a template build.
type BuildCreateParams struct {
Name string
BaseTemplate string
Recipe []string
Healthcheck string
VCPUs int32
MemoryMB int32
SkipPrePost bool
Archive []byte // Optional tar/tar.gz/zip archive for COPY commands.
ArchiveName string // Original filename (used to detect format).
}
// storeArchive stores uploaded archive bytes keyed by build ID for the worker.
func (s *BuildService) storeArchive(buildID string, data []byte) {
s.mu.Lock()
defer s.mu.Unlock()
if s.filesMap == nil {
s.filesMap = make(map[string][]byte)
}
s.filesMap[buildID] = data
}
// takeArchive retrieves and removes stored archive bytes for a build.
func (s *BuildService) takeArchive(buildID string) []byte {
s.mu.Lock()
defer s.mu.Unlock()
data := s.filesMap[buildID]
delete(s.filesMap, buildID)
return data
}
// Create inserts a new build record and enqueues it to Redis.
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
if p.BaseTemplate == "" {
p.BaseTemplate = "minimal"
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
recipeJSON, err := json.Marshal(p.Recipe)
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
}
buildID := id.NewBuildID()
buildIDStr := id.FormatBuildID(buildID)
newTemplateID := id.NewTemplateID()
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
if p.SkipPrePost {
defaultSteps = 0
}
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
ID: buildID,
Name: p.Name,
BaseTemplate: p.BaseTemplate,
Recipe: recipeJSON,
Healthcheck: p.Healthcheck,
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TotalSteps: int32(len(p.Recipe) + defaultSteps),
TemplateID: newTemplateID,
TeamID: id.PlatformTeamID,
SkipPrePost: p.SkipPrePost,
})
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
}
// Store archive before enqueue so the worker never dequeues without files.
if len(p.Archive) > 0 {
s.storeArchive(buildIDStr, p.Archive)
}
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
s.takeArchive(buildIDStr) // clean up on enqueue failure
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
}
return build, nil
}
// Get returns a single build by ID.
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
return s.DB.GetTemplateBuild(ctx, buildID)
}
// List returns all builds ordered by creation time.
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
return s.DB.ListTemplateBuilds(ctx)
}
// Cancel cancels a pending or running build. For pending builds the status is
// updated in the DB and the worker skips it when dequeued. For running builds
// the per-build context is cancelled, which causes the current exec step to
// abort; executeBuild then detects the cancellation and records the status.
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
build, err := s.DB.GetTemplateBuild(ctx, buildID)
if err != nil {
return fmt.Errorf("get build: %w", err)
}
switch build.Status {
case "success", "failed", "cancelled":
return fmt.Errorf("build is already %s", build.Status)
}
// Mark cancelled in DB first. This handles both pending builds (which haven't
// been picked up yet) and acts as a flag for executeBuild to check on start.
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
ID: buildID, Status: "cancelled",
}); err != nil {
return fmt.Errorf("update build status: %w", err)
}
// If the build is currently running, signal its context.
buildIDStr := id.FormatBuildID(buildID)
s.mu.Lock()
cancel, running := s.cancelMap[buildIDStr]
s.mu.Unlock()
if running {
cancel()
}
return nil
}
// StartWorkers launches n goroutines that consume from the Redis build queue.
// The returned cancel function stops all workers.
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
for i := range n {
go s.worker(ctx, i)
}
slog.Info("build workers started", "count", n)
return cancel
}
func (s *BuildService) worker(ctx context.Context, workerID int) {
log := slog.With("worker", workerID)
for {
// BLPOP blocks until a build ID is available or context is cancelled.
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
if err != nil {
if ctx.Err() != nil {
log.Info("build worker shutting down")
return
}
log.Error("redis BLPOP error", "error", err)
time.Sleep(time.Second)
continue
}
// result[0] is the key, result[1] is the build ID (formatted string).
buildIDStr := result[1]
log.Info("picked up build", "build_id", buildIDStr)
s.executeBuild(ctx, buildIDStr)
}
}
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
log := slog.With("build_id", buildIDStr)
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
log.Error("invalid build ID from queue", "error", err)
return
}
// Create a per-build context so this build can be cancelled independently of
// the worker. Register in cancelMap before fetching the build so that a
// concurrent Cancel call can always find and signal it.
buildCtx, buildCancel := context.WithCancel(ctx)
defer buildCancel()
s.mu.Lock()
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[buildIDStr] = buildCancel
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.cancelMap, buildIDStr)
s.mu.Unlock()
}()
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
if err != nil {
log.Error("failed to fetch build", "error", err)
return
}
// Skip if already cancelled (Cancel was called before we dequeued).
if build.Status == "cancelled" {
log.Info("build already cancelled, skipping")
return
}
// Mark as running.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "running",
}); err != nil {
log.Error("failed to update build status", "error", err)
return
}
// Parse user recipe.
var userRecipe []string
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
return
}
// Pick a platform host and create a sandbox.
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false, build.MemoryMb, 5120)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
return
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
return
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
baseTeamID := id.PlatformTeamID
baseTemplateID := id.MinimalTemplateID
if build.BaseTemplate != "minimal" {
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
return
}
baseTeamID = baseTmpl.TeamID
baseTemplateID = baseTmpl.ID
}
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: build.BaseTemplate,
TeamId: id.UUIDString(baseTeamID),
TemplateId: id.UUIDString(baseTemplateID),
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
TimeoutSec: 0, // no auto-pause for builds
DiskSizeMb: 5120, // 5 GB for template builds
}))
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
return
}
// Capture sandbox metadata (envd/kernel/firecracker/agent versions).
sandboxMetadata := resp.Msg.Metadata
// Record sandbox/host association.
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
ID: buildID,
SandboxID: sandboxID,
HostID: host.ID,
})
// Upload and extract build archive if provided.
archive := s.takeArchive(buildIDStr)
if len(archive) > 0 {
if err := s.uploadAndExtractArchive(buildCtx, agent, sandboxIDStr, archive, buildIDStr); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("archive upload failed: %v", err))
return
}
}
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
// valid; panic on error is appropriate here since it would be a programmer mistake.
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
}
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
return
}
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
}
var logs []recipe.BuildLogEntry
step := 0
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
if err != nil {
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
envVars = map[string]string{
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"HOME": "/root",
}
}
bctx := &recipe.ExecContext{EnvVars: envVars, User: "root"}
// Per-step progress callback for live UI updates.
progressFn := func(currentStep int, allEntries []recipe.BuildLogEntry) {
s.updateLogs(buildCtx, buildID, currentStep, allEntries)
}
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec, func(currentStep int, phaseEntries []recipe.BuildLogEntry) {
// Progress callback: combine prior logs with current phase entries.
progressFn(currentStep, append(logs, phaseEntries...))
})
logs = append(logs, newEntries...)
step = nextStep
s.updateLogs(buildCtx, buildID, step, logs)
if !ok {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
// If the build was cancelled, status is already set — don't overwrite with "failed".
if buildCtx.Err() != nil {
return false
}
reason := "unknown error"
if len(newEntries) > 0 {
last := newEntries[len(newEntries)-1]
reason = last.Stderr
if reason == "" {
reason = fmt.Sprintf("exit code %d", last.Exit)
}
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
}
return ok
}
// Phase 1: Pre-build (as root) — creates wrenn-user, updates apt.
if !build.SkipPrePost {
if !runPhase("pre-build", preBuildSteps, 0) {
return
}
}
// Phase 2: User recipe — starts as wrenn-user (set by USER in pre-build)
// or root if skip_pre_post.
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
return
}
// Capture the final user and env vars as template defaults.
// Filter out user-specific and runtime vars that should be resolved at
// sandbox creation time, not baked in from the build environment.
templateDefaultUser := bctx.User
templateDefaultEnv := filterBuildEnv(bctx.EnvVars)
// Phase 3: Post-build (as root) — cleanup.
bctx.User = "root"
if !build.SkipPrePost {
if !runPhase("post-build", postBuildSteps, 0) {
return
}
}
// Healthcheck or direct snapshot.
var sizeBytes int64
if build.Healthcheck != "" {
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
return
}
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc, templateDefaultUser); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
return
}
// Healthcheck passed → full snapshot (with memory/CPU state).
log.Info("healthcheck passed, creating snapshot")
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
return
}
sizeBytes = snapResp.Msg.SizeBytes
} else {
// No healthcheck → image-only template (rootfs only).
log.Info("no healthcheck, flattening rootfs")
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
return
}
sizeBytes = flatResp.Msg.SizeBytes
}
// Insert into templates table as a global (platform) template.
templateType := "base"
if build.Healthcheck != "" {
templateType = "snapshot"
}
// Serialize env vars for DB storage.
defaultEnvJSON, err := json.Marshal(templateDefaultEnv)
if err != nil {
defaultEnvJSON = []byte("{}")
}
// Serialize sandbox metadata for DB storage.
metadataJSON, err := json.Marshal(sandboxMetadata)
if err != nil || len(sandboxMetadata) == 0 {
metadataJSON = []byte("{}")
}
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
ID: build.TemplateID,
Name: build.Name,
Type: templateType,
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
SizeBytes: sizeBytes,
TeamID: id.PlatformTeamID,
DefaultUser: templateDefaultUser,
DefaultEnv: defaultEnvJSON,
Metadata: metadataJSON,
}); err != nil {
log.Error("failed to insert template record", "error", err)
// Build succeeded on disk, just DB record failed — don't mark as failed.
}
// Record defaults and metadata on the build record for inspection.
_ = s.DB.UpdateBuildDefaults(buildCtx, db.UpdateBuildDefaultsParams{
ID: buildID,
DefaultUser: templateDefaultUser,
DefaultEnv: defaultEnvJSON,
Metadata: metadataJSON,
})
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
// No additional destroy needed.
// Mark build as success.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "success",
}); err != nil {
log.Error("failed to mark build as success", "error", err)
}
log.Info("template build completed successfully", "name", build.Name)
}
// waitForHealthcheck repeatedly executes the healthcheck command inside the
// sandbox according to the config's interval, timeout, start-period, and
// retries.
// During the start period, failures are not counted toward the retry budget.
// Returns nil on the first successful check, or an error if retries are
// exhausted, the deadline passes, or the context is cancelled.
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig, user string) error {
// Wrap the healthcheck command with su when a non-root user is set, so that
// ~ expands to the correct home directory and the process runs with the
// right UID (matching the template's default user).
cmd := hc.Cmd
if user != "" && user != "root" {
cmd = "su " + recipe.Shellescape(user) + " -s /bin/sh -c " + recipe.Shellescape(hc.Cmd)
}
ticker := time.NewTicker(hc.Interval)
defer ticker.Stop()
// When retries > 0, set a deadline based on the retry budget.
// When retries == 0 (unlimited), rely solely on the parent context deadline.
var deadlineCh <-chan time.Time
if hc.Retries > 0 {
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
defer deadline.Stop()
deadlineCh = deadline.C
}
startedAt := time.Now()
failCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadlineCh:
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
case <-ticker.C:
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", cmd},
TimeoutSec: int32(hc.Timeout.Seconds()),
}))
cancel()
if err != nil {
slog.Debug("healthcheck exec error (retrying)", "error", err)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
}
}
continue
}
if resp.Msg.ExitCode == 0 {
return nil
}
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
}
}
}
}
}
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
logsJSON, err := json.Marshal(logs)
if err != nil {
slog.Warn("failed to marshal build logs", "error", err)
return
}
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
ID: buildID,
CurrentStep: int32(step),
Logs: logsJSON,
}); err != nil {
slog.Warn("failed to update build progress", "error", err)
}
}
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
ID: buildID,
Error: errMsg,
}); err != nil {
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
}
}
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
// Use a detached context so cleanup succeeds even during shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
}
}
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
// the build agent and returns environment variables
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", "env"},
TimeoutSec: 10,
}))
if err != nil {
return nil, fmt.Errorf("fetch env: %w", err)
}
if resp.Msg.ExitCode != 0 {
return nil, fmt.Errorf("fetch env: command exited with code %d",
resp.Msg.ExitCode)
}
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
}
// parseSandboxEnv converts the raw newline-separated output of an 'env'
// command into a map.
// It skips empty lines and malformed entries, and correctly handles values
// containing '='.
func parseSandboxEnv(raw string) map[string]string {
envVars := make(map[string]string)
for line := range strings.SplitSeq(raw, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
envVars[parts[0]] = parts[1]
}
return envVars
}
// uploadAndExtractArchive writes the archive to the sandbox and extracts it
// to /tmp/build-files/. Detects format from content (tar.gz, tar, zip).
func (s *BuildService) uploadAndExtractArchive(
ctx context.Context,
agent buildAgentClient,
sandboxID string,
archive []byte,
buildID string,
) error {
// Detect archive type from magic bytes.
var archivePath, extractCmd string
switch {
case len(archive) >= 2 && archive[0] == 0x1f && archive[1] == 0x8b:
// gzip (tar.gz)
archivePath = "/tmp/build-files.tar.gz"
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
case len(archive) >= 4 && string(archive[:4]) == "PK\x03\x04":
// zip
archivePath = "/tmp/build-files.zip"
extractCmd = "mkdir -p /tmp/build-files && unzip -o /tmp/build-files.zip -d /tmp/build-files"
case len(archive) >= 262 && string(archive[257:262]) == "ustar":
// tar (ustar magic at offset 257)
archivePath = "/tmp/build-files.tar"
extractCmd = "mkdir -p /tmp/build-files && tar xf /tmp/build-files.tar -C /tmp/build-files"
default:
// Fallback: try tar.gz
archivePath = "/tmp/build-files.tar.gz"
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
}
slog.Info("uploading build archive", "build_id", buildID, "path", archivePath, "size", len(archive))
// Write archive to VM.
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
Path: archivePath,
Content: archive,
})); err != nil {
return fmt.Errorf("write archive: %w", err)
}
// Extract and ensure files are readable.
fullCmd := extractCmd + " && chmod -R a+rX /tmp/build-files"
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: "/bin/sh",
Args: []string{"-c", fullCmd},
TimeoutSec: 120,
}))
if err != nil {
return fmt.Errorf("extract archive: %w", err)
}
if resp.Msg.ExitCode != 0 {
return fmt.Errorf("extract archive: exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
}
return nil
}
// runtimeEnvVars lists env vars that are user- or session-specific and should
// not be persisted into template defaults. These are resolved at runtime by
// envd based on the actual user and sandbox context.
var runtimeEnvVars = map[string]bool{
"HOME": true, "USER": true, "LOGNAME": true, "SHELL": true,
"PWD": true, "OLDPWD": true, "HOSTNAME": true, "TERM": true,
"SHLVL": true, "_": true,
// Per-sandbox identifiers set by envd at boot via MMDS.
"WRENN_SANDBOX_ID": true, "WRENN_TEMPLATE_ID": true,
}
// filterBuildEnv returns a copy of envVars with runtime/user-specific
// variables removed so they don't override envd's per-user resolution.
func filterBuildEnv(envVars map[string]string) map[string]string {
filtered := make(map[string]string, len(envVars))
for k, v := range envVars {
if runtimeEnvVars[k] {
continue
}
filtered[k] = v
}
return filtered
}

628
pkg/service/host.go Normal file
View File

@ -0,0 +1,628 @@
package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// HostService provides host management operations.
type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
// HostCreateResult holds the created host and the one-time registration token.
type HostCreateResult struct {
Host db.Host
RegistrationToken string
}
// HostRegisterParams holds the parameters for host agent registration.
type HostRegisterParams struct {
Token string
Arch string
CPUCores int32
MemoryMB int32
DiskGB int32
Address string
}
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
type regTokenPayload struct {
HostID string `json:"host_id"`
TokenID string `json:"token_id"`
}
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
}
if p.Type == "regular" {
if !p.IsRequestorAdmin {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: platform admin, or team owner/admin.
if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: p.RequestingUserID,
TeamID: p.TeamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
// Validate team exists, is not deleted, and has BYOC enabled.
if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
if !team.IsByoc {
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
}
}
hostID := id.NewHostID()
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: p.TeamID,
Provider: p.Provider,
AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
}
// Generate registration token and store in Redis + Postgres audit trail.
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
}
if host.Status != "pending" {
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
if err == redis.Nil {
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
}
if err != nil {
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
}
var payload regTokenPayload
if err := json.Unmarshal(raw, &payload); err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
hostID, err := id.ParseHostID(payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
tokenID, err := id.ParseHostTokenID(payload.TokenID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: hostID,
Arch: p.Arch,
CpuCores: p.CPUCores,
MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB,
Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
}
if rowsAffected == 0 {
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
}
if n == 0 {
return fmt.Errorf("host not found")
}
return nil
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return HostDeletePreview{}, err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return &HostHasSandboxesError{SandboxIDs: ids}
}
hostIDStr := id.FormatHostID(hostID)
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return host, nil
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

451
pkg/service/sandbox.go Normal file
View File

@ -0,0 +1,451 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard.
type SandboxService struct {
DB *db.Queries
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
}
// SandboxCreateParams holds the parameters for creating a sandbox.
type SandboxCreateParams struct {
TeamID pgtype.UUID
Template string
VCPUs int32
MemoryMB int32
TimeoutSec int32
DiskSizeMB int32
}
// agentForSandbox looks up the host for the given sandbox and returns a client.
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" {
p.Template = "minimal"
}
if err := validate.SafeName(p.Template); err != nil {
return db.Sandbox{}, fmt.Errorf("invalid template name: %w", err)
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
if p.DiskSizeMB <= 0 {
p.DiskSizeMB = 5120 // 5 GB default
}
// Resolve template name → (teamID, templateID).
templateTeamID := id.PlatformTeamID
templateID := id.MinimalTemplateID
var templateDefaultUser string
var templateDefaultEnv map[string]string
if p.Template != "minimal" {
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
}
templateTeamID = tmpl.TeamID
templateID = tmpl.ID
templateDefaultUser = tmpl.DefaultUser
// Parse default_env JSONB into a map.
if len(tmpl.DefaultEnv) > 0 {
_ = json.Unmarshal(tmpl.DefaultEnv, &templateDefaultEnv)
}
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl.Type == "snapshot" {
p.VCPUs = tmpl.Vcpus
p.MemoryMB = tmpl.MemoryMb
}
}
if !p.TeamID.Valid {
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
}
// Determine whether this team uses BYOC hosts or platform hosts.
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
}
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc, p.MemoryMB, p.DiskSizeMB)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID,
TeamID: p.TeamID,
HostID: host.ID,
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
TemplateID: templateID,
TemplateTeamID: templateTeamID,
Metadata: []byte("{}"),
}); err != nil {
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: p.Template,
TeamId: id.UUIDString(templateTeamID),
TemplateId: id.UUIDString(templateID),
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
DefaultUser: templateDefaultUser,
DefaultEnv: templateDefaultEnv,
}))
if err != nil {
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "error",
}); dbErr != nil {
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
}
now := time.Now()
sb, err := s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
ID: sandboxID,
HostIp: resp.Msg.HostIp,
GuestIp: "",
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
})
if err != nil {
return db.Sandbox{}, fmt.Errorf("update sandbox running: %w", err)
}
// Store runtime metadata from the agent (envd/kernel/firecracker/agent versions).
if meta := resp.Msg.Metadata; len(meta) > 0 {
metaJSON, _ := json.Marshal(meta)
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
ID: sandboxID,
Metadata: metaJSON,
}); err != nil {
slog.Warn("failed to store sandbox metadata", "id", sandboxIDStr, "error", err)
}
sb.Metadata = metaJSON
}
return sb, nil
}
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
return s.DB.ListSandboxesByTeam(ctx, teamID)
}
// Get returns a single sandbox by ID, scoped to the given team.
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
}
// Pause snapshots and freezes a running sandbox to disk.
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "running" {
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
// mark the sandbox "stopped" while the host agent processes the pause.
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
}
// Flush all metrics tiers before pausing so data survives in DB.
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
// Revert status on failure.
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}
sb, err = s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
}
return sb, nil
}
// Resume restores a paused sandbox from snapshot.
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "paused" {
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Look up template defaults for resume.
var resumeDefaultUser string
var resumeDefaultEnv map[string]string
if sb.TemplateID.Valid {
tmpl, err := s.DB.GetTemplate(ctx, sb.TemplateID)
if err == nil {
resumeDefaultUser = tmpl.DefaultUser
if len(tmpl.DefaultEnv) > 0 {
_ = json.Unmarshal(tmpl.DefaultEnv, &resumeDefaultEnv)
}
}
}
// Extract kernel version hint from existing sandbox metadata.
var kernelVersion string
if len(sb.Metadata) > 0 {
var meta map[string]string
if err := json.Unmarshal(sb.Metadata, &meta); err == nil {
kernelVersion = meta["kernel_version"]
}
}
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxIDStr,
TimeoutSec: sb.TimeoutSec,
DefaultUser: resumeDefaultUser,
DefaultEnv: resumeDefaultEnv,
KernelVersion: kernelVersion,
}))
if err != nil {
return db.Sandbox{}, fmt.Errorf("agent resume: %w", err)
}
now := time.Now()
sb, err = s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
ID: sandboxID,
HostIp: resp.Msg.HostIp,
GuestIp: "",
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
})
if err != nil {
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
}
// Update metadata with actual versions used after resume.
if meta := resp.Msg.Metadata; len(meta) > 0 {
metaJSON, _ := json.Marshal(meta)
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
ID: sandboxID,
Metadata: metaJSON,
}); err != nil {
slog.Warn("failed to update sandbox metadata after resume", "id", sandboxIDStr, "error", err)
}
sb.Metadata = metaJSON
}
return sb, nil
}
// Destroy stops a sandbox and marks it as stopped.
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// If running, flush 24h tier metrics for analytics before destroying.
if sb.Status == "running" {
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
}
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
if sb.Status == "paused" {
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "10m",
})
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "2h",
})
}
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "stopped",
}); err != nil {
return fmt.Errorf("update status: %w", err)
}
return nil
}
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
// the returned data to DB. If allTiers is true, all three tiers are saved;
// otherwise only the 24h tier (for post-destroy analytics).
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
SandboxId: sandboxIDStr,
}))
if err != nil {
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
return
}
msg := resp.Msg
if allTiers {
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
}
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
}
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
for _, p := range points {
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
SandboxID: sandboxID,
Tier: tier,
Ts: p.TimestampUnix,
CpuPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}); err != nil {
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
}
}
}
// Ping resets the inactivity timer for a running sandbox.
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "running" {
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)
}
if err := s.DB.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
}
return nil
}

160
pkg/service/stats.go Normal file
View File

@ -0,0 +1,160 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// TimeRange identifies a chart time window.
type TimeRange string
const (
Range5m TimeRange = "5m"
Range1h TimeRange = "1h"
Range6h TimeRange = "6h"
Range24h TimeRange = "24h"
Range30d TimeRange = "30d"
)
type rangeConfig struct {
bucketSec int // bucket width in seconds for time-series aggregation
intervalLiteral string // PostgreSQL interval literal for the lookback window
}
var rangeConfigs = map[TimeRange]rangeConfig{
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
}
// ValidRange returns true if r is a known TimeRange value.
func ValidRange(r TimeRange) bool {
_, ok := rangeConfigs[r]
return ok
}
// StatPoint is one bucketed data point in the time-series.
type StatPoint struct {
Bucket time.Time
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// CurrentStats holds the live values for a team, read directly from sandboxes.
type CurrentStats struct {
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// PeakStats holds the 30-day maximum values for a team.
type PeakStats struct {
RunningCount int32
VCPUs int32
MemoryMB int32
}
// StatsService computes sandbox metrics for the dashboard.
type StatsService struct {
DB *db.Queries
Pool *pgxpool.Pool
}
// GetStats returns current stats, 30-day peaks, and a time-series for the
// given team and time range. If no snapshots exist yet, zeros are returned.
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
cfg, ok := rangeConfigs[r]
if !ok {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
}
// Current live values — read directly from sandboxes so we always reflect
// the true state even when no capsules are running.
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
}
current := CurrentStats{
RunningCount: cur.RunningCount,
VCPUsReserved: cur.VcpusReserved,
MemoryMBReserved: cur.MemoryMbReserved,
}
// 30-day peaks.
var peaks PeakStats
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
}
if err == nil {
peaks = PeakStats{
RunningCount: pk.PeakRunningCount,
VCPUs: pk.PeakVcpus,
MemoryMB: pk.PeakMemoryMb,
}
}
// Time-series — dynamic bucket width, executed via pgx directly.
series, err := s.queryTimeSeries(ctx, teamID, cfg)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
}
return current, peaks, series, nil
}
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
// integer number of seconds without requiring TimescaleDB.
//
// MAX is used instead of AVG so that short-lived running states are not
// averaged down to zero within a bucket. For capacity metrics the peak
// value in each bucket is what matters — AVG with ::INTEGER rounding
// caused running_count, vcpus, and memory to become inconsistent with
// each other (e.g. running=0 but vcpus=1).
//
// $1 = bucket width in seconds (integer)
// $2 = team_id
// $3 = lookback interval literal (e.g. '1 hour')
const timeSeriesSQL = `
SELECT
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
MAX(running_count) AS running_count,
MAX(vcpus_reserved) AS vcpus_reserved,
MAX(memory_mb_reserved) AS memory_mb_reserved
FROM sandbox_metrics_snapshots
WHERE team_id = $2
AND sampled_at >= NOW() - $3::INTERVAL
GROUP BY bucket
ORDER BY bucket ASC
`
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
if err != nil {
return nil, err
}
defer rows.Close()
var points []StatPoint
for rows.Next() {
var p StatPoint
var bucket time.Time
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
return nil, err
}
p.Bucket = bucket
points = append(points, p)
}
return points, rows.Err()
}

544
pkg/service/team.go Normal file
View File

@ -0,0 +1,544 @@
package service
import (
"context"
"fmt"
"log/slog"
"regexp"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
// TeamService provides team management operations.
type TeamService struct {
DB *db.Queries
Pool *pgxpool.Pool
HostPool *lifecycle.HostClientPool
}
// TeamWithRole pairs a team with the calling user's role in it.
type TeamWithRole struct {
db.Team
Role string `json:"role"`
}
// MemberInfo is a team member with resolved user details.
type MemberInfo struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt time.Time `json:"joined_at"`
}
// callerRole fetches the calling user's role in the given team from DB.
// Returns an error wrapping "forbidden" if the caller is not a member.
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: callerUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("forbidden: not a member of this team")
}
return "", fmt.Errorf("get membership: %w", err)
}
return m.Role, nil
}
// requireAdmin returns an error if the caller is not an admin or owner.
func requireAdmin(role string) error {
if role != "owner" && role != "admin" {
return fmt.Errorf("forbidden: admin or owner role required")
}
return nil
}
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
if err == pgx.ErrNoRows {
return db.Team{}, fmt.Errorf("team not found")
}
return db.Team{}, fmt.Errorf("get team: %w", err)
}
if team.DeletedAt.Valid {
return db.Team{}, fmt.Errorf("team not found")
}
return team, nil
}
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
rows, err := s.DB.GetTeamsForUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list teams: %w", err)
}
result := make([]TeamWithRole, len(rows))
for i, r := range rows {
result[i] = TeamWithRole{
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
Role: r.Role,
}
}
return result, nil
}
// CreateTeam creates a new team owned by the given user.
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
if !teamNameRE.MatchString(name) {
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
tx, err := s.Pool.Begin(ctx)
if err != nil {
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := s.DB.WithTx(tx)
teamID := id.NewTeamID()
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: name,
Slug: id.NewTeamSlug(),
})
if err != nil {
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: ownerUserID,
TeamID: teamID,
IsDefault: false,
Role: "owner",
}); err != nil {
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
}
return TeamWithRole{Team: team, Role: "owner"}, nil
}
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
if !teamNameRE.MatchString(newName) {
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(role); err != nil {
return err
}
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
return fmt.Errorf("update name: %w", err)
}
return nil
}
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
// are preserved; only the team's deleted_at is set and active VMs are stopped.
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role != "owner" {
return fmt.Errorf("forbidden: only the owner can delete a team")
}
return s.deleteTeamCore(ctx, teamID)
}
// deleteTeamCore contains the shared team deletion logic:
// destroy active sandboxes, clean up templates, soft-delete the team.
func (s *TeamService) deleteTeamCore(ctx context.Context, teamID pgtype.UUID) error {
// Collect active sandboxes and stop them.
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("list active sandboxes: %w", err)
}
var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}
if len(stopIDs) > 0 {
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: stopIDs,
Status: "stopped",
}); err != nil {
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
// as that would leave orphaned "running" records for a deleted team.
return fmt.Errorf("update sandbox statuses: %w", err)
}
}
// Delete sandbox metrics for this team.
if err := s.DB.DeleteMetricPointsByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete metric points", "team_id", id.FormatTeamID(teamID), "error", err)
}
if err := s.DB.DeleteMetricsSnapshotsByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete metrics snapshots", "team_id", id.FormatTeamID(teamID), "error", err)
}
// Delete all API keys for this team.
if err := s.DB.DeleteAPIKeysByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete API keys", "team_id", id.FormatTeamID(teamID), "error", err)
}
// Delete all channels for this team.
if err := s.DB.DeleteAllChannelsByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete channels", "team_id", id.FormatTeamID(teamID), "error", err)
}
// Clean up team-owned templates from all hosts in the background.
go s.cleanupTeamTemplates(context.Background(), teamID)
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
// then removes the DB records. Called asynchronously during team deletion.
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
if err != nil {
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
return
}
if len(templates) == 0 {
return
}
hosts, err := s.DB.ListActiveHosts(ctx)
if err != nil {
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
return
}
for _, tmpl := range templates {
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := s.HostPool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: id.UUIDString(tmpl.TeamID),
TemplateId: id.UUIDString(tmpl.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to delete template on host",
"host_id", id.FormatHostID(host.ID),
"template", tmpl.Name,
"error", err,
)
}
}
}
// Remove DB records.
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
}
}
// GetMembers returns all members of the team with their emails and roles.
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
rows, err := s.DB.GetTeamMembers(ctx, teamID)
if err != nil {
return nil, fmt.Errorf("get members: %w", err)
}
members := make([]MemberInfo, len(rows))
for i, r := range rows {
var joinedAt time.Time
if r.JoinedAt.Valid {
joinedAt = r.JoinedAt.Time
}
members[i] = MemberInfo{
UserID: id.FormatUserID(r.ID),
Name: r.Name,
Email: r.Email,
Role: r.Role,
JoinedAt: joinedAt,
}
}
return members, nil
}
// AddMember adds an existing user (looked up by email) to the team as a member.
// Caller must be admin or owner (verified from DB).
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return MemberInfo{}, err
}
if err := requireAdmin(role); err != nil {
return MemberInfo{}, err
}
target, err := s.DB.GetUserByEmail(ctx, email)
if err != nil {
if err == pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
}
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
}
// Check if already a member.
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: target.ID,
TeamID: teamID,
})
if memberCheckErr == nil {
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
} else if memberCheckErr != pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
}
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: target.ID,
TeamID: teamID,
IsDefault: false,
Role: "member",
}); err != nil {
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
}
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
}
// RemoveMember removes a user from the team.
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: targetUserID,
}); err != nil {
return fmt.Errorf("delete member: %w", err)
}
return nil
}
// UpdateMemberRole changes a member's role to admin or member.
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
// Valid target roles: "admin", "member".
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
if newRole != "admin" && newRole != "member" {
return fmt.Errorf("invalid: role must be admin or member")
}
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner's role cannot be changed")
}
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
TeamID: teamID,
UserID: targetUserID,
Role: newRole,
}); err != nil {
return fmt.Errorf("update role: %w", err)
}
return nil
}
// LeaveTeam removes the calling user from the team.
// The owner cannot leave; they must delete the team instead.
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role == "owner" {
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: callerUserID,
}); err != nil {
return fmt.Errorf("leave team: %w", err)
}
return nil
}
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
// be disabled — it is a one-way transition.
// Admin-only — the caller must verify admin status before invoking this.
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
if !enabled {
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
}
if team.IsByoc {
// Already enabled — idempotent, no-op.
return nil
}
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
return fmt.Errorf("set byoc: %w", err)
}
return nil
}
// AdminTeamRow is the shape returned by AdminListTeams.
type AdminTeamRow struct {
ID pgtype.UUID
Name string
Slug string
IsByoc bool
CreatedAt time.Time
DeletedAt *time.Time
MemberCount int32
OwnerName string
OwnerEmail string
ActiveSandboxCount int32
ChannelCount int32
}
// AdminListTeams returns a paginated list of all teams (excluding the platform
// team) with member counts, owner info, and active sandbox counts.
// Admin-only — caller must verify admin status.
func (s *TeamService) AdminListTeams(ctx context.Context, limit, offset int32) ([]AdminTeamRow, int32, error) {
teams, err := s.DB.ListTeamsAdmin(ctx, db.ListTeamsAdminParams{
Limit: limit,
Offset: offset,
})
if err != nil {
return nil, 0, fmt.Errorf("list teams: %w", err)
}
total, err := s.DB.CountTeamsAdmin(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count teams: %w", err)
}
rows := make([]AdminTeamRow, len(teams))
for i, t := range teams {
row := AdminTeamRow{
ID: t.ID,
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
CreatedAt: t.CreatedAt.Time,
MemberCount: t.MemberCount,
OwnerName: t.OwnerName,
OwnerEmail: t.OwnerEmail,
ActiveSandboxCount: t.ActiveSandboxCount,
ChannelCount: t.ChannelCount,
}
if t.DeletedAt.Valid {
deletedAt := t.DeletedAt.Time
row.DeletedAt = &deletedAt
}
rows[i] = row
}
return rows, total, nil
}
// DeleteTeamInternal soft-deletes a team and destroys all its active sandboxes.
// Used for system-initiated deletions (e.g. cascading from user account deletion)
// where no caller role check is needed.
func (s *TeamService) DeleteTeamInternal(ctx context.Context, teamID pgtype.UUID) error {
return s.deleteTeamCore(ctx, teamID)
}
// AdminDeleteTeam soft-deletes a team and destroys all its active sandboxes.
// Unlike DeleteTeam, this does not require the caller to be the team owner —
// it is admin-only (caller must verify admin status).
func (s *TeamService) AdminDeleteTeam(ctx context.Context, teamID pgtype.UUID) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
return s.deleteTeamCore(ctx, teamID)
}

27
pkg/service/template.go Normal file
View File

@ -0,0 +1,27 @@
package service
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// TemplateService provides template/snapshot operations shared between the
// REST API and the dashboard.
type TemplateService struct {
DB *db.Queries
}
// List returns all templates belonging to the given team. If typeFilter is
// non-empty, only templates of that type ("base" or "snapshot") are returned.
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
if typeFilter != "" {
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
TeamID: teamID,
Type: typeFilter,
})
}
return s.DB.ListTemplatesByTeam(ctx, teamID)
}

107
pkg/service/user.go Normal file
View File

@ -0,0 +1,107 @@
package service
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// UserService provides user management operations.
type UserService struct {
DB *db.Queries
SandboxSvc *SandboxService
}
// AdminUserRow is the shape returned by AdminListUsers.
type AdminUserRow struct {
ID pgtype.UUID
Email string
Name string
IsAdmin bool
Status string
CreatedAt time.Time
TeamsJoined int32
TeamsOwned int32
}
// AdminListUsers returns a paginated list of all non-deleted users with team counts.
func (s *UserService) AdminListUsers(ctx context.Context, limit, offset int32) ([]AdminUserRow, int32, error) {
users, err := s.DB.ListUsersAdmin(ctx, db.ListUsersAdminParams{
Limit: limit,
Offset: offset,
})
if err != nil {
return nil, 0, fmt.Errorf("list users: %w", err)
}
total, err := s.DB.CountUsersAdmin(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count users: %w", err)
}
rows := make([]AdminUserRow, len(users))
for i, u := range users {
rows[i] = AdminUserRow{
ID: u.ID,
Email: u.Email,
Name: u.Name,
IsAdmin: u.IsAdmin,
Status: u.Status,
CreatedAt: u.CreatedAt.Time,
TeamsJoined: u.TeamsJoined,
TeamsOwned: u.TeamsOwned,
}
}
return rows, total, nil
}
// SetUserStatus sets the status of a user account.
func (s *UserService) SetUserStatus(ctx context.Context, userID pgtype.UUID, status string) error {
if err := s.DB.SetUserStatus(ctx, db.SetUserStatusParams{
ID: userID,
Status: status,
}); err != nil {
return fmt.Errorf("set user status: %w", err)
}
if status == "disabled" || status == "deleted" {
if err := s.DB.DeleteAPIKeysByCreator(ctx, userID); err != nil {
slog.Warn("failed to delete API keys for deactivated user", "user_id", userID, "error", err)
}
s.destroySandboxesForOwnedTeams(ctx, userID)
}
return nil
}
// destroySandboxesForOwnedTeams destroys all active sandboxes (running, paused,
// hibernated, starting) for every team the user owns. Best-effort: errors are
// logged but do not prevent the user from being disabled.
func (s *UserService) destroySandboxesForOwnedTeams(ctx context.Context, userID pgtype.UUID) {
if s.SandboxSvc == nil {
return
}
teamIDs, err := s.DB.GetOwnedTeamIDs(ctx, userID)
if err != nil {
slog.Warn("failed to list owned teams for sandbox cleanup", "user_id", userID, "error", err)
return
}
for _, teamID := range teamIDs {
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
slog.Warn("failed to list active sandboxes for team", "team_id", teamID, "user_id", userID, "error", err)
continue
}
for _, sb := range sandboxes {
if err := s.SandboxSvc.Destroy(ctx, sb.ID, teamID); err != nil {
slog.Warn("failed to destroy sandbox during user disable",
"sandbox_id", sb.ID, "team_id", teamID, "user_id", userID, "error", err)
}
}
}
}