forked from wrenn/wrenn
v0.1.0 (#17)
This commit is contained in:
65
pkg/service/apikey.go
Normal file
65
pkg/service/apikey.go
Normal file
@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// APIKeyService provides API key operations shared between the REST API and the dashboard.
|
||||
type APIKeyService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// APIKeyCreateResult holds the result of creating an API key, including the
|
||||
// plaintext key which is only available at creation time.
|
||||
type APIKeyCreateResult struct {
|
||||
Row db.TeamApiKey
|
||||
Plaintext string
|
||||
}
|
||||
|
||||
// Create generates a new API key for the given team.
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
|
||||
if name == "" {
|
||||
name = "Unnamed API Key"
|
||||
}
|
||||
|
||||
plaintext, hash, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
|
||||
ID: id.NewAPIKeyID(),
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
KeyHash: hash,
|
||||
KeyPrefix: auth.APIKeyPrefix(plaintext),
|
||||
CreatedBy: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
|
||||
}
|
||||
|
||||
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
|
||||
}
|
||||
|
||||
// List returns all API keys belonging to the given team.
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
|
||||
return s.DB.ListAPIKeysByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListWithCreator returns all API keys for the team, joined with the creator's email.
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
|
||||
}
|
||||
|
||||
// Delete removes an API key by ID, scoped to the given team.
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
|
||||
}
|
||||
113
pkg/service/audit.go
Normal file
113
pkg/service/audit.go
Normal file
@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const auditMaxLimit = 200
|
||||
|
||||
// AuditEntry is a single audit log record returned by List.
|
||||
type AuditEntry struct {
|
||||
ID string
|
||||
TeamID string
|
||||
ActorType string
|
||||
ActorID string // empty for system
|
||||
ActorName string // empty for system
|
||||
ResourceType string
|
||||
ResourceID string // empty when not applicable
|
||||
Action string
|
||||
Scope string
|
||||
Status string // 'success', 'info', 'warning', 'error'
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AuditListParams controls the ListAuditLogs query.
|
||||
type AuditListParams struct {
|
||||
TeamID pgtype.UUID
|
||||
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
|
||||
ResourceTypes []string // empty = no filter; multiple values = OR match
|
||||
Actions []string // empty = no filter; multiple values = OR match
|
||||
Before time.Time // zero = no cursor (start from latest)
|
||||
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
|
||||
Limit int // clamped to auditMaxLimit by the handler
|
||||
}
|
||||
|
||||
// AuditService provides the read side of the audit log.
|
||||
type AuditService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns a page of audit log entries for the given team.
|
||||
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
|
||||
limit := p.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > auditMaxLimit {
|
||||
limit = auditMaxLimit
|
||||
}
|
||||
|
||||
scopes := []string{"team"}
|
||||
if p.AdminScoped {
|
||||
scopes = append(scopes, "admin")
|
||||
}
|
||||
|
||||
var before pgtype.Timestamptz
|
||||
if !p.Before.IsZero() {
|
||||
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
|
||||
}
|
||||
|
||||
resourceTypes := p.ResourceTypes
|
||||
if resourceTypes == nil {
|
||||
resourceTypes = []string{}
|
||||
}
|
||||
actions := p.Actions
|
||||
if actions == nil {
|
||||
actions = []string{}
|
||||
}
|
||||
|
||||
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
|
||||
TeamID: p.TeamID,
|
||||
Column2: scopes,
|
||||
Column3: resourceTypes,
|
||||
Column4: actions,
|
||||
Column5: before,
|
||||
ID: p.BeforeID,
|
||||
Limit: int32(limit),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list audit logs: %w", err)
|
||||
}
|
||||
|
||||
entries := make([]AuditEntry, len(rows))
|
||||
for i, row := range rows {
|
||||
var meta map[string]any
|
||||
if len(row.Metadata) > 0 {
|
||||
_ = json.Unmarshal(row.Metadata, &meta)
|
||||
}
|
||||
entries[i] = AuditEntry{
|
||||
ID: id.FormatAuditLogID(row.ID),
|
||||
TeamID: id.FormatTeamID(row.TeamID),
|
||||
ActorType: row.ActorType,
|
||||
ActorID: row.ActorID.String,
|
||||
ActorName: row.ActorName,
|
||||
ResourceType: row.ResourceType,
|
||||
ResourceID: row.ResourceID.String,
|
||||
Action: row.Action,
|
||||
Scope: row.Scope,
|
||||
Status: row.Status,
|
||||
Metadata: meta,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
786
pkg/service/build.go
Normal file
786
pkg/service/build.go
Normal file
@ -0,0 +1,786 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/recipe"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
const (
|
||||
buildQueueKey = "wrenn:build_queue"
|
||||
buildCommandTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// preBuildCmds run before the user recipe to prepare the build environment.
|
||||
// apt update runs as root first, then USER switches to wrenn-user for the recipe.
|
||||
var preBuildCmds = []string{
|
||||
"RUN apt update",
|
||||
"USER wrenn-user",
|
||||
"WORKDIR /home/wrenn-user",
|
||||
}
|
||||
|
||||
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
|
||||
var postBuildCmds = []string{
|
||||
"RUN apt clean",
|
||||
"RUN apt autoremove -y",
|
||||
"RUN rm -rf /var/lib/apt/lists/*",
|
||||
"RUN rm -rf /tmp/build-files /tmp/build-files.*",
|
||||
}
|
||||
|
||||
// buildAgentClient is the subset of the host agent client used by the build worker.
|
||||
type buildAgentClient interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
WriteFile(ctx context.Context, req *connect.Request[pb.WriteFileRequest]) (*connect.Response[pb.WriteFileResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
|
||||
}
|
||||
|
||||
// BuildService handles template build orchestration.
|
||||
type BuildService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
|
||||
mu sync.Mutex
|
||||
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
|
||||
filesMap map[string][]byte // buildID → uploaded archive bytes
|
||||
}
|
||||
|
||||
// BuildCreateParams holds the parameters for creating a template build.
|
||||
type BuildCreateParams struct {
|
||||
Name string
|
||||
BaseTemplate string
|
||||
Recipe []string
|
||||
Healthcheck string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
SkipPrePost bool
|
||||
Archive []byte // Optional tar/tar.gz/zip archive for COPY commands.
|
||||
ArchiveName string // Original filename (used to detect format).
|
||||
}
|
||||
|
||||
// storeArchive stores uploaded archive bytes keyed by build ID for the worker.
|
||||
func (s *BuildService) storeArchive(buildID string, data []byte) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.filesMap == nil {
|
||||
s.filesMap = make(map[string][]byte)
|
||||
}
|
||||
s.filesMap[buildID] = data
|
||||
}
|
||||
|
||||
// takeArchive retrieves and removes stored archive bytes for a build.
|
||||
func (s *BuildService) takeArchive(buildID string) []byte {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
data := s.filesMap[buildID]
|
||||
delete(s.filesMap, buildID)
|
||||
return data
|
||||
}
|
||||
|
||||
// Create inserts a new build record and enqueues it to Redis.
|
||||
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
|
||||
if p.BaseTemplate == "" {
|
||||
p.BaseTemplate = "minimal"
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(p.Recipe)
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
|
||||
}
|
||||
|
||||
buildID := id.NewBuildID()
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
|
||||
if p.SkipPrePost {
|
||||
defaultSteps = 0
|
||||
}
|
||||
|
||||
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
|
||||
ID: buildID,
|
||||
Name: p.Name,
|
||||
BaseTemplate: p.BaseTemplate,
|
||||
Recipe: recipeJSON,
|
||||
Healthcheck: p.Healthcheck,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TotalSteps: int32(len(p.Recipe) + defaultSteps),
|
||||
TemplateID: newTemplateID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
SkipPrePost: p.SkipPrePost,
|
||||
})
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
|
||||
}
|
||||
|
||||
// Store archive before enqueue so the worker never dequeues without files.
|
||||
if len(p.Archive) > 0 {
|
||||
s.storeArchive(buildIDStr, p.Archive)
|
||||
}
|
||||
|
||||
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
|
||||
s.takeArchive(buildIDStr) // clean up on enqueue failure
|
||||
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
|
||||
}
|
||||
|
||||
return build, nil
|
||||
}
|
||||
|
||||
// Get returns a single build by ID.
|
||||
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
|
||||
return s.DB.GetTemplateBuild(ctx, buildID)
|
||||
}
|
||||
|
||||
// List returns all builds ordered by creation time.
|
||||
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
|
||||
return s.DB.ListTemplateBuilds(ctx)
|
||||
}
|
||||
|
||||
// Cancel cancels a pending or running build. For pending builds the status is
|
||||
// updated in the DB and the worker skips it when dequeued. For running builds
|
||||
// the per-build context is cancelled, which causes the current exec step to
|
||||
// abort; executeBuild then detects the cancellation and records the status.
|
||||
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
|
||||
build, err := s.DB.GetTemplateBuild(ctx, buildID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get build: %w", err)
|
||||
}
|
||||
switch build.Status {
|
||||
case "success", "failed", "cancelled":
|
||||
return fmt.Errorf("build is already %s", build.Status)
|
||||
}
|
||||
|
||||
// Mark cancelled in DB first. This handles both pending builds (which haven't
|
||||
// been picked up yet) and acts as a flag for executeBuild to check on start.
|
||||
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "cancelled",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update build status: %w", err)
|
||||
}
|
||||
|
||||
// If the build is currently running, signal its context.
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
s.mu.Lock()
|
||||
cancel, running := s.cancelMap[buildIDStr]
|
||||
s.mu.Unlock()
|
||||
if running {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWorkers launches n goroutines that consume from the Redis build queue.
|
||||
// The returned cancel function stops all workers.
|
||||
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := range n {
|
||||
go s.worker(ctx, i)
|
||||
}
|
||||
slog.Info("build workers started", "count", n)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (s *BuildService) worker(ctx context.Context, workerID int) {
|
||||
log := slog.With("worker", workerID)
|
||||
for {
|
||||
// BLPOP blocks until a build ID is available or context is cancelled.
|
||||
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Info("build worker shutting down")
|
||||
return
|
||||
}
|
||||
log.Error("redis BLPOP error", "error", err)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
// result[0] is the key, result[1] is the build ID (formatted string).
|
||||
buildIDStr := result[1]
|
||||
log.Info("picked up build", "build_id", buildIDStr)
|
||||
s.executeBuild(ctx, buildIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
log := slog.With("build_id", buildIDStr)
|
||||
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
log.Error("invalid build ID from queue", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a per-build context so this build can be cancelled independently of
|
||||
// the worker. Register in cancelMap before fetching the build so that a
|
||||
// concurrent Cancel call can always find and signal it.
|
||||
buildCtx, buildCancel := context.WithCancel(ctx)
|
||||
defer buildCancel()
|
||||
|
||||
s.mu.Lock()
|
||||
if s.cancelMap == nil {
|
||||
s.cancelMap = make(map[string]context.CancelFunc)
|
||||
}
|
||||
s.cancelMap[buildIDStr] = buildCancel
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.cancelMap, buildIDStr)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
|
||||
if err != nil {
|
||||
log.Error("failed to fetch build", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if already cancelled (Cancel was called before we dequeued).
|
||||
if build.Status == "cancelled" {
|
||||
log.Info("build already cancelled, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as running.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "running",
|
||||
}); err != nil {
|
||||
log.Error("failed to update build status", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse user recipe.
|
||||
var userRecipe []string
|
||||
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a platform host and create a sandbox.
|
||||
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false, build.MemoryMb, 5120)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
|
||||
baseTeamID := id.PlatformTeamID
|
||||
baseTemplateID := id.MinimalTemplateID
|
||||
if build.BaseTemplate != "minimal" {
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return
|
||||
}
|
||||
baseTeamID = baseTmpl.TeamID
|
||||
baseTemplateID = baseTmpl.ID
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0, // no auto-pause for builds
|
||||
DiskSizeMb: 5120, // 5 GB for template builds
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return
|
||||
}
|
||||
// Capture sandbox metadata (envd/kernel/firecracker/agent versions).
|
||||
sandboxMetadata := resp.Msg.Metadata
|
||||
|
||||
// Record sandbox/host association.
|
||||
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
// Upload and extract build archive if provided.
|
||||
archive := s.takeArchive(buildIDStr)
|
||||
if len(archive) > 0 {
|
||||
if err := s.uploadAndExtractArchive(buildCtx, agent, sandboxIDStr, archive, buildIDStr); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("archive upload failed: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
|
||||
// valid; panic on error is appropriate here since it would be a programmer mistake.
|
||||
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
|
||||
}
|
||||
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
|
||||
return
|
||||
}
|
||||
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
|
||||
}
|
||||
|
||||
var logs []recipe.BuildLogEntry
|
||||
step := 0
|
||||
|
||||
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
|
||||
if err != nil {
|
||||
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
|
||||
envVars = map[string]string{
|
||||
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"HOME": "/root",
|
||||
}
|
||||
}
|
||||
bctx := &recipe.ExecContext{EnvVars: envVars, User: "root"}
|
||||
|
||||
// Per-step progress callback for live UI updates.
|
||||
progressFn := func(currentStep int, allEntries []recipe.BuildLogEntry) {
|
||||
s.updateLogs(buildCtx, buildID, currentStep, allEntries)
|
||||
}
|
||||
|
||||
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec, func(currentStep int, phaseEntries []recipe.BuildLogEntry) {
|
||||
// Progress callback: combine prior logs with current phase entries.
|
||||
progressFn(currentStep, append(logs, phaseEntries...))
|
||||
})
|
||||
logs = append(logs, newEntries...)
|
||||
step = nextStep
|
||||
s.updateLogs(buildCtx, buildID, step, logs)
|
||||
if !ok {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
// If the build was cancelled, status is already set — don't overwrite with "failed".
|
||||
if buildCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
reason := "unknown error"
|
||||
if len(newEntries) > 0 {
|
||||
last := newEntries[len(newEntries)-1]
|
||||
reason = last.Stderr
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("exit code %d", last.Exit)
|
||||
}
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Phase 1: Pre-build (as root) — creates wrenn-user, updates apt.
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("pre-build", preBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: User recipe — starts as wrenn-user (set by USER in pre-build)
|
||||
// or root if skip_pre_post.
|
||||
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
|
||||
return
|
||||
}
|
||||
|
||||
// Capture the final user and env vars as template defaults.
|
||||
// Filter out user-specific and runtime vars that should be resolved at
|
||||
// sandbox creation time, not baked in from the build environment.
|
||||
templateDefaultUser := bctx.User
|
||||
templateDefaultEnv := filterBuildEnv(bctx.EnvVars)
|
||||
|
||||
// Phase 3: Post-build (as root) — cleanup.
|
||||
bctx.User = "root"
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("post-build", postBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck or direct snapshot.
|
||||
var sizeBytes int64
|
||||
if build.Healthcheck != "" {
|
||||
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
return
|
||||
}
|
||||
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
|
||||
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc, templateDefaultUser); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Healthcheck passed → full snapshot (with memory/CPU state).
|
||||
log.Info("healthcheck passed, creating snapshot")
|
||||
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = snapResp.Msg.SizeBytes
|
||||
} else {
|
||||
// No healthcheck → image-only template (rootfs only).
|
||||
log.Info("no healthcheck, flattening rootfs")
|
||||
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = flatResp.Msg.SizeBytes
|
||||
}
|
||||
|
||||
// Insert into templates table as a global (platform) template.
|
||||
templateType := "base"
|
||||
if build.Healthcheck != "" {
|
||||
templateType = "snapshot"
|
||||
}
|
||||
|
||||
// Serialize env vars for DB storage.
|
||||
defaultEnvJSON, err := json.Marshal(templateDefaultEnv)
|
||||
if err != nil {
|
||||
defaultEnvJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Serialize sandbox metadata for DB storage.
|
||||
metadataJSON, err := json.Marshal(sandboxMetadata)
|
||||
if err != nil || len(sandboxMetadata) == 0 {
|
||||
metadataJSON = []byte("{}")
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
|
||||
ID: build.TemplateID,
|
||||
Name: build.Name,
|
||||
Type: templateType,
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
SizeBytes: sizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultEnv: defaultEnvJSON,
|
||||
Metadata: metadataJSON,
|
||||
}); err != nil {
|
||||
log.Error("failed to insert template record", "error", err)
|
||||
// Build succeeded on disk, just DB record failed — don't mark as failed.
|
||||
}
|
||||
|
||||
// Record defaults and metadata on the build record for inspection.
|
||||
_ = s.DB.UpdateBuildDefaults(buildCtx, db.UpdateBuildDefaultsParams{
|
||||
ID: buildID,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultEnv: defaultEnvJSON,
|
||||
Metadata: metadataJSON,
|
||||
})
|
||||
|
||||
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
|
||||
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
|
||||
// No additional destroy needed.
|
||||
|
||||
// Mark build as success.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "success",
|
||||
}); err != nil {
|
||||
log.Error("failed to mark build as success", "error", err)
|
||||
}
|
||||
|
||||
log.Info("template build completed successfully", "name", build.Name)
|
||||
}
|
||||
|
||||
// waitForHealthcheck repeatedly executes the healthcheck command inside the
|
||||
// sandbox according to the config's interval, timeout, start-period, and
|
||||
// retries.
|
||||
// During the start period, failures are not counted toward the retry budget.
|
||||
// Returns nil on the first successful check, or an error if retries are
|
||||
// exhausted, the deadline passes, or the context is cancelled.
|
||||
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig, user string) error {
|
||||
// Wrap the healthcheck command with su when a non-root user is set, so that
|
||||
// ~ expands to the correct home directory and the process runs with the
|
||||
// right UID (matching the template's default user).
|
||||
cmd := hc.Cmd
|
||||
if user != "" && user != "root" {
|
||||
cmd = "su " + recipe.Shellescape(user) + " -s /bin/sh -c " + recipe.Shellescape(hc.Cmd)
|
||||
}
|
||||
ticker := time.NewTicker(hc.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// When retries > 0, set a deadline based on the retry budget.
|
||||
// When retries == 0 (unlimited), rely solely on the parent context deadline.
|
||||
var deadlineCh <-chan time.Time
|
||||
if hc.Retries > 0 {
|
||||
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
|
||||
defer deadline.Stop()
|
||||
deadlineCh = deadline.C
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
failCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadlineCh:
|
||||
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
|
||||
case <-ticker.C:
|
||||
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
|
||||
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", cmd},
|
||||
TimeoutSec: int32(hc.Timeout.Seconds()),
|
||||
}))
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
slog.Debug("healthcheck exec error (retrying)", "error", err)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if resp.Msg.ExitCode == 0 {
|
||||
return nil
|
||||
}
|
||||
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
|
||||
if time.Since(startedAt) >= hc.StartPeriod {
|
||||
failCount++
|
||||
if hc.Retries > 0 && failCount >= hc.Retries {
|
||||
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
|
||||
logsJSON, err := json.Marshal(logs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to marshal build logs", "error", err)
|
||||
return
|
||||
}
|
||||
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
|
||||
ID: buildID,
|
||||
CurrentStep: int32(step),
|
||||
Logs: logsJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update build progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
|
||||
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
|
||||
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
|
||||
ID: buildID,
|
||||
Error: errMsg,
|
||||
}); err != nil {
|
||||
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
|
||||
// Use a detached context so cleanup succeeds even during shutdown.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
|
||||
// the build agent and returns environment variables
|
||||
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
|
||||
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", "env"},
|
||||
TimeoutSec: 10,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch env: %w", err)
|
||||
}
|
||||
|
||||
if resp.Msg.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("fetch env: command exited with code %d",
|
||||
resp.Msg.ExitCode)
|
||||
}
|
||||
|
||||
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
|
||||
}
|
||||
|
||||
// parseSandboxEnv converts the raw newline-separated output of an 'env'
|
||||
// command into a map.
|
||||
// It skips empty lines and malformed entries, and correctly handles values
|
||||
// containing '='.
|
||||
func parseSandboxEnv(raw string) map[string]string {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
for line := range strings.SplitSeq(raw, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
envVars[parts[0]] = parts[1]
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
|
||||
// uploadAndExtractArchive writes the archive to the sandbox and extracts it
|
||||
// to /tmp/build-files/. Detects format from content (tar.gz, tar, zip).
|
||||
func (s *BuildService) uploadAndExtractArchive(
|
||||
ctx context.Context,
|
||||
agent buildAgentClient,
|
||||
sandboxID string,
|
||||
archive []byte,
|
||||
buildID string,
|
||||
) error {
|
||||
// Detect archive type from magic bytes.
|
||||
var archivePath, extractCmd string
|
||||
switch {
|
||||
case len(archive) >= 2 && archive[0] == 0x1f && archive[1] == 0x8b:
|
||||
// gzip (tar.gz)
|
||||
archivePath = "/tmp/build-files.tar.gz"
|
||||
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
|
||||
case len(archive) >= 4 && string(archive[:4]) == "PK\x03\x04":
|
||||
// zip
|
||||
archivePath = "/tmp/build-files.zip"
|
||||
extractCmd = "mkdir -p /tmp/build-files && unzip -o /tmp/build-files.zip -d /tmp/build-files"
|
||||
case len(archive) >= 262 && string(archive[257:262]) == "ustar":
|
||||
// tar (ustar magic at offset 257)
|
||||
archivePath = "/tmp/build-files.tar"
|
||||
extractCmd = "mkdir -p /tmp/build-files && tar xf /tmp/build-files.tar -C /tmp/build-files"
|
||||
default:
|
||||
// Fallback: try tar.gz
|
||||
archivePath = "/tmp/build-files.tar.gz"
|
||||
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
|
||||
}
|
||||
|
||||
slog.Info("uploading build archive", "build_id", buildID, "path", archivePath, "size", len(archive))
|
||||
|
||||
// Write archive to VM.
|
||||
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: archivePath,
|
||||
Content: archive,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("write archive: %w", err)
|
||||
}
|
||||
|
||||
// Extract and ensure files are readable.
|
||||
fullCmd := extractCmd + " && chmod -R a+rX /tmp/build-files"
|
||||
|
||||
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", fullCmd},
|
||||
TimeoutSec: 120,
|
||||
}))
|
||||
if err != nil {
|
||||
return fmt.Errorf("extract archive: %w", err)
|
||||
}
|
||||
if resp.Msg.ExitCode != 0 {
|
||||
return fmt.Errorf("extract archive: exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtimeEnvVars lists env vars that are user- or session-specific and should
|
||||
// not be persisted into template defaults. These are resolved at runtime by
|
||||
// envd based on the actual user and sandbox context.
|
||||
var runtimeEnvVars = map[string]bool{
|
||||
"HOME": true, "USER": true, "LOGNAME": true, "SHELL": true,
|
||||
"PWD": true, "OLDPWD": true, "HOSTNAME": true, "TERM": true,
|
||||
"SHLVL": true, "_": true,
|
||||
// Per-sandbox identifiers set by envd at boot via MMDS.
|
||||
"WRENN_SANDBOX_ID": true, "WRENN_TEMPLATE_ID": true,
|
||||
}
|
||||
|
||||
// filterBuildEnv returns a copy of envVars with runtime/user-specific
|
||||
// variables removed so they don't override envd's per-user resolution.
|
||||
func filterBuildEnv(envVars map[string]string) map[string]string {
|
||||
filtered := make(map[string]string, len(envVars))
|
||||
for k, v := range envVars {
|
||||
if runtimeEnvVars[k] {
|
||||
continue
|
||||
}
|
||||
filtered[k] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
628
pkg/service/host.go
Normal file
628
pkg/service/host.go
Normal file
@ -0,0 +1,628 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
Pool *lifecycle.HostClientPool
|
||||
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID pgtype.UUID // required for BYOC, zero value for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID pgtype.UUID
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
// HostCreateResult holds the created host and the one-time registration token.
|
||||
type HostCreateResult struct {
|
||||
Host db.Host
|
||||
RegistrationToken string
|
||||
}
|
||||
|
||||
// HostRegisterParams holds the parameters for host agent registration.
|
||||
type HostRegisterParams struct {
|
||||
Token string
|
||||
Arch string
|
||||
CPUCores int32
|
||||
MemoryMB int32
|
||||
DiskGB int32
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
|
||||
// refresh token, and optionally the host's mTLS certificate material.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
|
||||
// refresh, plus refreshed mTLS certificate material when CA is configured.
|
||||
type HostRefreshResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
RefreshToken string
|
||||
// mTLS cert material — empty when CA is not configured.
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
CACertPEM string
|
||||
}
|
||||
|
||||
// HostDeletePreview describes what will be affected by deleting a host.
|
||||
type HostDeletePreview struct {
|
||||
Host db.Host
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
type regTokenPayload struct {
|
||||
HostID string `json:"host_id"`
|
||||
TokenID string `json:"token_id"`
|
||||
}
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||
func requireAdminOrOwner(role string) error {
|
||||
if role == "owner" || role == "admin" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
|
||||
}
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
|
||||
}
|
||||
|
||||
if p.Type == "regular" {
|
||||
if !p.IsRequestorAdmin {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: platform admin, or team owner/admin.
|
||||
if !p.TeamID.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: p.RequestingUserID,
|
||||
TeamID: p.TeamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists, is not deleted, and has BYOC enabled.
|
||||
if p.TeamID.Valid {
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil || team.DeletedAt.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
if !team.IsByoc {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: p.TeamID,
|
||||
Provider: p.Provider,
|
||||
AvailabilityZone: p.AvailabilityZone,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||
}
|
||||
|
||||
// Generate registration token and store in Redis + Postgres audit trail.
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if host.Status != "pending" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
|
||||
if err == redis.Nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
|
||||
}
|
||||
|
||||
var payload regTokenPayload
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
hostID, err := id.ParseHostID(payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
tokenID, err := id.ParseHostTokenID(payload.TokenID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Issue mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: hostID,
|
||||
Arch: p.Arch,
|
||||
CpuCores: p.CPUCores,
|
||||
MemoryMb: p.MemoryMB,
|
||||
DiskGb: p.DiskGB,
|
||||
Address: p.Address,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Issue a long-lived refresh token.
|
||||
refreshToken, err := s.issueRefreshToken(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||
// and returns a fresh JWT plus the new refresh token.
|
||||
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
|
||||
hash := hashToken(refreshToken)
|
||||
|
||||
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
host, err := s.DB.GetHost(ctx, row.HostID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign new JWT.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||
}
|
||||
|
||||
// Renew mTLS certificate if CA is configured.
|
||||
var hc auth.HostCert
|
||||
if s.CA != nil {
|
||||
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
|
||||
}
|
||||
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
|
||||
ID: host.ID,
|
||||
CertFingerprint: hc.Fingerprint,
|
||||
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
|
||||
}); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||
if err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Revoke old refresh token after the new one is safely persisted.
|
||||
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
|
||||
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||
}
|
||||
|
||||
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
|
||||
if s.CA != nil {
|
||||
result.CertPEM = hc.CertPEM
|
||||
result.KeyPEM = hc.KeyPEM
|
||||
result.CACertPEM = s.CA.PEM
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||
// the opaque token string.
|
||||
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
|
||||
token := id.NewRefreshToken()
|
||||
hash := hashToken(token)
|
||||
now := time.Now()
|
||||
|
||||
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
|
||||
ID: id.NewRefreshTokenID(),
|
||||
HostID: hostID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("insert refresh token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// hashToken returns the hex-encoded SHA-256 hash of the token.
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||
// any 'unreachable' host back to 'online'. Returns a "host not found" error
|
||||
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
|
||||
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return fmt.Errorf("host not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// DeletePreview returns what would be affected by deleting the host, without
|
||||
// making any changes. Use this to show the user a confirmation prompt.
|
||||
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
|
||||
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Without force it returns an error listing active
|
||||
// sandboxes so the caller can present a confirmation. With force it gracefully
|
||||
// destroys all running sandboxes before deleting the host record.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
|
||||
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: hostID,
|
||||
Column2: []string{"pending", "starting", "running", "missing"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sandboxes: %w", err)
|
||||
}
|
||||
|
||||
if len(sandboxes) > 0 && !force {
|
||||
ids := make([]string, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
ids[i] = id.FormatSandboxID(sb.ID)
|
||||
}
|
||||
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||
}
|
||||
|
||||
hostIDStr := id.FormatHostID(hostID)
|
||||
|
||||
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
|
||||
if host.Address != "" {
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err == nil {
|
||||
for _, sb := range sandboxes {
|
||||
if sb.Status == "running" || sb.Status == "starting" {
|
||||
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
}))
|
||||
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
|
||||
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tell the agent to shut itself down immediately.
|
||||
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
|
||||
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark all affected sandboxes as stopped in DB.
|
||||
if len(sandboxes) > 0 {
|
||||
sbIDs := make([]pgtype.UUID, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
sbIDs[i] = sb.ID
|
||||
}
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: sbIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke all refresh tokens for this host.
|
||||
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
|
||||
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
|
||||
}
|
||||
|
||||
// Evict the client from the pool so no further RPCs are sent.
|
||||
if s.Pool != nil {
|
||||
s.Pool.Evict(id.FormatHostID(hostID))
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// checkDeletePermission verifies the caller has permission to delete the given
|
||||
// host and returns the host record on success.
|
||||
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if host.Type != "byoc" {
|
||||
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID != teamID {
|
||||
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if err := requireAdminOrOwner(membership.Role); err != nil {
|
||||
return db.Host{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
|
||||
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
|
||||
// and force was not set. The caller should present the list to the user and
|
||||
// re-call Delete with force=true if they confirm.
|
||||
type HostHasSandboxesError struct {
|
||||
SandboxIDs []string
|
||||
}
|
||||
|
||||
func (e *HostHasSandboxesError) Error() string {
|
||||
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
|
||||
}
|
||||
451
pkg/service/sandbox.go
Normal file
451
pkg/service/sandbox.go
Normal file
@ -0,0 +1,451 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// SandboxService provides sandbox lifecycle operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type SandboxService struct {
|
||||
DB *db.Queries
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
}
|
||||
|
||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||
type SandboxCreateParams struct {
|
||||
TeamID pgtype.UUID
|
||||
Template string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
TimeoutSec int32
|
||||
DiskSizeMB int32
|
||||
}
|
||||
|
||||
// agentForSandbox looks up the host for the given sandbox and returns a client.
|
||||
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
host, err := s.DB.GetHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
|
||||
}
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
return agent, sb, nil
|
||||
}
|
||||
|
||||
// hostagentClient is a local alias to avoid the full package path in signatures.
|
||||
type hostagentClient = interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
|
||||
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
|
||||
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
|
||||
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
|
||||
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
|
||||
// DB record, calls the host agent, and updates the record to running.
|
||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||
if p.Template == "" {
|
||||
p.Template = "minimal"
|
||||
}
|
||||
if err := validate.SafeName(p.Template); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid template name: %w", err)
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
if p.DiskSizeMB <= 0 {
|
||||
p.DiskSizeMB = 5120 // 5 GB default
|
||||
}
|
||||
|
||||
// Resolve template name → (teamID, templateID).
|
||||
templateTeamID := id.PlatformTeamID
|
||||
templateID := id.MinimalTemplateID
|
||||
var templateDefaultUser string
|
||||
var templateDefaultEnv map[string]string
|
||||
if p.Template != "minimal" {
|
||||
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
|
||||
}
|
||||
templateTeamID = tmpl.TeamID
|
||||
templateID = tmpl.ID
|
||||
templateDefaultUser = tmpl.DefaultUser
|
||||
// Parse default_env JSONB into a map.
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &templateDefaultEnv)
|
||||
}
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl.Type == "snapshot" {
|
||||
p.VCPUs = tmpl.Vcpus
|
||||
p.MemoryMB = tmpl.MemoryMb
|
||||
}
|
||||
}
|
||||
|
||||
if !p.TeamID.Valid {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
|
||||
}
|
||||
|
||||
// Determine whether this team uses BYOC hosts or platform hosts.
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
|
||||
// Pick a host for this sandbox.
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc, p.MemoryMB, p.DiskSizeMB)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: host.ID,
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
TemplateID: templateID,
|
||||
TemplateTeamID: templateTeamID,
|
||||
Metadata: []byte("{}"),
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: p.Template,
|
||||
TeamId: id.UUIDString(templateTeamID),
|
||||
TemplateId: id.UUIDString(templateID),
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultEnv: templateDefaultEnv,
|
||||
}))
|
||||
if err != nil {
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "error",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err := s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update sandbox running: %w", err)
|
||||
}
|
||||
|
||||
// Store runtime metadata from the agent (envd/kernel/firecracker/agent versions).
|
||||
if meta := resp.Msg.Metadata; len(meta) > 0 {
|
||||
metaJSON, _ := json.Marshal(meta)
|
||||
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID,
|
||||
Metadata: metaJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to store sandbox metadata", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
sb.Metadata = metaJSON
|
||||
}
|
||||
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
|
||||
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
|
||||
return s.DB.ListSandboxesByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single sandbox by ID, scoped to the given team.
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Pause snapshots and freezes a running sandbox to disk.
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
|
||||
// mark the sandbox "stopped" while the host agent processes the pause.
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
|
||||
}
|
||||
|
||||
// Flush all metrics tiers before pausing so data survives in DB.
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
|
||||
|
||||
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
// Revert status on failure.
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||
}
|
||||
|
||||
sb, err = s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
|
||||
}
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Resume restores a paused sandbox from snapshot.
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "paused" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// Look up template defaults for resume.
|
||||
var resumeDefaultUser string
|
||||
var resumeDefaultEnv map[string]string
|
||||
if sb.TemplateID.Valid {
|
||||
tmpl, err := s.DB.GetTemplate(ctx, sb.TemplateID)
|
||||
if err == nil {
|
||||
resumeDefaultUser = tmpl.DefaultUser
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &resumeDefaultEnv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract kernel version hint from existing sandbox metadata.
|
||||
var kernelVersion string
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(sb.Metadata, &meta); err == nil {
|
||||
kernelVersion = meta["kernel_version"]
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
DefaultUser: resumeDefaultUser,
|
||||
DefaultEnv: resumeDefaultEnv,
|
||||
KernelVersion: kernelVersion,
|
||||
}))
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("agent resume: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err = s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
|
||||
// Update metadata with actual versions used after resume.
|
||||
if meta := resp.Msg.Metadata; len(meta) > 0 {
|
||||
metaJSON, _ := json.Marshal(meta)
|
||||
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID,
|
||||
Metadata: metaJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update sandbox metadata after resume", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
sb.Metadata = metaJSON
|
||||
}
|
||||
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox and marks it as stopped.
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
// If running, flush 24h tier metrics for analytics before destroying.
|
||||
if sb.Status == "running" {
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
|
||||
}
|
||||
|
||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
return fmt.Errorf("agent destroy: %w", err)
|
||||
}
|
||||
|
||||
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
|
||||
if sb.Status == "paused" {
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "10m",
|
||||
})
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "2h",
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopped",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
|
||||
// the returned data to DB. If allTiers is true, all three tiers are saved;
|
||||
// otherwise only the 24h tier (for post-destroy analytics).
|
||||
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
|
||||
return
|
||||
}
|
||||
msg := resp.Msg
|
||||
|
||||
if allTiers {
|
||||
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
|
||||
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
|
||||
}
|
||||
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
|
||||
}
|
||||
|
||||
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
for _, p := range points {
|
||||
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
|
||||
SandboxID: sandboxID,
|
||||
Tier: tier,
|
||||
Ts: p.TimestampUnix,
|
||||
CpuPct: p.CpuPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}); err != nil {
|
||||
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ping resets the inactivity timer for a running sandbox.
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
|
||||
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("agent ping: %w", err)
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
160
pkg/service/stats.go
Normal file
160
pkg/service/stats.go
Normal file
@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// TimeRange identifies a chart time window.
|
||||
type TimeRange string
|
||||
|
||||
const (
|
||||
Range5m TimeRange = "5m"
|
||||
Range1h TimeRange = "1h"
|
||||
Range6h TimeRange = "6h"
|
||||
Range24h TimeRange = "24h"
|
||||
Range30d TimeRange = "30d"
|
||||
)
|
||||
|
||||
type rangeConfig struct {
|
||||
bucketSec int // bucket width in seconds for time-series aggregation
|
||||
intervalLiteral string // PostgreSQL interval literal for the lookback window
|
||||
}
|
||||
|
||||
var rangeConfigs = map[TimeRange]rangeConfig{
|
||||
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
|
||||
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
|
||||
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
|
||||
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
|
||||
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
|
||||
}
|
||||
|
||||
// ValidRange returns true if r is a known TimeRange value.
|
||||
func ValidRange(r TimeRange) bool {
|
||||
_, ok := rangeConfigs[r]
|
||||
return ok
|
||||
}
|
||||
|
||||
// StatPoint is one bucketed data point in the time-series.
|
||||
type StatPoint struct {
|
||||
Bucket time.Time
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// CurrentStats holds the live values for a team, read directly from sandboxes.
|
||||
type CurrentStats struct {
|
||||
RunningCount int32
|
||||
VCPUsReserved int32
|
||||
MemoryMBReserved int32
|
||||
}
|
||||
|
||||
// PeakStats holds the 30-day maximum values for a team.
|
||||
type PeakStats struct {
|
||||
RunningCount int32
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
}
|
||||
|
||||
// StatsService computes sandbox metrics for the dashboard.
|
||||
type StatsService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// GetStats returns current stats, 30-day peaks, and a time-series for the
|
||||
// given team and time range. If no snapshots exist yet, zeros are returned.
|
||||
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
|
||||
cfg, ok := rangeConfigs[r]
|
||||
if !ok {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
|
||||
}
|
||||
|
||||
// Current live values — read directly from sandboxes so we always reflect
|
||||
// the true state even when no capsules are running.
|
||||
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
|
||||
}
|
||||
current := CurrentStats{
|
||||
RunningCount: cur.RunningCount,
|
||||
VCPUsReserved: cur.VcpusReserved,
|
||||
MemoryMBReserved: cur.MemoryMbReserved,
|
||||
}
|
||||
|
||||
// 30-day peaks.
|
||||
var peaks PeakStats
|
||||
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
peaks = PeakStats{
|
||||
RunningCount: pk.PeakRunningCount,
|
||||
VCPUs: pk.PeakVcpus,
|
||||
MemoryMB: pk.PeakMemoryMb,
|
||||
}
|
||||
}
|
||||
|
||||
// Time-series — dynamic bucket width, executed via pgx directly.
|
||||
series, err := s.queryTimeSeries(ctx, teamID, cfg)
|
||||
if err != nil {
|
||||
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
|
||||
}
|
||||
|
||||
return current, peaks, series, nil
|
||||
}
|
||||
|
||||
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
|
||||
// integer number of seconds without requiring TimescaleDB.
|
||||
//
|
||||
// MAX is used instead of AVG so that short-lived running states are not
|
||||
// averaged down to zero within a bucket. For capacity metrics the peak
|
||||
// value in each bucket is what matters — AVG with ::INTEGER rounding
|
||||
// caused running_count, vcpus, and memory to become inconsistent with
|
||||
// each other (e.g. running=0 but vcpus=1).
|
||||
//
|
||||
// $1 = bucket width in seconds (integer)
|
||||
// $2 = team_id
|
||||
// $3 = lookback interval literal (e.g. '1 hour')
|
||||
const timeSeriesSQL = `
|
||||
SELECT
|
||||
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
|
||||
MAX(running_count) AS running_count,
|
||||
MAX(vcpus_reserved) AS vcpus_reserved,
|
||||
MAX(memory_mb_reserved) AS memory_mb_reserved
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $2
|
||||
AND sampled_at >= NOW() - $3::INTERVAL
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC
|
||||
`
|
||||
|
||||
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
|
||||
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var points []StatPoint
|
||||
for rows.Next() {
|
||||
var p StatPoint
|
||||
var bucket time.Time
|
||||
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Bucket = bucket
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
544
pkg/service/team.go
Normal file
544
pkg/service/team.go
Normal file
@ -0,0 +1,544 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
|
||||
|
||||
// TeamService provides team management operations.
|
||||
type TeamService struct {
|
||||
DB *db.Queries
|
||||
Pool *pgxpool.Pool
|
||||
HostPool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
// TeamWithRole pairs a team with the calling user's role in it.
|
||||
type TeamWithRole struct {
|
||||
db.Team
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// MemberInfo is a team member with resolved user details.
|
||||
type MemberInfo struct {
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
}
|
||||
|
||||
// callerRole fetches the calling user's role in the given team from DB.
|
||||
// Returns an error wrapping "forbidden" if the caller is not a member.
|
||||
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
|
||||
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: callerUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("forbidden: not a member of this team")
|
||||
}
|
||||
return "", fmt.Errorf("get membership: %w", err)
|
||||
}
|
||||
return m.Role, nil
|
||||
}
|
||||
|
||||
// requireAdmin returns an error if the caller is not an admin or owner.
|
||||
func requireAdmin(role string) error {
|
||||
if role != "owner" && role != "admin" {
|
||||
return fmt.Errorf("forbidden: admin or owner role required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
|
||||
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return db.Team{}, fmt.Errorf("get team: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return db.Team{}, fmt.Errorf("team not found")
|
||||
}
|
||||
return team, nil
|
||||
}
|
||||
|
||||
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
|
||||
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
|
||||
rows, err := s.DB.GetTeamsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list teams: %w", err)
|
||||
}
|
||||
result := make([]TeamWithRole, len(rows))
|
||||
for i, r := range rows {
|
||||
result[i] = TeamWithRole{
|
||||
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
|
||||
Role: r.Role,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreateTeam creates a new team owned by the given user.
|
||||
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
|
||||
if !teamNameRE.MatchString(name) {
|
||||
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
tx, err := s.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := s.DB.WithTx(tx)
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: name,
|
||||
Slug: id.NewTeamSlug(),
|
||||
})
|
||||
if err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: ownerUserID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return TeamWithRole{Team: team, Role: "owner"}, nil
|
||||
}
|
||||
|
||||
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
|
||||
if !teamNameRE.MatchString(newName) {
|
||||
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
|
||||
}
|
||||
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
|
||||
return fmt.Errorf("update name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
|
||||
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
|
||||
// are preserved; only the team's deleted_at is set and active VMs are stopped.
|
||||
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role != "owner" {
|
||||
return fmt.Errorf("forbidden: only the owner can delete a team")
|
||||
}
|
||||
|
||||
return s.deleteTeamCore(ctx, teamID)
|
||||
}
|
||||
|
||||
// deleteTeamCore contains the shared team deletion logic:
|
||||
// destroy active sandboxes, clean up templates, soft-delete the team.
|
||||
func (s *TeamService) deleteTeamCore(ctx context.Context, teamID pgtype.UUID) error {
|
||||
// Collect active sandboxes and stop them.
|
||||
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list active sandboxes: %w", err)
|
||||
}
|
||||
|
||||
var stopIDs []pgtype.UUID
|
||||
for _, sb := range sandboxes {
|
||||
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
|
||||
if hostErr == nil {
|
||||
agent, agentErr := s.HostPool.GetForHost(host)
|
||||
if agentErr == nil {
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: id.FormatSandboxID(sb.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
stopIDs = append(stopIDs, sb.ID)
|
||||
}
|
||||
|
||||
if len(stopIDs) > 0 {
|
||||
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: stopIDs,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
|
||||
// as that would leave orphaned "running" records for a deleted team.
|
||||
return fmt.Errorf("update sandbox statuses: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete sandbox metrics for this team.
|
||||
if err := s.DB.DeleteMetricPointsByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete metric points", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
if err := s.DB.DeleteMetricsSnapshotsByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete metrics snapshots", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
|
||||
// Delete all API keys for this team.
|
||||
if err := s.DB.DeleteAPIKeysByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete API keys", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
|
||||
// Delete all channels for this team.
|
||||
if err := s.DB.DeleteAllChannelsByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete channels", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
|
||||
// Clean up team-owned templates from all hosts in the background.
|
||||
go s.cleanupTeamTemplates(context.Background(), teamID)
|
||||
|
||||
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
|
||||
return fmt.Errorf("soft delete team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
|
||||
// then removes the DB records. Called asynchronously during team deletion.
|
||||
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
|
||||
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
return
|
||||
}
|
||||
if len(templates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
hosts, err := s.DB.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range templates {
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := s.HostPool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: id.UUIDString(tmpl.TeamID),
|
||||
TemplateId: id.UUIDString(tmpl.ID),
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("team delete: failed to delete template on host",
|
||||
"host_id", id.FormatHostID(host.ID),
|
||||
"template", tmpl.Name,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove DB records.
|
||||
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
|
||||
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMembers returns all members of the team with their emails and roles.
|
||||
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
|
||||
rows, err := s.DB.GetTeamMembers(ctx, teamID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get members: %w", err)
|
||||
}
|
||||
members := make([]MemberInfo, len(rows))
|
||||
for i, r := range rows {
|
||||
var joinedAt time.Time
|
||||
if r.JoinedAt.Valid {
|
||||
joinedAt = r.JoinedAt.Time
|
||||
}
|
||||
members[i] = MemberInfo{
|
||||
UserID: id.FormatUserID(r.ID),
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Role: r.Role,
|
||||
JoinedAt: joinedAt,
|
||||
}
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// AddMember adds an existing user (looked up by email) to the team as a member.
|
||||
// Caller must be admin or owner (verified from DB).
|
||||
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
if err := requireAdmin(role); err != nil {
|
||||
return MemberInfo{}, err
|
||||
}
|
||||
|
||||
target, err := s.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
|
||||
}
|
||||
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
|
||||
}
|
||||
|
||||
// Check if already a member.
|
||||
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if memberCheckErr == nil {
|
||||
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
|
||||
} else if memberCheckErr != pgx.ErrNoRows {
|
||||
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
|
||||
}
|
||||
|
||||
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: target.ID,
|
||||
TeamID: teamID,
|
||||
IsDefault: false,
|
||||
Role: "member",
|
||||
}); err != nil {
|
||||
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
|
||||
}
|
||||
|
||||
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
|
||||
}
|
||||
|
||||
// RemoveMember removes a user from the team.
|
||||
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
|
||||
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("delete member: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateMemberRole changes a member's role to admin or member.
|
||||
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
|
||||
// Valid target roles: "admin", "member".
|
||||
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
|
||||
if newRole != "admin" && newRole != "member" {
|
||||
return fmt.Errorf("invalid: role must be admin or member")
|
||||
}
|
||||
|
||||
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := requireAdmin(callerRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: targetUserID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return fmt.Errorf("not found: user is not a member of this team")
|
||||
}
|
||||
return fmt.Errorf("get target membership: %w", err)
|
||||
}
|
||||
|
||||
if targetMembership.Role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner's role cannot be changed")
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
|
||||
TeamID: teamID,
|
||||
UserID: targetUserID,
|
||||
Role: newRole,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeaveTeam removes the calling user from the team.
|
||||
// The owner cannot leave; they must delete the team instead.
|
||||
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
|
||||
role, err := s.callerRole(ctx, teamID, callerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role == "owner" {
|
||||
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
|
||||
}
|
||||
|
||||
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
|
||||
TeamID: teamID,
|
||||
UserID: callerUserID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("leave team: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
|
||||
// be disabled — it is a one-way transition.
|
||||
// Admin-only — the caller must verify admin status before invoking this.
|
||||
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
if !enabled {
|
||||
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
|
||||
}
|
||||
if team.IsByoc {
|
||||
// Already enabled — idempotent, no-op.
|
||||
return nil
|
||||
}
|
||||
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
|
||||
return fmt.Errorf("set byoc: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminTeamRow is the shape returned by AdminListTeams.
|
||||
type AdminTeamRow struct {
|
||||
ID pgtype.UUID
|
||||
Name string
|
||||
Slug string
|
||||
IsByoc bool
|
||||
CreatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
MemberCount int32
|
||||
OwnerName string
|
||||
OwnerEmail string
|
||||
ActiveSandboxCount int32
|
||||
ChannelCount int32
|
||||
}
|
||||
|
||||
// AdminListTeams returns a paginated list of all teams (excluding the platform
|
||||
// team) with member counts, owner info, and active sandbox counts.
|
||||
// Admin-only — caller must verify admin status.
|
||||
func (s *TeamService) AdminListTeams(ctx context.Context, limit, offset int32) ([]AdminTeamRow, int32, error) {
|
||||
teams, err := s.DB.ListTeamsAdmin(ctx, db.ListTeamsAdminParams{
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list teams: %w", err)
|
||||
}
|
||||
|
||||
total, err := s.DB.CountTeamsAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count teams: %w", err)
|
||||
}
|
||||
|
||||
rows := make([]AdminTeamRow, len(teams))
|
||||
for i, t := range teams {
|
||||
row := AdminTeamRow{
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
CreatedAt: t.CreatedAt.Time,
|
||||
MemberCount: t.MemberCount,
|
||||
OwnerName: t.OwnerName,
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
}
|
||||
if t.DeletedAt.Valid {
|
||||
deletedAt := t.DeletedAt.Time
|
||||
row.DeletedAt = &deletedAt
|
||||
}
|
||||
rows[i] = row
|
||||
}
|
||||
return rows, total, nil
|
||||
}
|
||||
|
||||
// DeleteTeamInternal soft-deletes a team and destroys all its active sandboxes.
|
||||
// Used for system-initiated deletions (e.g. cascading from user account deletion)
|
||||
// where no caller role check is needed.
|
||||
func (s *TeamService) DeleteTeamInternal(ctx context.Context, teamID pgtype.UUID) error {
|
||||
return s.deleteTeamCore(ctx, teamID)
|
||||
}
|
||||
|
||||
// AdminDeleteTeam soft-deletes a team and destroys all its active sandboxes.
|
||||
// Unlike DeleteTeam, this does not require the caller to be the team owner —
|
||||
// it is admin-only (caller must verify admin status).
|
||||
func (s *TeamService) AdminDeleteTeam(ctx context.Context, teamID pgtype.UUID) error {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
|
||||
return s.deleteTeamCore(ctx, teamID)
|
||||
}
|
||||
27
pkg/service/template.go
Normal file
27
pkg/service/template.go
Normal file
@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// TemplateService provides template/snapshot operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type TemplateService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns all templates belonging to the given team. If typeFilter is
|
||||
// non-empty, only templates of that type ("base" or "snapshot") are returned.
|
||||
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
|
||||
if typeFilter != "" {
|
||||
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
|
||||
TeamID: teamID,
|
||||
Type: typeFilter,
|
||||
})
|
||||
}
|
||||
return s.DB.ListTemplatesByTeam(ctx, teamID)
|
||||
}
|
||||
107
pkg/service/user.go
Normal file
107
pkg/service/user.go
Normal file
@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// UserService provides user management operations.
|
||||
type UserService struct {
|
||||
DB *db.Queries
|
||||
SandboxSvc *SandboxService
|
||||
}
|
||||
|
||||
// AdminUserRow is the shape returned by AdminListUsers.
|
||||
type AdminUserRow struct {
|
||||
ID pgtype.UUID
|
||||
Email string
|
||||
Name string
|
||||
IsAdmin bool
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
TeamsJoined int32
|
||||
TeamsOwned int32
|
||||
}
|
||||
|
||||
// AdminListUsers returns a paginated list of all non-deleted users with team counts.
|
||||
func (s *UserService) AdminListUsers(ctx context.Context, limit, offset int32) ([]AdminUserRow, int32, error) {
|
||||
users, err := s.DB.ListUsersAdmin(ctx, db.ListUsersAdminParams{
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list users: %w", err)
|
||||
}
|
||||
|
||||
total, err := s.DB.CountUsersAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count users: %w", err)
|
||||
}
|
||||
|
||||
rows := make([]AdminUserRow, len(users))
|
||||
for i, u := range users {
|
||||
rows[i] = AdminUserRow{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
IsAdmin: u.IsAdmin,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt.Time,
|
||||
TeamsJoined: u.TeamsJoined,
|
||||
TeamsOwned: u.TeamsOwned,
|
||||
}
|
||||
}
|
||||
return rows, total, nil
|
||||
}
|
||||
|
||||
// SetUserStatus sets the status of a user account.
|
||||
func (s *UserService) SetUserStatus(ctx context.Context, userID pgtype.UUID, status string) error {
|
||||
if err := s.DB.SetUserStatus(ctx, db.SetUserStatusParams{
|
||||
ID: userID,
|
||||
Status: status,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("set user status: %w", err)
|
||||
}
|
||||
if status == "disabled" || status == "deleted" {
|
||||
if err := s.DB.DeleteAPIKeysByCreator(ctx, userID); err != nil {
|
||||
slog.Warn("failed to delete API keys for deactivated user", "user_id", userID, "error", err)
|
||||
}
|
||||
s.destroySandboxesForOwnedTeams(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// destroySandboxesForOwnedTeams destroys all active sandboxes (running, paused,
|
||||
// hibernated, starting) for every team the user owns. Best-effort: errors are
|
||||
// logged but do not prevent the user from being disabled.
|
||||
func (s *UserService) destroySandboxesForOwnedTeams(ctx context.Context, userID pgtype.UUID) {
|
||||
if s.SandboxSvc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
teamIDs, err := s.DB.GetOwnedTeamIDs(ctx, userID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list owned teams for sandbox cleanup", "user_id", userID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, teamID := range teamIDs {
|
||||
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list active sandboxes for team", "team_id", teamID, "user_id", userID, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, sb := range sandboxes {
|
||||
if err := s.SandboxSvc.Destroy(ctx, sb.ID, teamID); err != nil {
|
||||
slog.Warn("failed to destroy sandbox during user disable",
|
||||
"sandbox_id", sb.ID, "team_id", teamID, "user_id", userID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user