forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -38,17 +38,56 @@ func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
|
||||
}
|
||||
}
|
||||
|
||||
// publishTransient mirrors an event on the SSE Pub/Sub channel only.
|
||||
func (l *AuditLogger) publishTransient(ctx context.Context, e events.Event) {
|
||||
if l.pub != nil {
|
||||
l.pub.PublishTransient(ctx, e)
|
||||
}
|
||||
}
|
||||
|
||||
// outcomeFromErr returns OutcomeSuccess when err is nil, OutcomeError otherwise.
|
||||
func outcomeFromErr(err error) events.Outcome {
|
||||
if err != nil {
|
||||
return events.OutcomeError
|
||||
}
|
||||
return events.OutcomeSuccess
|
||||
}
|
||||
|
||||
// auditStatusFor maps an error and success-status into the audit row status.
|
||||
// On error → "error"; otherwise the supplied success status (e.g. "success", "warning", "info").
|
||||
func auditStatusFor(err error, okStatus string) string {
|
||||
if err != nil {
|
||||
return "error"
|
||||
}
|
||||
return okStatus
|
||||
}
|
||||
|
||||
func errString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// mergeMeta returns a new map with err added when non-nil, preserving caller fields.
|
||||
func mergeMeta(base map[string]any, err error) map[string]any {
|
||||
if err == nil {
|
||||
return base
|
||||
}
|
||||
out := make(map[string]any, len(base)+1)
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
out["error"] = err.Error()
|
||||
return out
|
||||
}
|
||||
|
||||
// actorToEvent converts auth context fields to an events.Actor.
|
||||
func actorToEvent(ac auth.AuthContext) events.Actor {
|
||||
at, aid, aname := actorFields(ac)
|
||||
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
|
||||
}
|
||||
|
||||
// systemActor returns an events.Actor for system-initiated events.
|
||||
func systemActor() events.Actor {
|
||||
return events.Actor{Type: events.ActorSystem}
|
||||
}
|
||||
|
||||
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
|
||||
// actor_id is stored as a prefixed string in the TEXT column.
|
||||
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
|
||||
@ -171,87 +210,240 @@ func resolveHostTeamID(teamID pgtype.UUID) pgtype.UUID {
|
||||
|
||||
// --- Sandbox events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "create", "success", map[string]any{"template": template}))
|
||||
// LogSandboxCreate records the result of a first-boot sandbox creation. err
|
||||
// nil ⇒ success; non-nil ⇒ error. Writes audit row and publishes a
|
||||
// capsule.create event with the derived outcome.
|
||||
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string, err error) {
|
||||
meta := map[string]any{"template": template}
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "create", auditStatusFor(err, "success"), mergeMeta(meta, err)))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleCreated,
|
||||
Event: events.CapsuleCreate,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"template": template},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "pause", "success", nil))
|
||||
// LogSandboxPause records a user-initiated pause.
|
||||
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, err error) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "pause", auditStatusFor(err, "success"), mergeMeta(nil, err)))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Event: events.CapsulePause,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
|
||||
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
|
||||
// LogSandboxAutoPause records a system-initiated auto-pause (TTL reaper or
|
||||
// reconciler restoration of paused state). Always system actor; metadata
|
||||
// carries the reason (e.g. "ttl_expired", "restored_after_host_recovery").
|
||||
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID, reason string, err error) {
|
||||
meta := map[string]any{"reason": reason}
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "sandbox", ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: "pause", Scope: "team", Status: "info",
|
||||
Action: "pause", Scope: "team", Status: auditStatusFor(err, "info"),
|
||||
Metadata: mergeMeta(meta, err),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsulePaused,
|
||||
Event: events.CapsulePause,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"reason": reason},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "resume", "success", nil))
|
||||
// LogSandboxResume records a user-initiated unpause (resume from paused state).
|
||||
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, err error) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "resume", auditStatusFor(err, "success"), mergeMeta(nil, err)))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleRunning,
|
||||
Event: events.CapsuleResume,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "destroy", "warning", nil))
|
||||
// LogSandboxDestroy records a destroy action. ac carries the actor (user / api_key / system).
|
||||
// reason is added to metadata when non-empty (e.g. "orphaned", "cleanup_after_create_error", "ttl_expired").
|
||||
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, err error) {
|
||||
l.LogSandboxDestroyWithReason(ctx, ac, sandboxID, "", err)
|
||||
}
|
||||
|
||||
// LogSandboxDestroyWithReason is LogSandboxDestroy with an explicit reason.
|
||||
func (l *AuditLogger) LogSandboxDestroyWithReason(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, reason string, err error) {
|
||||
var (
|
||||
auditMeta map[string]any
|
||||
evtMeta map[string]string
|
||||
)
|
||||
if reason != "" {
|
||||
auditMeta = map[string]any{"reason": reason}
|
||||
evtMeta = map[string]string{"reason": reason}
|
||||
}
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "sandbox", id.FormatSandboxID(sandboxID), "destroy", auditStatusFor(err, "warning"), mergeMeta(auditMeta, err)))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleDestroyed,
|
||||
Event: events.CapsuleDestroy,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: evtMeta,
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxCreateSystem records a system-derived create outcome (e.g. the
|
||||
// reconciler inferring a failed first-boot after the grace period expired).
|
||||
// reason is added to metadata; err controls outcome.
|
||||
func (l *AuditLogger) LogSandboxCreateSystem(ctx context.Context, teamID, sandboxID pgtype.UUID, reason string, err error) {
|
||||
meta := map[string]any{"reason": reason}
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "sandbox", ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: "create", Scope: "team", Status: auditStatusFor(err, "info"),
|
||||
Metadata: mergeMeta(meta, err),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleCreate,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"reason": reason},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxResumeSystem records a system-derived resume outcome (typically
|
||||
// reconciler-inferred error after the grace period).
|
||||
func (l *AuditLogger) LogSandboxResumeSystem(ctx context.Context, teamID, sandboxID pgtype.UUID, reason string, err error) {
|
||||
meta := map[string]any{"reason": reason}
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "sandbox", ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: "resume", Scope: "team", Status: auditStatusFor(err, "info"),
|
||||
Metadata: mergeMeta(meta, err),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleResume,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"reason": reason},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxDestroySystem records a system-initiated destroy (orphan cleanup,
|
||||
// cleanup-on-error, reconciler grace-period expiry). Always system actor.
|
||||
func (l *AuditLogger) LogSandboxDestroySystem(ctx context.Context, teamID, sandboxID pgtype.UUID, reason string, err error) {
|
||||
meta := map[string]any{"reason": reason}
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "sandbox", ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: "destroy", Scope: "team", Status: auditStatusFor(err, "warning"),
|
||||
Metadata: mergeMeta(meta, err),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.CapsuleDestroy,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"reason": reason},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSandboxStateChanged is a transient (SSE-only) event for ephemeral status
|
||||
// transitions (e.g. running → pausing → paused). Writes no audit row.
|
||||
func (l *AuditLogger) LogSandboxStateChanged(ctx context.Context, teamID, sandboxID pgtype.UUID, from, to string) {
|
||||
l.publishTransient(ctx, events.Event{
|
||||
Event: events.CapsuleStateChanged,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
|
||||
Metadata: map[string]string{"from": from, "to": to},
|
||||
})
|
||||
}
|
||||
|
||||
// --- Snapshot events (scope: team) ---
|
||||
|
||||
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
// LogSnapshotCreateRequested records that a user requested an async snapshot.
|
||||
// It writes the user-attributed audit row only — the terminal success/failure
|
||||
// event is published later by the background goroutine (system actor). Mirrors
|
||||
// the accept-time audit pattern used by LogSandboxPause.
|
||||
func (l *AuditLogger) LogSnapshotCreateRequested(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "snapshot", name, "create", "success", nil))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotCreated,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
}
|
||||
|
||||
// LogSnapshotCreateSystem records a system-actor snapshot transition inferred
|
||||
// by a reconciler (e.g. the HostMonitor recovering or failing a sandbox stuck
|
||||
// in "snapshotting"). It writes an audit row only and does NOT publish a
|
||||
// SnapshotCreate event: the reconciler has no template name, and emitting one
|
||||
// would surface a spurious "snapshot captured/failed" toast.
|
||||
func (l *AuditLogger) LogSnapshotCreateSystem(ctx context.Context, teamID, sandboxID pgtype.UUID, reason string, err error) {
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "sandbox", ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: "snapshot", Scope: "team", Status: auditStatusFor(err, "info"),
|
||||
Metadata: mergeMeta(map[string]any{"reason": reason}, err),
|
||||
})
|
||||
}
|
||||
|
||||
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "snapshot", name, "delete", "warning", nil))
|
||||
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string, err error) {
|
||||
l.Log(ctx, newEntry(ac, ac.TeamID, "team", "snapshot", name, "delete", auditStatusFor(err, "warning"), mergeMeta(nil, err)))
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotDeleted,
|
||||
Event: events.SnapshotDelete,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Actor: actorToEvent(ac),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
// LogSnapshotDeleteSystem records system-initiated snapshot cleanup
|
||||
// (e.g. rollback after a failed snapshot create). Always system actor.
|
||||
func (l *AuditLogger) LogSnapshotDeleteSystem(ctx context.Context, teamID pgtype.UUID, name, reason string, err error) {
|
||||
meta := map[string]any{"reason": reason}
|
||||
l.Log(ctx, Entry{
|
||||
TeamID: teamID, ActorType: "system",
|
||||
ResourceType: "snapshot", ResourceID: name,
|
||||
Action: "delete", Scope: "team", Status: auditStatusFor(err, "warning"),
|
||||
Metadata: mergeMeta(meta, err),
|
||||
})
|
||||
l.publish(ctx, events.Event{
|
||||
Event: events.SnapshotDelete,
|
||||
Outcome: outcomeFromErr(err),
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: name, Type: "snapshot"},
|
||||
Metadata: map[string]string{"reason": reason},
|
||||
Error: errString(err),
|
||||
})
|
||||
}
|
||||
|
||||
@ -350,7 +542,7 @@ func (l *AuditLogger) logSystemHostEvent(ctx context.Context, teamID, hostID pgt
|
||||
Event: ev,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: systemActor(),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
|
||||
})
|
||||
}
|
||||
|
||||
@ -17,9 +17,10 @@ type AuthContext struct {
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
|
||||
IsAdmin bool // session-cached flag; admin gates always re-verify against the DB
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for session auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for session auth
|
||||
SessionID string // populated for cookie-session auth; empty for API key auth
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
|
||||
@ -10,62 +10,9 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: id.FormatUserID(userID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return Claims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
|
||||
292
pkg/auth/session/middleware/middleware.go
Normal file
292
pkg/auth/session/middleware/middleware.go
Normal file
@ -0,0 +1,292 @@
|
||||
// Package middleware exposes the session/CSRF middleware and cookie helpers
|
||||
// that gate the browser-facing control plane API. It is the single source of
|
||||
// truth — both internal/api and cloud extensions call into this package so
|
||||
// auth semantics never diverge.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// Cookie + header names. Exported so extensions and frontends can reference
|
||||
// the canonical values instead of hardcoding strings.
|
||||
const (
|
||||
SessionCookieName = "wrenn_sid"
|
||||
CSRFCookieName = "wrenn_csrf"
|
||||
CSRFHeaderName = "X-CSRF-Token"
|
||||
)
|
||||
|
||||
type errorBody struct {
|
||||
Error errorDetail `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetail struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(errorBody{Error: errorDetail{Code: code, Message: message}})
|
||||
}
|
||||
|
||||
// IsSecure reports whether the inbound request should produce Secure cookies.
|
||||
// Honors X-Forwarded-Proto for deployments behind TLS-terminating proxies.
|
||||
func IsSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
|
||||
// SetCookies writes both the opaque session-id cookie (HttpOnly) and the
|
||||
// JS-readable CSRF cookie used for double-submit validation.
|
||||
func SetCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
|
||||
maxAge := int(session.AbsoluteCap.Seconds())
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: sid,
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CSRFCookieName,
|
||||
Value: csrfToken,
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearCookies invalidates the session and CSRF cookies on the response.
|
||||
func ClearCookies(w http.ResponseWriter, secure bool) {
|
||||
for _, name := range []string{SessionCookieName, CSRFCookieName} {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: name == SessionCookieName,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveSession reads the session cookie and returns the hydrated session,
|
||||
// or session.ErrNotFound / session.ErrExpired on failure.
|
||||
func ResolveSession(ctx context.Context, queries *db.Queries, svc *session.Service, r *http.Request) (*session.Session, error) {
|
||||
cookie, err := r.Cookie(SessionCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
return nil, session.ErrNotFound
|
||||
}
|
||||
return svc.Get(ctx, cookie.Value, hydrateFromDB(queries))
|
||||
}
|
||||
|
||||
func hydrateFromDB(queries *db.Queries) func(context.Context, *session.Session) error {
|
||||
return func(ctx context.Context, sess *session.Session) error {
|
||||
user, err := queries.GetUserByID(ctx, sess.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return errors.New("account not active")
|
||||
}
|
||||
membership, err := queries.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: sess.UserID,
|
||||
TeamID: sess.TeamID,
|
||||
})
|
||||
role := ""
|
||||
if err == nil {
|
||||
role = membership.Role
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
sess.Email = user.Email
|
||||
sess.Name = user.Name
|
||||
sess.Role = role
|
||||
sess.IsAdmin = user.IsAdmin
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AuthContextFromSession builds the AuthContext middleware stamps into the
|
||||
// request context after a successful session lookup.
|
||||
func AuthContextFromSession(sess *session.Session) auth.AuthContext {
|
||||
return auth.AuthContext{
|
||||
TeamID: sess.TeamID,
|
||||
UserID: sess.UserID,
|
||||
Email: sess.Email,
|
||||
Name: sess.Name,
|
||||
Role: sess.Role,
|
||||
IsAdmin: sess.IsAdmin,
|
||||
SessionID: sess.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateAPIKey validates an X-API-Key value and returns a request
|
||||
// context carrying the API-key-scoped AuthContext on success.
|
||||
func AuthenticateAPIKey(ctx context.Context, queries *db.Queries, key, ip string) (context.Context, bool) {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(ctx, hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", ip)
|
||||
return ctx, false
|
||||
}
|
||||
if err := queries.UpdateAPIKeyLastUsed(ctx, row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
return auth.WithAuthContext(ctx, auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
}), true
|
||||
}
|
||||
|
||||
// RequireSession returns middleware that allows only requests carrying a
|
||||
// valid session cookie. On failure it clears stale cookies and responds 401.
|
||||
func RequireSession(svc *session.Service, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := ResolveSession(r.Context(), queries, svc, r)
|
||||
if err != nil {
|
||||
ClearCookies(w, IsSecure(r))
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "valid session required")
|
||||
return
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), AuthContextFromSession(sess))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSessionOrAPIKey accepts X-API-Key (SDK) or wrenn_sid cookie (browser).
|
||||
func RequireSessionOrAPIKey(svc *session.Service, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
if ctx, ok := AuthenticateAPIKey(r.Context(), queries, key, r.RemoteAddr); ok {
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
sess, err := ResolveSession(r.Context(), queries, svc, r)
|
||||
if err != nil {
|
||||
ClearCookies(w, IsSecure(r))
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header or session cookie required")
|
||||
return
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), AuthContextFromSession(sess))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdmin enforces that the authenticated user is a platform admin.
|
||||
// Must run after RequireSession. Re-reads is_admin from Postgres so a freshly
|
||||
// revoked admin loses access on the next request — the cached session blob is
|
||||
// only used for UI hints, never authorization.
|
||||
func RequireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
user, err := queries.GetUserByID(r.Context(), ac.UserID)
|
||||
if err != nil || !user.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireCSRF returns middleware enforcing double-submit CSRF: the wrenn_csrf
|
||||
// cookie value must equal the X-CSRF-Token header. Skipped for safe methods
|
||||
// (GET/HEAD/OPTIONS) and for requests authenticated via X-API-Key.
|
||||
func RequireCSRF() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
cookie, err := r.Cookie(CSRFCookieName)
|
||||
header := r.Header.Get(CSRFHeaderName)
|
||||
if err != nil || cookie.Value == "" || header == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
|
||||
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IssueSession looks up identity columns, creates a fresh session, and writes
|
||||
// the cookies onto the response. Intended for extension flows (invite-accept,
|
||||
// admin impersonation, etc.) that need to log a user in without re-implementing
|
||||
// the cookie wire-up.
|
||||
func IssueSession(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
queries *db.Queries,
|
||||
svc *session.Service,
|
||||
userID, teamID pgtype.UUID,
|
||||
) (*session.Session, error) {
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
role := ""
|
||||
membership, err := queries.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: teamID})
|
||||
if err == nil {
|
||||
role = membership.Role
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
sess, err := svc.Create(ctx, userID, teamID, user.Email, user.Name, role, user.IsAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
SetCookies(w, sess.RawSID, sess.CSRFToken, IsSecure(r))
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
if i := strings.IndexByte(fwd, ','); i > 0 {
|
||||
return strings.TrimSpace(fwd[:i])
|
||||
}
|
||||
return strings.TrimSpace(fwd)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
443
pkg/auth/session/session.go
Normal file
443
pkg/auth/session/session.go
Normal file
@ -0,0 +1,443 @@
|
||||
// Package session implements opaque cookie-backed user sessions for the
|
||||
// browser-facing control plane. Sessions are stored durably in Postgres
|
||||
// (sessions table) and cached in Redis (wrenn:session:{sid}) for the hot
|
||||
// auth-middleware path.
|
||||
//
|
||||
// SIDs are 32 random bytes hex-encoded. CSRF tokens are issued alongside
|
||||
// each session and rotated on session rotation (e.g. team switch).
|
||||
//
|
||||
// Expiry has two limits:
|
||||
// - Idle: IdleWindow (6h) — Redis TTL slides on each successful Get.
|
||||
// - Absolute: AbsoluteCap (24h) — stored as expires_at; never extended.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
const (
|
||||
// IdleWindow caps how long a session can sit idle before it expires.
|
||||
IdleWindow = 6 * time.Hour
|
||||
// AbsoluteCap caps total session lifetime regardless of activity.
|
||||
AbsoluteCap = 24 * time.Hour
|
||||
// touchDBInterval is the minimum gap between Postgres last_seen_at updates
|
||||
// for the same session. Redis TTL is bumped on every request; the DB is
|
||||
// only updated when stale by more than this interval.
|
||||
touchDBInterval = 1 * time.Minute
|
||||
|
||||
redisKeyPrefix = "wrenn:session:"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when no session exists for the given SID.
|
||||
var ErrNotFound = errors.New("session: not found")
|
||||
|
||||
// ErrExpired is returned when a session is past its absolute cap.
|
||||
var ErrExpired = errors.New("session: expired")
|
||||
|
||||
// Session is the in-memory representation of a logged-in user. The fields
|
||||
// after the identity block are denormalized from the users + team_members
|
||||
// tables for fast middleware lookups; they are refreshed on rotation and
|
||||
// invalidated by Revoke/RevokeAllForUser on identity changes.
|
||||
//
|
||||
// ID is the sha256(rawSID) hex digest — the value stored in Postgres and
|
||||
// used as the Redis cache key. RawSID is the un-hashed bearer secret;
|
||||
// it is only populated by Create and Rotate so the caller can write the
|
||||
// cookie, and is never serialized to Redis or persisted in Postgres.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
RawSID string `json:"-"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
CSRFToken string `json:"csrf"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IPAddress string `json:"ip"`
|
||||
}
|
||||
|
||||
// Service issues, validates, and revokes sessions.
|
||||
type Service struct {
|
||||
db *db.Queries
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewService constructs a session service backed by the given queries and
|
||||
// Redis client.
|
||||
func NewService(q *db.Queries, rdb *redis.Client) *Service {
|
||||
return &Service{db: q, rdb: rdb}
|
||||
}
|
||||
|
||||
// GenerateSID returns a fresh 32-byte hex-encoded session identifier.
|
||||
func GenerateSID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GenerateCSRFToken returns a fresh 32-byte hex-encoded CSRF token.
|
||||
func GenerateCSRFToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// HashSID returns the sha256 hex digest of a raw session ID. Storage and
|
||||
// lookups in Postgres + Redis use the hash; the raw value only lives in
|
||||
// the user's cookie and transiently in this process.
|
||||
func HashSID(rawSID string) string {
|
||||
sum := sha256.Sum256([]byte(rawSID))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Create issues a new session for the given user. Email, name, role and
|
||||
// is_admin are stamped into the session blob and used by middleware without
|
||||
// further DB lookups (except for admin gates, which always re-check the DB).
|
||||
func (s *Service) Create(
|
||||
ctx context.Context,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
rawSID, err := GenerateSID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate sid: %w", err)
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
csrf, err := GenerateCSRFToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate csrf: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(AbsoluteCap)
|
||||
|
||||
row, err := s.db.InsertSession(ctx, db.InsertSessionParams{
|
||||
ID: sidHash,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
CsrfToken: csrf,
|
||||
UserAgent: userAgent,
|
||||
IpAddress: ipAddress,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
RawSID: rawSID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
// Cache failures are non-fatal — middleware will fall back to DB.
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Get loads a session by its raw SID (from the cookie), validates expiry,
|
||||
// and slides the idle window. The raw SID is hashed internally; storage and
|
||||
// lookups never see the un-hashed value. Returns ErrNotFound if the session
|
||||
// does not exist (or has been revoked) and ErrExpired if it is past its
|
||||
// absolute cap.
|
||||
//
|
||||
// The hydrate callback is invoked on cache miss to refetch identity columns
|
||||
// (email, name, role, is_admin) from the source tables before the session is
|
||||
// repopulated into Redis. Pass nil to skip identity refresh.
|
||||
func (s *Service) Get(
|
||||
ctx context.Context,
|
||||
rawSID string,
|
||||
hydrate func(context.Context, *Session) error,
|
||||
) (*Session, error) {
|
||||
if rawSID == "" {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Cache hit fast path.
|
||||
if sess, ok, err := s.readCache(ctx, sidHash); err != nil {
|
||||
slog.Warn("session: read cache failed", "error", err)
|
||||
} else if ok {
|
||||
if now.After(sess.ExpiresAt) {
|
||||
_ = s.revokeByHash(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
s.slideIdle(ctx, sess)
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
row, err := s.db.GetSession(ctx, sidHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
if now.After(row.ExpiresAt.Time) {
|
||||
_ = s.db.DeleteSession(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
if hydrate != nil {
|
||||
if err := hydrate(ctx, sess); err != nil {
|
||||
return nil, fmt.Errorf("hydrate session: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Rotate revokes the session identified by oldHashedSID and issues a fresh
|
||||
// one with possibly updated team/role. Used on team switch and any privilege
|
||||
// change. The old ID is a hash (taken from AuthContext.SessionID), not the
|
||||
// raw cookie value.
|
||||
func (s *Service) Rotate(
|
||||
ctx context.Context,
|
||||
oldHashedSID string,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
if err := s.revokeByHash(ctx, oldHashedSID); err != nil {
|
||||
return nil, fmt.Errorf("revoke old: %w", err)
|
||||
}
|
||||
return s.Create(ctx, userID, teamID, email, name, role, isAdmin, userAgent, ipAddress)
|
||||
}
|
||||
|
||||
// Revoke deletes a single session by its hashed ID from both Redis and
|
||||
// Postgres. Callers in authenticated request paths already hold the hash
|
||||
// in AuthContext.SessionID; pass that value here.
|
||||
func (s *Service) Revoke(ctx context.Context, hashedSID string) error {
|
||||
return s.revokeByHash(ctx, hashedSID)
|
||||
}
|
||||
|
||||
func (s *Service) revokeByHash(ctx context.Context, sidHash string) error {
|
||||
if sidHash == "" {
|
||||
return nil
|
||||
}
|
||||
if err := s.rdb.Del(ctx, redisKey(sidHash)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSession(ctx, sidHash); err != nil {
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllForUser deletes every session for a user. Used on password
|
||||
// add/change/reset and on logout-all.
|
||||
func (s *Service) RevokeAllForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
ids, err := s.db.DeleteSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user sessions: %w", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, id := range ids {
|
||||
pipe.Del(ctx, redisKey(id))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
slog.Warn("session: pipeline del failed", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListForUser returns all active session rows for a user, newest activity
|
||||
// first. Backed by Postgres directly — the Redis cache is opportunistic and
|
||||
// is not consulted here.
|
||||
func (s *Service) ListForUser(ctx context.Context, userID pgtype.UUID) ([]db.Session, error) {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// DeleteForUser deletes a single session if it belongs to the given user.
|
||||
// hashedSID is the stored hash (as returned by ListForUser / AuthContext),
|
||||
// not the raw cookie value. Returns no error if the SID does not exist or
|
||||
// belongs to someone else (caller is treated as having already lost
|
||||
// interest in it).
|
||||
func (s *Service) DeleteForUser(ctx context.Context, hashedSID string, userID pgtype.UUID) error {
|
||||
if err := s.rdb.Del(ctx, redisKey(hashedSID)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSessionForUser(ctx, db.DeleteSessionForUserParams{ID: hashedSID, UserID: userID}); err != nil {
|
||||
return fmt.Errorf("delete session for user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForUser drops Redis cache entries for every session
|
||||
// belonging to the given user without revoking the underlying DB rows.
|
||||
// Next request rehydrates the session from Postgres + identity tables —
|
||||
// useful after a name change so cached identity is refreshed cheaply.
|
||||
func (s *Service) InvalidateCacheForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, row := range rows {
|
||||
pipe.Del(ctx, redisKey(row.ID))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("invalidate cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTeam mutates the team_id on the current session in place (for the
|
||||
// non-rotation switch-team path; we do rotate in handlers, but the helper
|
||||
// is kept for completeness). hashedSID is the stored hash, not the raw
|
||||
// cookie value.
|
||||
func (s *Service) UpdateTeam(ctx context.Context, hashedSID string, teamID pgtype.UUID) error {
|
||||
if err := s.db.UpdateSessionTeam(ctx, db.UpdateSessionTeamParams{ID: hashedSID, TeamID: teamID}); err != nil {
|
||||
return fmt.Errorf("update session team: %w", err)
|
||||
}
|
||||
_ = s.rdb.Del(ctx, redisKey(hashedSID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleaner returns a background-worker function that periodically prunes
|
||||
// rows whose absolute expiry has passed. Register it via
|
||||
// cpserver/cpextension BackgroundWorkers wiring.
|
||||
func (s *Service) StartCleaner() func(context.Context) {
|
||||
return func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.db.DeleteExpiredSessions(ctx); err != nil {
|
||||
slog.Warn("session: delete expired failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func redisKey(sid string) string { return redisKeyPrefix + sid }
|
||||
|
||||
func (s *Service) readCache(ctx context.Context, sid string) (*Session, bool, error) {
|
||||
raw, err := s.rdb.Get(ctx, redisKey(sid)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
var sess Session
|
||||
if err := json.Unmarshal(raw, &sess); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return &sess, true, nil
|
||||
}
|
||||
|
||||
func (s *Service) writeCache(ctx context.Context, sess *Session) error {
|
||||
buf, err := json.Marshal(sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.rdb.Set(ctx, redisKey(sess.ID), buf, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *Service) slideIdle(ctx context.Context, sess *Session) {
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
if err := s.rdb.Expire(ctx, redisKey(sess.ID), ttl).Err(); err != nil {
|
||||
slog.Warn("session: expire failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) maybeTouchDB(ctx context.Context, sess *Session, now time.Time) {
|
||||
if now.Sub(sess.LastSeenAt) < touchDBInterval {
|
||||
return
|
||||
}
|
||||
sid := sess.ID
|
||||
sess.LastSeenAt = now
|
||||
go func() {
|
||||
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.db.TouchSession(c, sid); err != nil {
|
||||
slog.Warn("session: touch db failed", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -101,6 +101,10 @@ func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if isRedundantSystemFollowup(event) {
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(event.TeamID)
|
||||
if err != nil {
|
||||
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
|
||||
@ -181,3 +185,23 @@ func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error)
|
||||
func isGroupExistsError(err error) bool {
|
||||
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
|
||||
}
|
||||
|
||||
// isRedundantSystemFollowup filters out capsule lifecycle events emitted by
|
||||
// the SandboxService background goroutine / host-agent callback after a
|
||||
// user-initiated action. The corresponding handler already publishes a
|
||||
// user-actor event for the same intent; without this filter, every user
|
||||
// action delivers two notifications.
|
||||
//
|
||||
// Genuinely system-only emitters (TTL auto-pause, host_monitor reconciler,
|
||||
// host-reported failures) always set Metadata["reason"], so they pass.
|
||||
func isRedundantSystemFollowup(e events.Event) bool {
|
||||
if e.Actor.Type != events.ActorSystem {
|
||||
return false
|
||||
}
|
||||
switch e.Event {
|
||||
case events.CapsuleCreate, events.CapsulePause, events.CapsuleResume, events.CapsuleDestroy:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return e.Metadata["reason"] == ""
|
||||
}
|
||||
|
||||
@ -14,27 +14,55 @@ func FormatMessage(e events.Event) string {
|
||||
|
||||
b.WriteString(formatSummary(e))
|
||||
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
|
||||
if e.Outcome != "" {
|
||||
fmt.Fprintf(&b, "\nOutcome: %s", e.Outcome)
|
||||
}
|
||||
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
|
||||
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
|
||||
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
|
||||
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
|
||||
if e.Error != "" {
|
||||
fmt.Fprintf(&b, "\nError: %s", e.Error)
|
||||
}
|
||||
if reason, ok := e.Metadata["reason"]; ok {
|
||||
fmt.Fprintf(&b, "\nReason: %s", reason)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatSummary(e events.Event) string {
|
||||
failed := e.Outcome == events.OutcomeError
|
||||
switch e.Event {
|
||||
case events.CapsuleCreated:
|
||||
case events.CapsuleCreate:
|
||||
if failed {
|
||||
return fmt.Sprintf("Capsule %s failed to create", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
|
||||
case events.CapsuleRunning:
|
||||
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
|
||||
case events.CapsulePaused:
|
||||
case events.CapsulePause:
|
||||
if failed {
|
||||
return fmt.Sprintf("Capsule %s failed to pause", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
|
||||
case events.CapsuleDestroyed:
|
||||
case events.CapsuleResume:
|
||||
if failed {
|
||||
return fmt.Sprintf("Capsule %s failed to resume", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Capsule %s resumed", e.Resource.ID)
|
||||
case events.CapsuleDestroy:
|
||||
if failed {
|
||||
return fmt.Sprintf("Capsule %s failed to destroy", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
|
||||
case events.SnapshotCreated:
|
||||
case events.SnapshotCreate:
|
||||
if failed {
|
||||
return fmt.Sprintf("Template snapshot %s failed to create", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
|
||||
case events.SnapshotDeleted:
|
||||
case events.SnapshotDelete:
|
||||
if failed {
|
||||
return fmt.Sprintf("Template snapshot %s failed to delete", e.Resource.ID)
|
||||
}
|
||||
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
|
||||
case events.HostUp:
|
||||
return fmt.Sprintf("Host %s is up", e.Resource.ID)
|
||||
|
||||
@ -10,7 +10,10 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
)
|
||||
|
||||
const streamKey = "wrenn:events"
|
||||
const (
|
||||
streamKey = "wrenn:events"
|
||||
ssePubSubChannel = "wrenn:sse"
|
||||
)
|
||||
|
||||
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
|
||||
type Publisher struct {
|
||||
@ -22,8 +25,9 @@ func NewPublisher(rdb *redis.Client) *Publisher {
|
||||
return &Publisher{rdb: rdb}
|
||||
}
|
||||
|
||||
// Publish serializes the event and appends it to the global stream.
|
||||
// Fire-and-forget: failures are logged, never propagated.
|
||||
// Publish serializes the event, appends it to the durable Redis stream
|
||||
// (consumed by channel dispatcher for webhook/telegram delivery), and
|
||||
// mirrors it on the SSE Pub/Sub channel for the dashboard. Fire-and-forget.
|
||||
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
@ -41,4 +45,24 @@ func (p *Publisher) Publish(ctx context.Context, e events.Event) {
|
||||
}).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
|
||||
}
|
||||
|
||||
if err := p.rdb.Publish(ctx, ssePubSubChannel, string(payload)).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish SSE event", "event", e.Event, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishTransient mirrors the event on the SSE Pub/Sub channel only — no
|
||||
// durable stream write, no channel dispatch. Used for ephemeral UI signals
|
||||
// (status transitions during start/pause/resume) that should reach the
|
||||
// dashboard live but must not be delivered to webhook/telegram subscribers.
|
||||
func (p *Publisher) PublishTransient(ctx context.Context, e events.Event) {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to marshal transient event", "event", e.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.rdb.Publish(ctx, ssePubSubChannel, string(payload)).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish transient SSE event", "event", e.Event, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,8 +45,8 @@ var requiredFields = map[string][]string{
|
||||
var validEvents map[string]bool
|
||||
|
||||
func init() {
|
||||
validEvents = make(map[string]bool, len(events.AllEventTypes))
|
||||
for _, et := range events.AllEventTypes {
|
||||
validEvents = make(map[string]bool, len(events.SubscribableEventTypes))
|
||||
for _, et := range events.SubscribableEventTypes {
|
||||
validEvents[et] = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,13 +6,19 @@ package cpextension
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
sessionmw "git.omukk.dev/wrenn/wrenn/pkg/auth/session/middleware"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/config"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/email"
|
||||
@ -20,19 +26,34 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
)
|
||||
|
||||
// Re-exported cookie / header names so extensions can reference the canonical
|
||||
// values without depending on the middleware sub-package directly.
|
||||
const (
|
||||
SessionCookieName = sessionmw.SessionCookieName
|
||||
CSRFCookieName = sessionmw.CSRFCookieName
|
||||
CSRFHeaderName = sessionmw.CSRFHeaderName
|
||||
)
|
||||
|
||||
// ServerContext exposes the initialized dependencies that extensions can use
|
||||
// to register routes and start background workers. All fields are read-only
|
||||
// from the extension's perspective.
|
||||
type ServerContext struct {
|
||||
Queries *db.Queries
|
||||
PgPool *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
HostPool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
CA *auth.CA
|
||||
Audit *audit.AuditLogger
|
||||
Mailer email.Mailer
|
||||
Queries *db.Queries
|
||||
PgPool *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
HostPool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
CA *auth.CA
|
||||
Audit *audit.AuditLogger
|
||||
Mailer email.Mailer
|
||||
OAuthRegistry *oauth.Registry
|
||||
Channels *channels.Service
|
||||
ChannelPub *channels.Publisher
|
||||
// JWTSecret signs host-agent tokens and HMACs OAuth state cookies. User
|
||||
// auth uses cookie-backed sessions and does not depend on this value —
|
||||
// extensions should not use it to verify user identity.
|
||||
JWTSecret []byte
|
||||
Sessions *session.Service
|
||||
Config config.Config
|
||||
}
|
||||
|
||||
@ -56,3 +77,67 @@ type Extension interface {
|
||||
type MiddlewareProvider interface {
|
||||
Middlewares(ctx ServerContext) []func(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
// AuthHook is optionally implemented by extensions that need to react to
|
||||
// identity lifecycle events. OnSignup runs synchronously inside the signup
|
||||
// handler — returning an error fails the request, which is the contract
|
||||
// billing extensions rely on (no Wrenn user without a Lago customer).
|
||||
// OnLogin and the delete hooks are fire-and-forget at the call site: errors
|
||||
// are logged but never block the user-visible flow.
|
||||
type AuthHook interface {
|
||||
OnSignup(ctx context.Context, userID, teamID pgtype.UUID, email string) error
|
||||
OnLogin(ctx context.Context, userID pgtype.UUID) error
|
||||
OnAccountSoftDelete(ctx context.Context, userID pgtype.UUID) error
|
||||
OnAccountHardDelete(ctx context.Context, userID pgtype.UUID) error
|
||||
}
|
||||
|
||||
// SandboxEvent is the canonical payload handed to SandboxEventHook
|
||||
// implementations. The Type field uses the public verb names ("created",
|
||||
// "started", "paused", "resumed", "stopped", "destroyed").
|
||||
type SandboxEvent struct {
|
||||
SandboxID pgtype.UUID
|
||||
TeamID pgtype.UUID
|
||||
Type string
|
||||
OccurredAt time.Time
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// SandboxEventHook is optionally implemented by extensions that need to react
|
||||
// to sandbox lifecycle events for metering, audit shipping, etc. The hook is
|
||||
// invoked from inside the Redis stream consumer; returning an error causes
|
||||
// the message to be left un-acked so it will be redelivered. Hooks must be
|
||||
// idempotent.
|
||||
type SandboxEventHook interface {
|
||||
OnSandboxEvent(ctx context.Context, ev SandboxEvent) error
|
||||
}
|
||||
|
||||
// --- Auth middleware helpers exposed to extensions ---
|
||||
|
||||
// RequireSession returns middleware that enforces a valid session cookie.
|
||||
func RequireSession(sctx ServerContext) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSession(sctx.Sessions, sctx.Queries)
|
||||
}
|
||||
|
||||
// RequireSessionOrAPIKey returns middleware that accepts either an X-API-Key
|
||||
// header (SDKs) or a wrenn_sid cookie (browser).
|
||||
func RequireSessionOrAPIKey(sctx ServerContext) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSessionOrAPIKey(sctx.Sessions, sctx.Queries)
|
||||
}
|
||||
|
||||
// RequireAdmin returns middleware that gates routes on platform-admin status.
|
||||
// Must run after RequireSession.
|
||||
func RequireAdmin(sctx ServerContext) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireAdmin(sctx.Queries)
|
||||
}
|
||||
|
||||
// IssueSession creates a fresh session for the given user/team and writes
|
||||
// the cookies onto the response. Identity columns are looked up from the DB.
|
||||
func IssueSession(w http.ResponseWriter, r *http.Request, sctx ServerContext, userID, teamID pgtype.UUID) (*session.Session, error) {
|
||||
return sessionmw.IssueSession(r.Context(), w, r, sctx.Queries, sctx.Sessions, userID, teamID)
|
||||
}
|
||||
|
||||
// ClearSessionCookies invalidates the session and CSRF cookies. Suitable for
|
||||
// extension logout flows that aren't routed through OSS handlers.
|
||||
func ClearSessionCookies(w http.ResponseWriter, r *http.Request) {
|
||||
sessionmw.ClearCookies(w, sessionmw.IsSecure(r))
|
||||
}
|
||||
|
||||
@ -21,8 +21,10 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/config"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
@ -163,22 +165,37 @@ func Run(opts ...Option) {
|
||||
FromEmail: cfg.SMTPFromEmail,
|
||||
})
|
||||
|
||||
// Session service backs cookie auth for the browser; exposed to
|
||||
// extensions through ServerContext so cloud-repo code can revoke or
|
||||
// invalidate sessions on identity events without re-implementing the store.
|
||||
sessionSvc := session.NewService(queries, rdb)
|
||||
|
||||
// Build the server context that extensions receive.
|
||||
sctx := ServerContext{
|
||||
Queries: queries,
|
||||
PgPool: pool,
|
||||
Redis: rdb,
|
||||
HostPool: hostPool,
|
||||
Scheduler: hostScheduler,
|
||||
CA: ca,
|
||||
Audit: al,
|
||||
Mailer: mailer,
|
||||
JWTSecret: []byte(cfg.JWTSecret),
|
||||
Config: cfg,
|
||||
Queries: queries,
|
||||
PgPool: pool,
|
||||
Redis: rdb,
|
||||
HostPool: hostPool,
|
||||
Scheduler: hostScheduler,
|
||||
CA: ca,
|
||||
Audit: al,
|
||||
Mailer: mailer,
|
||||
OAuthRegistry: oauthRegistry,
|
||||
Channels: channelSvc,
|
||||
ChannelPub: channelPub,
|
||||
JWTSecret: []byte(cfg.JWTSecret),
|
||||
Sessions: sessionSvc,
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
// Host monitor (safety-net reconciliation every 5 minutes).
|
||||
// Primary state sync is push-based (host agent callbacks + CP background
|
||||
// goroutines). The monitor acts as a fallback for missed events, host death
|
||||
// detection, and transient status resolution.
|
||||
monitor := api.NewHostMonitor(queries, hostPool, al, 5*time.Minute)
|
||||
|
||||
// API server.
|
||||
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca, al, channelSvc, mailer, o.extensions, sctx, o.version)
|
||||
srv := api.New(ctx, queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca, al, channelPub, channelSvc, mailer, o.extensions, sctx, monitor, o.version)
|
||||
|
||||
// Start template build workers (2 concurrent).
|
||||
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
|
||||
@ -187,10 +204,30 @@ func Run(opts ...Option) {
|
||||
// Start channel event dispatcher.
|
||||
channelDispatcher.Start(ctx)
|
||||
|
||||
// Start host monitor (passive + active reconciliation every 30s).
|
||||
monitor := api.NewHostMonitor(queries, hostPool, al, 15*time.Second)
|
||||
// Start sandbox event consumer (processes lifecycle events from Redis stream).
|
||||
var sandboxHooks []cpextension.SandboxEventHook
|
||||
for _, ext := range o.extensions {
|
||||
if h, ok := ext.(cpextension.SandboxEventHook); ok {
|
||||
sandboxHooks = append(sandboxHooks, h)
|
||||
}
|
||||
}
|
||||
sandboxEventConsumer := api.NewSandboxEventConsumer(rdb, queries, al, sandboxHooks)
|
||||
sandboxEventConsumer.Start(ctx)
|
||||
|
||||
// Start SSE relay (subscribes to Redis Pub/Sub, dispatches to connected clients).
|
||||
srv.SSERelay.Start(ctx)
|
||||
|
||||
// Start host monitor loop.
|
||||
monitor.Start(ctx)
|
||||
|
||||
// Collect AuthHook extensions for the hard-delete cleanup goroutine.
|
||||
var authHooks []cpextension.AuthHook
|
||||
for _, ext := range o.extensions {
|
||||
if h, ok := ext.(cpextension.AuthHook); ok {
|
||||
authHooks = append(authHooks, h)
|
||||
}
|
||||
}
|
||||
|
||||
// Hard-delete accounts that have been soft-deleted for more than 15 days (runs every 24h).
|
||||
// Audit logs referencing deleted users are anonymized before the user row is removed.
|
||||
// A notification email is sent to the user before their data is permanently removed.
|
||||
@ -218,6 +255,11 @@ func Run(opts ...Option) {
|
||||
slog.Error("account cleanup: failed to hard-delete user", "user_id", prefixedID, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, h := range authHooks {
|
||||
if err := h.OnAccountHardDelete(ctx, row.ID); err != nil {
|
||||
slog.Warn("account cleanup: OnAccountHardDelete hook failed", "user_id", prefixedID, "error", err)
|
||||
}
|
||||
}
|
||||
if err := mailer.Send(ctx, row.Email, "Your Wrenn account has been deleted", email.EmailData{
|
||||
Message: "Your Wrenn account and all associated data have been permanently deleted. " +
|
||||
"This action was taken automatically because your account was scheduled for deletion more than 15 days ago.\n\n" +
|
||||
@ -246,7 +288,7 @@ func Run(opts ...Option) {
|
||||
// Start extension background workers.
|
||||
for _, ext := range o.extensions {
|
||||
for _, worker := range ext.BackgroundWorkers(sctx) {
|
||||
worker(ctx)
|
||||
go worker(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -172,8 +172,6 @@ SELECT
|
||||
h.created_by,
|
||||
h.created_at,
|
||||
h.updated_at,
|
||||
h.cert_fingerprint,
|
||||
h.cert_expires_at,
|
||||
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
|
||||
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
|
||||
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
|
||||
@ -205,8 +203,6 @@ type GetHostsWithLoadRow struct {
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
CertFingerprint string `json:"cert_fingerprint"`
|
||||
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
RunningDiskMb int32 `json:"running_disk_mb"`
|
||||
@ -242,8 +238,6 @@ func (q *Queries) GetHostsWithLoad(ctx context.Context) ([]GetHostsWithLoadRow,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
&i.RunningVcpus,
|
||||
&i.RunningMemoryMb,
|
||||
&i.RunningDiskMb,
|
||||
@ -427,6 +421,105 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsAdmin = `-- name: ListHostsAdmin :many
|
||||
SELECT
|
||||
h.id,
|
||||
h.type,
|
||||
h.team_id,
|
||||
h.provider,
|
||||
h.availability_zone,
|
||||
h.arch,
|
||||
h.cpu_cores,
|
||||
h.memory_mb,
|
||||
h.disk_gb,
|
||||
h.address,
|
||||
h.status,
|
||||
h.last_heartbeat_at,
|
||||
h.metadata,
|
||||
h.created_by,
|
||||
h.created_at,
|
||||
h.updated_at,
|
||||
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
|
||||
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
|
||||
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
|
||||
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_memory_mb,
|
||||
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_disk_mb
|
||||
FROM hosts h
|
||||
LEFT JOIN sandboxes s ON s.host_id = h.id
|
||||
AND s.status IN ('running', 'paused', 'starting', 'pending')
|
||||
GROUP BY h.id
|
||||
ORDER BY h.created_at DESC
|
||||
`
|
||||
|
||||
type ListHostsAdminRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
RunningDiskMb int32 `json:"running_disk_mb"`
|
||||
PausedMemoryMb int32 `json:"paused_memory_mb"`
|
||||
PausedDiskMb int32 `json:"paused_disk_mb"`
|
||||
}
|
||||
|
||||
// Returns all hosts with per-host sandbox resource consumption aggregated.
|
||||
// Unlike GetHostsWithLoad, this returns ALL hosts (not just online) so admins
|
||||
// can see resource usage across the entire fleet including offline/pending hosts.
|
||||
func (q *Queries) ListHostsAdmin(ctx context.Context) ([]ListHostsAdminRow, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListHostsAdminRow
|
||||
for rows.Next() {
|
||||
var i ListHostsAdminRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.RunningVcpus,
|
||||
&i.RunningMemoryMb,
|
||||
&i.RunningDiskMb,
|
||||
&i.PausedMemoryMb,
|
||||
&i.PausedDiskMb,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
`
|
||||
@ -517,18 +610,71 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
||||
}
|
||||
|
||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
SELECT
|
||||
h.id,
|
||||
h.type,
|
||||
h.team_id,
|
||||
h.provider,
|
||||
h.availability_zone,
|
||||
h.arch,
|
||||
h.cpu_cores,
|
||||
h.memory_mb,
|
||||
h.disk_gb,
|
||||
h.address,
|
||||
h.status,
|
||||
h.last_heartbeat_at,
|
||||
h.metadata,
|
||||
h.created_by,
|
||||
h.created_at,
|
||||
h.updated_at,
|
||||
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
|
||||
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
|
||||
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
|
||||
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_memory_mb,
|
||||
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_disk_mb
|
||||
FROM hosts h
|
||||
LEFT JOIN sandboxes s ON s.host_id = h.id
|
||||
AND s.status IN ('running', 'paused', 'starting', 'pending')
|
||||
WHERE h.team_id = $1 AND h.type = 'byoc'
|
||||
GROUP BY h.id
|
||||
ORDER BY h.created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
|
||||
type ListHostsByTeamRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
Provider string `json:"provider"`
|
||||
AvailabilityZone string `json:"availability_zone"`
|
||||
Arch string `json:"arch"`
|
||||
CpuCores int32 `json:"cpu_cores"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
DiskGb int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedBy pgtype.UUID `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
RunningDiskMb int32 `json:"running_disk_mb"`
|
||||
PausedMemoryMb int32 `json:"paused_memory_mb"`
|
||||
PausedDiskMb int32 `json:"paused_disk_mb"`
|
||||
}
|
||||
|
||||
// Returns hosts by team with per-host sandbox resource consumption aggregated.
|
||||
// Follows the same aggregation pattern as ListHostsAdmin and GetHostsWithLoad.
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]ListHostsByTeamRow, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
var items []ListHostsByTeamRow
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
var i ListHostsByTeamRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
@ -546,8 +692,11 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.CertExpiresAt,
|
||||
&i.RunningVcpus,
|
||||
&i.RunningMemoryMb,
|
||||
&i.RunningDiskMb,
|
||||
&i.PausedMemoryMb,
|
||||
&i.PausedDiskMb,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -133,6 +133,33 @@ func (q *Queries) GetDailyUsage(ctx context.Context, arg GetDailyUsageParams) ([
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getLatestSandboxMetricPoint = `-- name: GetLatestSandboxMetricPoint :one
|
||||
SELECT ts, cpu_pct, mem_bytes, disk_bytes
|
||||
FROM sandbox_metric_points
|
||||
WHERE sandbox_id = $1
|
||||
ORDER BY ts DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetLatestSandboxMetricPointRow struct {
|
||||
Ts int64 `json:"ts"`
|
||||
CpuPct float64 `json:"cpu_pct"`
|
||||
MemBytes int64 `json:"mem_bytes"`
|
||||
DiskBytes int64 `json:"disk_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetLatestSandboxMetricPoint(ctx context.Context, sandboxID pgtype.UUID) (GetLatestSandboxMetricPointRow, error) {
|
||||
row := q.db.QueryRow(ctx, getLatestSandboxMetricPoint, sandboxID)
|
||||
var i GetLatestSandboxMetricPointRow
|
||||
err := row.Scan(
|
||||
&i.Ts,
|
||||
&i.CpuPct,
|
||||
&i.MemBytes,
|
||||
&i.DiskBytes,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getLiveMetrics = `-- name: GetLiveMetrics :one
|
||||
SELECT
|
||||
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
|
||||
|
||||
@ -139,6 +139,18 @@ type SandboxMetricsSnapshot struct {
|
||||
MemoryMbReserved int32 `json:"memory_mb_reserved"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
CsrfToken string `json:"csrf_token"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IpAddress string `json:"ip_address"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastSeenAt pgtype.Timestamptz `json:"last_seen_at"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
type Team struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
|
||||
@ -11,17 +11,24 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
|
||||
const bulkRestoreMissingToStatus = `-- name: BulkRestoreMissingToStatus :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::uuid[]) AND status = 'missing'
|
||||
`
|
||||
|
||||
type BulkRestoreMissingToStatusParams struct {
|
||||
Column1 []pgtype.UUID `json:"column_1"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// Called by the reconciler when a host comes back online and its sandboxes are
|
||||
// confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
|
||||
// confirmed alive. Restores only sandboxes currently in 'missing' state to the
|
||||
// given target status (typically 'running' or 'paused' based on the live state
|
||||
// reported by the host agent's ListSandboxes RPC).
|
||||
func (q *Queries) BulkRestoreMissingToStatus(ctx context.Context, arg BulkRestoreMissingToStatusParams) error {
|
||||
_, err := q.db.Exec(ctx, bulkRestoreMissingToStatus, arg.Column1, arg.Status)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -375,7 +382,7 @@ const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
|
||||
UPDATE sandboxes
|
||||
SET status = 'missing',
|
||||
last_updated = NOW()
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending', 'pausing', 'resuming', 'stopping')
|
||||
`
|
||||
|
||||
// Called when the host monitor marks a host unreachable.
|
||||
@ -403,6 +410,23 @@ func (q *Queries) UpdateLastActive(ctx context.Context, arg UpdateLastActivePara
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSandboxDiskSize = `-- name: UpdateSandboxDiskSize :exec
|
||||
UPDATE sandboxes
|
||||
SET disk_size_mb = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateSandboxDiskSizeParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
DiskSizeMb int32 `json:"disk_size_mb"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxDiskSize(ctx context.Context, arg UpdateSandboxDiskSizeParams) error {
|
||||
_, err := q.db.Exec(ctx, updateSandboxDiskSize, arg.ID, arg.DiskSizeMb)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSandboxMetadata = `-- name: UpdateSandboxMetadata :exec
|
||||
UPDATE sandboxes
|
||||
SET metadata = $2,
|
||||
@ -470,6 +494,61 @@ func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRun
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSandboxRunningIf = `-- name: UpdateSandboxRunningIf :one
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
host_ip = $3,
|
||||
guest_ip = $4,
|
||||
started_at = $5,
|
||||
last_active_at = $5,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1 AND status = $2
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata
|
||||
`
|
||||
|
||||
type UpdateSandboxRunningIfParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
HostIp string `json:"host_ip"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
}
|
||||
|
||||
// Conditionally transition a sandbox to running only if the current status
|
||||
// matches the expected value. Prevents races where a user destroys a sandbox
|
||||
// while the create/resume goroutine is still in-flight.
|
||||
func (q *Queries) UpdateSandboxRunningIf(ctx context.Context, arg UpdateSandboxRunningIfParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxRunningIf,
|
||||
arg.ID,
|
||||
arg.Status,
|
||||
arg.HostIp,
|
||||
arg.GuestIp,
|
||||
arg.StartedAt,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
&i.Metadata,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSandboxStatus = `-- name: UpdateSandboxStatus :one
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
@ -508,3 +587,46 @@ func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStat
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSandboxStatusIf = `-- name: UpdateSandboxStatusIf :one
|
||||
UPDATE sandboxes
|
||||
SET status = $3,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1 AND status = $2
|
||||
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata
|
||||
`
|
||||
|
||||
type UpdateSandboxStatusIfParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Status_2 string `json:"status_2"`
|
||||
}
|
||||
|
||||
// Atomically update status only when the current status matches the expected value.
|
||||
// Prevents background goroutines from overwriting a status that has since changed
|
||||
// (e.g. user destroyed a sandbox while Create was in-flight).
|
||||
func (q *Queries) UpdateSandboxStatusIf(ctx context.Context, arg UpdateSandboxStatusIfParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxStatusIf, arg.ID, arg.Status, arg.Status_2)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.DiskSizeMb,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TemplateID,
|
||||
&i.TemplateTeamID,
|
||||
&i.Metadata,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
187
pkg/db/sessions.sql.go
Normal file
187
pkg/db/sessions.sql.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
|
||||
DELETE FROM sessions WHERE expires_at < NOW()
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredSessions(ctx context.Context) error {
|
||||
_, err := q.db.Exec(ctx, deleteExpiredSessions)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSession = `-- name: DeleteSession :exec
|
||||
DELETE FROM sessions WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSession(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, deleteSession, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSessionForUser = `-- name: DeleteSessionForUser :exec
|
||||
DELETE FROM sessions WHERE id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteSessionForUserParams struct {
|
||||
ID string `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteSessionForUser(ctx context.Context, arg DeleteSessionForUserParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteSessionForUser, arg.ID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSessionsByUserID = `-- name: DeleteSessionsByUserID :many
|
||||
DELETE FROM sessions WHERE user_id = $1 RETURNING id
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSessionsByUserID(ctx context.Context, userID pgtype.UUID) ([]string, error) {
|
||||
rows, err := q.db.Query(ctx, deleteSessionsByUserID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSession = `-- name: GetSession :one
|
||||
SELECT id, user_id, team_id, csrf_token, user_agent, ip_address, created_at, last_seen_at, expires_at FROM sessions WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSession(ctx context.Context, id string) (Session, error) {
|
||||
row := q.db.QueryRow(ctx, getSession, id)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.CsrfToken,
|
||||
&i.UserAgent,
|
||||
&i.IpAddress,
|
||||
&i.CreatedAt,
|
||||
&i.LastSeenAt,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertSession = `-- name: InsertSession :one
|
||||
INSERT INTO sessions (id, user_id, team_id, csrf_token, user_agent, ip_address, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, user_id, team_id, csrf_token, user_agent, ip_address, created_at, last_seen_at, expires_at
|
||||
`
|
||||
|
||||
type InsertSessionParams struct {
|
||||
ID string `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
CsrfToken string `json:"csrf_token"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IpAddress string `json:"ip_address"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSession(ctx context.Context, arg InsertSessionParams) (Session, error) {
|
||||
row := q.db.QueryRow(ctx, insertSession,
|
||||
arg.ID,
|
||||
arg.UserID,
|
||||
arg.TeamID,
|
||||
arg.CsrfToken,
|
||||
arg.UserAgent,
|
||||
arg.IpAddress,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.CsrfToken,
|
||||
&i.UserAgent,
|
||||
&i.IpAddress,
|
||||
&i.CreatedAt,
|
||||
&i.LastSeenAt,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listSessionsByUserID = `-- name: ListSessionsByUserID :many
|
||||
SELECT id, user_id, team_id, csrf_token, user_agent, ip_address, created_at, last_seen_at, expires_at FROM sessions WHERE user_id = $1 ORDER BY last_seen_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSessionsByUserID(ctx context.Context, userID pgtype.UUID) ([]Session, error) {
|
||||
rows, err := q.db.Query(ctx, listSessionsByUserID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Session
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.CsrfToken,
|
||||
&i.UserAgent,
|
||||
&i.IpAddress,
|
||||
&i.CreatedAt,
|
||||
&i.LastSeenAt,
|
||||
&i.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const touchSession = `-- name: TouchSession :exec
|
||||
UPDATE sessions SET last_seen_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) TouchSession(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, touchSession, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSessionTeam = `-- name: UpdateSessionTeam :exec
|
||||
UPDATE sessions SET team_id = $2 WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateSessionTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSessionTeam(ctx context.Context, arg UpdateSessionTeamParams) error {
|
||||
_, err := q.db.Exec(ctx, updateSessionTeam, arg.ID, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
@ -358,7 +358,9 @@ SELECT
|
||||
COALESCE(owner_u.name, '') AS owner_name,
|
||||
COALESCE(owner_u.email, '') AS owner_email,
|
||||
(SELECT COUNT(*) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting'))::int AS active_sandbox_count,
|
||||
(SELECT COUNT(*) FROM channels c WHERE c.team_id = t.id)::int AS channel_count
|
||||
(SELECT COUNT(*) FROM channels c WHERE c.team_id = t.id)::int AS channel_count,
|
||||
COALESCE((SELECT SUM(s.vcpus) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting')), 0)::int AS running_vcpus,
|
||||
COALESCE((SELECT SUM(s.memory_mb) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting')), 0)::int AS running_memory_mb
|
||||
FROM teams t
|
||||
LEFT JOIN users_teams owner_ut ON owner_ut.team_id = t.id AND owner_ut.role = 'owner'
|
||||
LEFT JOIN users owner_u ON owner_u.id = owner_ut.user_id
|
||||
@ -384,6 +386,8 @@ type ListTeamsAdminRow struct {
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
ActiveSandboxCount int32 `json:"active_sandbox_count"`
|
||||
ChannelCount int32 `json:"channel_count"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListTeamsAdmin(ctx context.Context, arg ListTeamsAdminParams) ([]ListTeamsAdminRow, error) {
|
||||
@ -407,6 +411,8 @@ func (q *Queries) ListTeamsAdmin(ctx context.Context, arg ListTeamsAdminParams)
|
||||
&i.OwnerEmail,
|
||||
&i.ActiveSandboxCount,
|
||||
&i.ChannelCount,
|
||||
&i.RunningVcpus,
|
||||
&i.RunningMemoryMb,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -385,3 +385,17 @@ func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Temp
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateTemplateSize = `-- name: UpdateTemplateSize :exec
|
||||
UPDATE templates SET size_bytes = $2 WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateTemplateSizeParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateTemplateSize(ctx context.Context, arg UpdateTemplateSizeParams) error {
|
||||
_, err := q.db.Exec(ctx, updateTemplateSize, arg.ID, arg.SizeBytes)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -7,8 +7,18 @@ import (
|
||||
|
||||
// EventPublisher pushes events onto the notification stream.
|
||||
// Satisfied by *channels.Publisher.
|
||||
//
|
||||
// Publish writes durably (Redis stream + SSE Pub/Sub mirror) and is delivered
|
||||
// to subscribed channels. Use for actions that users may want webhook/telegram
|
||||
// notifications about.
|
||||
//
|
||||
// PublishTransient writes only to the SSE Pub/Sub mirror — no durable stream,
|
||||
// no channel delivery. Use for ephemeral UI signals (e.g., status transitions
|
||||
// while a sandbox is starting/pausing) that should reach the dashboard but
|
||||
// must not spam subscribers.
|
||||
type EventPublisher interface {
|
||||
Publish(ctx context.Context, e Event)
|
||||
PublishTransient(ctx context.Context, e Event)
|
||||
}
|
||||
|
||||
// ActorKind identifies what initiated an event.
|
||||
@ -27,42 +37,77 @@ type Actor struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// SystemActor returns the canonical actor for system-initiated events
|
||||
// (TTL reaper, reconciler-inferred state, cleanup-on-error).
|
||||
func SystemActor() Actor {
|
||||
return Actor{Type: ActorSystem}
|
||||
}
|
||||
|
||||
// Resource identifies the object the event relates to.
|
||||
type Resource struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Event is the canonical notification payload published to the Redis stream
|
||||
// and delivered to channel subscribers.
|
||||
type Event struct {
|
||||
Event string `json:"event"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor Actor `json:"actor"`
|
||||
Resource Resource `json:"resource"`
|
||||
}
|
||||
// Outcome encodes whether an action succeeded or failed.
|
||||
type Outcome string
|
||||
|
||||
// Event type constants.
|
||||
const (
|
||||
CapsuleCreated = "capsule.created"
|
||||
CapsuleRunning = "capsule.running"
|
||||
CapsulePaused = "capsule.paused"
|
||||
CapsuleDestroyed = "capsule.destroyed"
|
||||
SnapshotCreated = "template.snapshot.created"
|
||||
SnapshotDeleted = "template.snapshot.deleted"
|
||||
HostUp = "host.up"
|
||||
HostDown = "host.down"
|
||||
OutcomeSuccess Outcome = "success"
|
||||
OutcomeError Outcome = "error"
|
||||
)
|
||||
|
||||
// AllEventTypes is the complete set of valid event type strings.
|
||||
var AllEventTypes = []string{
|
||||
CapsuleCreated,
|
||||
CapsuleRunning,
|
||||
CapsulePaused,
|
||||
CapsuleDestroyed,
|
||||
SnapshotCreated,
|
||||
SnapshotDeleted,
|
||||
// Event is the canonical notification payload published to the Redis stream
|
||||
// and delivered to channel subscribers.
|
||||
//
|
||||
// Outcome distinguishes success vs. failure for action events. It is empty
|
||||
// for events with no success/error semantics (state.changed, host.up, host.down).
|
||||
// Error carries the failure reason when Outcome == OutcomeError.
|
||||
// Metadata carries event-specific structured context (e.g., reason, from/to
|
||||
// state for transitions, inferred=true for reconciler-derived events).
|
||||
type Event struct {
|
||||
Event string `json:"event"`
|
||||
Outcome Outcome `json:"outcome,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor Actor `json:"actor"`
|
||||
Resource Resource `json:"resource"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Event type constants. Group-level names: subscription matches on Event,
|
||||
// Outcome is a payload field so webhook recipients can distinguish success
|
||||
// from failure without separate subscriptions.
|
||||
const (
|
||||
// Durable, subscribable. First boot only (subsequent unpauses are CapsuleResume).
|
||||
CapsuleCreate = "capsule.create"
|
||||
CapsulePause = "capsule.pause"
|
||||
CapsuleResume = "capsule.resume"
|
||||
CapsuleDestroy = "capsule.destroy"
|
||||
|
||||
// Durable, subscribable.
|
||||
SnapshotCreate = "template.snapshot.create"
|
||||
SnapshotDelete = "template.snapshot.delete"
|
||||
|
||||
// Durable, no outcome (binary by name).
|
||||
HostUp = "host.up"
|
||||
HostDown = "host.down"
|
||||
|
||||
// Transient (SSE-only via PublishTransient). Not subscribable.
|
||||
// Metadata: from, to (sandbox status strings).
|
||||
CapsuleStateChanged = "capsule.state.changed"
|
||||
)
|
||||
|
||||
// SubscribableEventTypes is the set of event types users can subscribe to
|
||||
// via channels (webhook, telegram, shoutrrr). Excludes transient events.
|
||||
var SubscribableEventTypes = []string{
|
||||
CapsuleCreate,
|
||||
CapsulePause,
|
||||
CapsuleResume,
|
||||
CapsuleDestroy,
|
||||
SnapshotCreate,
|
||||
SnapshotDelete,
|
||||
HostUp,
|
||||
HostDown,
|
||||
}
|
||||
|
||||
36
pkg/id/id.go
36
pkg/id/id.go
@ -2,6 +2,7 @@ package id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@ -156,10 +157,37 @@ func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixCh
|
||||
// (e.g. base templates, shared infrastructure).
|
||||
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
|
||||
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
|
||||
// template. When both team_id and template_id are zero, the host agent uses
|
||||
// the minimal rootfs at WRENN_DIR/images/minimal/.
|
||||
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
|
||||
// SystemTemplateMaxID is the highest template ID reserved for built-in system
|
||||
// base templates. Template IDs in [0, SystemTemplateMaxID] under the platform
|
||||
// team are protected: they cannot be deleted and live at the well-known
|
||||
// teams/{base36(0)}/{base36(id)} on-disk paths.
|
||||
const SystemTemplateMaxID = 1024
|
||||
|
||||
// templateID returns the all-zeros UUID with its low 64 bits set to n. Used to
|
||||
// mint the well-known IDs for the built-in system base templates.
|
||||
func templateID(n uint64) pgtype.UUID {
|
||||
var b [16]byte
|
||||
binary.BigEndian.PutUint64(b[8:], n)
|
||||
return pgtype.UUID{Bytes: b, Valid: true}
|
||||
}
|
||||
|
||||
// Well-known system base template IDs (platform team). The on-disk rootfs for
|
||||
// each lives at WRENN_DIR/images/teams/{base36(PlatformTeamID)}/{base36(id)}/.
|
||||
var (
|
||||
UbuntuTemplateID = templateID(0) // minimal-ubuntu (replaces the old "minimal")
|
||||
AlpineTemplateID = templateID(1) // minimal-alpine
|
||||
ArchTemplateID = templateID(2) // minimal-arch
|
||||
FedoraTemplateID = templateID(3) // minimal-fedora
|
||||
)
|
||||
|
||||
// IsReservedTemplateID reports whether t falls in the reserved system template
|
||||
// ID range [0, SystemTemplateMaxID] (i.e. the top 64 bits are zero and the
|
||||
// bottom 64 bits are <= SystemTemplateMaxID).
|
||||
func IsReservedTemplateID(t pgtype.UUID) bool {
|
||||
hi := binary.BigEndian.Uint64(t.Bytes[:8])
|
||||
lo := binary.BigEndian.Uint64(t.Bytes[8:])
|
||||
return hi == 0 && lo <= SystemTemplateMaxID
|
||||
}
|
||||
|
||||
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
|
||||
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
|
||||
|
||||
@ -165,7 +165,5 @@ func hostFromRow(r *db.GetHostsWithLoadRow) db.Host {
|
||||
CreatedBy: r.CreatedBy,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
CertFingerprint: r.CertFingerprint,
|
||||
CertExpiresAt: r.CertExpiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@ -26,14 +27,17 @@ const (
|
||||
buildCommandTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// preBuildCmds run before the user recipe to prepare the build environment.
|
||||
// apt update runs as root first, then USER switches to wrenn-user for the recipe.
|
||||
// preBuildCmds run before the recipe to prepare the build environment, as
|
||||
// root. The build user (USER/WORKDIR) is not injected here — Create prepends
|
||||
// it to the persisted recipe instead, so "run as root" can omit it with no
|
||||
// build-level flag to track.
|
||||
var preBuildCmds = []string{
|
||||
"RUN apt update",
|
||||
"USER wrenn-user",
|
||||
"WORKDIR /home/wrenn-user",
|
||||
}
|
||||
|
||||
// buildUser is the non-root user a recipe runs as unless run_as_root is set.
|
||||
const buildUser = "wrenn-user"
|
||||
|
||||
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
|
||||
var postBuildCmds = []string{
|
||||
"RUN apt clean",
|
||||
@ -47,6 +51,8 @@ type buildAgentClient interface {
|
||||
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
|
||||
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
|
||||
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
|
||||
PtyAttach(ctx context.Context, req *connect.Request[pb.PtyAttachRequest]) (*connect.ServerStreamForClient[pb.PtyAttachResponse], error)
|
||||
PtyKill(ctx context.Context, req *connect.Request[pb.PtyKillRequest]) (*connect.Response[pb.PtyKillResponse], error)
|
||||
WriteFile(ctx context.Context, req *connect.Request[pb.WriteFileRequest]) (*connect.Response[pb.WriteFileResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
|
||||
@ -73,6 +79,7 @@ type BuildCreateParams struct {
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
SkipPrePost bool
|
||||
RunAsRoot bool // Run the recipe as root instead of the non-root build user.
|
||||
Archive []byte // Optional tar/tar.gz/zip archive for COPY commands.
|
||||
ArchiveName string // Original filename (used to detect format).
|
||||
}
|
||||
@ -99,7 +106,7 @@ func (s *BuildService) takeArchive(buildID string) []byte {
|
||||
// Create inserts a new build record and enqueues it to Redis.
|
||||
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
|
||||
if p.BaseTemplate == "" {
|
||||
p.BaseTemplate = "minimal"
|
||||
p.BaseTemplate = "minimal-ubuntu"
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
@ -108,7 +115,19 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(p.Recipe)
|
||||
// Assemble the recipe. Unless run_as_root is set, the non-root build user
|
||||
// is prepended as USER + WORKDIR steps. Persisting it in the recipe means
|
||||
// "run as root" needs no build-level flag — it simply omits these steps,
|
||||
// so wrenn-user is never created in a root template.
|
||||
recipeLines := p.Recipe
|
||||
if !p.RunAsRoot {
|
||||
recipeLines = append([]string{
|
||||
"USER " + buildUser,
|
||||
"WORKDIR /home/" + buildUser,
|
||||
}, recipeLines...)
|
||||
}
|
||||
|
||||
recipeJSON, err := json.Marshal(recipeLines)
|
||||
if err != nil {
|
||||
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
|
||||
}
|
||||
@ -130,7 +149,7 @@ func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.Temp
|
||||
Healthcheck: p.Healthcheck,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TotalSteps: int32(len(p.Recipe) + defaultSteps),
|
||||
TotalSteps: int32(len(recipeLines) + defaultSteps),
|
||||
TemplateID: newTemplateID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
SkipPrePost: p.SkipPrePost,
|
||||
@ -183,6 +202,7 @@ func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update build status: %w", err)
|
||||
}
|
||||
s.publishStatus(ctx, buildID, "cancelled", 0, 0, "")
|
||||
|
||||
// If the build is currently running, signal its context.
|
||||
buildIDStr := id.FormatBuildID(buildID)
|
||||
@ -274,6 +294,7 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
log.Error("failed to update build status", "error", err)
|
||||
return
|
||||
}
|
||||
s.publishStatus(buildCtx, buildID, "running", 0, build.TotalSteps, "")
|
||||
|
||||
// Parse user recipe.
|
||||
var userRecipe []string
|
||||
@ -282,69 +303,11 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a platform host and create a sandbox.
|
||||
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false, build.MemoryMb, 5120)
|
||||
agent, sandboxIDStr, sandboxMetadata, err := s.provisionBuildSandbox(buildCtx, buildID, buildIDStr, build, log)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
|
||||
baseTeamID := id.PlatformTeamID
|
||||
baseTemplateID := id.MinimalTemplateID
|
||||
if build.BaseTemplate != "minimal" {
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return
|
||||
}
|
||||
baseTeamID = baseTmpl.TeamID
|
||||
baseTemplateID = baseTmpl.ID
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0, // no auto-pause for builds
|
||||
DiskSizeMb: 5120, // 5 GB for template builds
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return
|
||||
}
|
||||
// Capture sandbox metadata (envd/kernel/firecracker/agent versions).
|
||||
sandboxMetadata := resp.Msg.Metadata
|
||||
|
||||
// Record sandbox/host association.
|
||||
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
// Upload and extract build archive if provided.
|
||||
archive := s.takeArchive(buildIDStr)
|
||||
if len(archive) > 0 {
|
||||
if err := s.uploadAndExtractArchive(buildCtx, agent, sandboxIDStr, archive, buildIDStr); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("archive upload failed: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
log = log.With("sandbox_id", sandboxIDStr)
|
||||
|
||||
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
|
||||
// valid; panic on error is appropriate here since it would be a programmer mistake.
|
||||
@ -376,16 +339,35 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
}
|
||||
bctx := &recipe.ExecContext{EnvVars: envVars, User: "root"}
|
||||
|
||||
// Per-step progress callback for live UI updates.
|
||||
progressFn := func(currentStep int, allEntries []recipe.BuildLogEntry) {
|
||||
s.updateLogs(buildCtx, buildID, currentStep, allEntries)
|
||||
}
|
||||
streamFn := s.ptyStreamExec(agent)
|
||||
|
||||
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec, func(currentStep int, phaseEntries []recipe.BuildLogEntry) {
|
||||
// Progress callback: combine prior logs with current phase entries.
|
||||
progressFn(currentStep, append(logs, phaseEntries...))
|
||||
})
|
||||
// step-start: published before each step begins.
|
||||
onStepStart := func(stepNum int, ph string, st recipe.Step) {
|
||||
publishBuildEvent(buildCtx, s.Redis, buildIDStr, BuildStreamEvent{
|
||||
Type: "step-start", Step: stepNum, Phase: ph, Cmd: st.Raw,
|
||||
})
|
||||
}
|
||||
// output: raw PTY bytes from a streaming RUN step, base64-encoded.
|
||||
onChunk := func(stepNum int, data []byte) {
|
||||
publishBuildEvent(buildCtx, s.Redis, buildIDStr, BuildStreamEvent{
|
||||
Type: "output", Step: stepNum, Data: base64.StdEncoding.EncodeToString(data),
|
||||
})
|
||||
}
|
||||
// onProgress: persist the DB log snapshot and publish step-end.
|
||||
onProgress := func(currentStep int, phaseEntries []recipe.BuildLogEntry) {
|
||||
s.updateLogs(buildCtx, buildID, currentStep, append(logs, phaseEntries...))
|
||||
if len(phaseEntries) > 0 {
|
||||
last := phaseEntries[len(phaseEntries)-1]
|
||||
publishBuildEvent(buildCtx, s.Redis, buildIDStr, BuildStreamEvent{
|
||||
Type: "step-end", Step: last.Step, Phase: last.Phase, Cmd: last.Cmd,
|
||||
Exit: last.Exit, Ok: last.Ok, ElapsedMs: last.Elapsed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step,
|
||||
defaultTimeout, bctx, agent.Exec, streamFn, onStepStart, onChunk, onProgress)
|
||||
logs = append(logs, newEntries...)
|
||||
step = nextStep
|
||||
s.updateLogs(buildCtx, buildID, step, logs)
|
||||
@ -408,15 +390,16 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Phase 1: Pre-build (as root) — creates wrenn-user, updates apt.
|
||||
// Phase 1: Pre-build (as root) — apt update.
|
||||
if !build.SkipPrePost {
|
||||
if !runPhase("pre-build", preBuildSteps, 0) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: User recipe — starts as wrenn-user (set by USER in pre-build)
|
||||
// or root if skip_pre_post.
|
||||
// Phase 2: Recipe — the persisted recipe. For non-root builds it begins
|
||||
// with the injected USER/WORKDIR steps that create and switch to the build
|
||||
// user; for run_as_root builds it runs as root throughout.
|
||||
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
|
||||
return
|
||||
}
|
||||
@ -435,81 +418,186 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck or direct snapshot.
|
||||
// Finalize: healthcheck/snapshot/flatten → persist template → mark success.
|
||||
s.finalizeBuild(buildCtx, buildID, build, agent, sandboxIDStr, templateDefaultUser, templateDefaultEnv, sandboxMetadata, log)
|
||||
}
|
||||
|
||||
// provisionBuildSandbox picks a host, creates a sandbox, and uploads the build
|
||||
// archive. On failure it calls failBuild and returns an error.
|
||||
func (s *BuildService) provisionBuildSandbox(
|
||||
ctx context.Context,
|
||||
buildID pgtype.UUID,
|
||||
buildIDStr string,
|
||||
build db.TemplateBuild,
|
||||
log *slog.Logger,
|
||||
) (buildAgentClient, string, map[string]string, error) {
|
||||
host, err := s.Scheduler.SelectHost(ctx, id.PlatformTeamID, false, build.MemoryMb, 5120)
|
||||
if err != nil {
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("no host available: %v", err))
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("agent client error: %v", err))
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
log.Info("provisioning build sandbox", "sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
|
||||
|
||||
// All base templates — including the built-in system ones — are
|
||||
// platform-owned rows, so resolve the path from the DB record.
|
||||
baseTmpl, err := s.DB.GetPlatformTemplateByName(ctx, build.BaseTemplate)
|
||||
if err != nil {
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
|
||||
return nil, "", nil, err
|
||||
}
|
||||
baseTeamID := baseTmpl.TeamID
|
||||
baseTemplateID := baseTmpl.ID
|
||||
|
||||
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: build.BaseTemplate,
|
||||
TeamId: id.UUIDString(baseTeamID),
|
||||
TemplateId: id.UUIDString(baseTemplateID),
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0,
|
||||
DiskSizeMb: 0,
|
||||
}))
|
||||
if err != nil {
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
|
||||
return nil, "", nil, err
|
||||
}
|
||||
sandboxMetadata := resp.Msg.Metadata
|
||||
|
||||
_ = s.DB.UpdateBuildSandbox(ctx, db.UpdateBuildSandboxParams{
|
||||
ID: buildID,
|
||||
SandboxID: sandboxID,
|
||||
HostID: host.ID,
|
||||
})
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: id.PlatformTeamID,
|
||||
HostID: host.ID,
|
||||
Template: build.BaseTemplate,
|
||||
Status: "running",
|
||||
Vcpus: build.Vcpus,
|
||||
MemoryMb: build.MemoryMb,
|
||||
TimeoutSec: 0,
|
||||
DiskSizeMb: 0,
|
||||
TemplateID: baseTemplateID,
|
||||
TemplateTeamID: baseTeamID,
|
||||
Metadata: []byte("{}"),
|
||||
}); err != nil {
|
||||
log.Warn("failed to insert builder sandbox record", "error", err)
|
||||
}
|
||||
|
||||
if resp.Msg.DiskSizeMb > 0 {
|
||||
if err := s.DB.UpdateSandboxDiskSize(ctx, db.UpdateSandboxDiskSizeParams{
|
||||
ID: sandboxID,
|
||||
DiskSizeMb: resp.Msg.DiskSizeMb,
|
||||
}); err != nil {
|
||||
log.Warn("failed to update builder sandbox disk size", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
archive := s.takeArchive(buildIDStr)
|
||||
if len(archive) > 0 {
|
||||
if err := s.uploadAndExtractArchive(ctx, agent, sandboxIDStr, archive, buildIDStr); err != nil {
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("archive upload failed: %v", err))
|
||||
return nil, "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return agent, sandboxIDStr, sandboxMetadata, nil
|
||||
}
|
||||
|
||||
// finalizeBuild handles the healthcheck/snapshot/flatten step and persists the
|
||||
// template record. Called after all recipe phases complete successfully.
|
||||
func (s *BuildService) finalizeBuild(
|
||||
ctx context.Context,
|
||||
buildID pgtype.UUID,
|
||||
build db.TemplateBuild,
|
||||
agent buildAgentClient,
|
||||
sandboxIDStr string,
|
||||
defaultUser string,
|
||||
defaultEnv map[string]string,
|
||||
sandboxMetadata map[string]string,
|
||||
log *slog.Logger,
|
||||
) {
|
||||
var sizeBytes int64
|
||||
if build.Healthcheck != "" {
|
||||
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
|
||||
return
|
||||
}
|
||||
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
|
||||
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc, templateDefaultUser); err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
if err := s.waitForHealthcheck(ctx, agent, sandboxIDStr, hc, defaultUser); err != nil {
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Healthcheck passed → full snapshot (with memory/CPU state).
|
||||
log.Info("healthcheck passed, creating snapshot")
|
||||
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
snapResp, err := agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = snapResp.Msg.SizeBytes
|
||||
} else {
|
||||
// No healthcheck → image-only template (rootfs only).
|
||||
log.Info("no healthcheck, flattening rootfs")
|
||||
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
flatResp, err := agent.FlattenRootfs(ctx, connect.NewRequest(&pb.FlattenRootfsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: build.Name,
|
||||
TeamId: id.UUIDString(build.TeamID),
|
||||
TemplateId: id.UUIDString(build.TemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
s.destroySandbox(buildCtx, agent, sandboxIDStr)
|
||||
if buildCtx.Err() != nil {
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
s.failBuild(ctx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
|
||||
return
|
||||
}
|
||||
sizeBytes = flatResp.Msg.SizeBytes
|
||||
}
|
||||
|
||||
// Insert into templates table as a global (platform) template.
|
||||
templateType := "base"
|
||||
if build.Healthcheck != "" {
|
||||
templateType = "snapshot"
|
||||
}
|
||||
|
||||
// Serialize env vars for DB storage.
|
||||
defaultEnvJSON, err := json.Marshal(templateDefaultEnv)
|
||||
defaultEnvJSON, err := json.Marshal(defaultEnv)
|
||||
if err != nil {
|
||||
defaultEnvJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Serialize sandbox metadata for DB storage.
|
||||
metadataJSON, err := json.Marshal(sandboxMetadata)
|
||||
if err != nil || len(sandboxMetadata) == 0 {
|
||||
metadataJSON = []byte("{}")
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
|
||||
if _, err := s.DB.InsertTemplate(ctx, db.InsertTemplateParams{
|
||||
ID: build.TemplateID,
|
||||
Name: build.Name,
|
||||
Type: templateType,
|
||||
@ -517,33 +605,28 @@ func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
|
||||
MemoryMb: build.MemoryMb,
|
||||
SizeBytes: sizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultUser: defaultUser,
|
||||
DefaultEnv: defaultEnvJSON,
|
||||
Metadata: metadataJSON,
|
||||
}); err != nil {
|
||||
log.Error("failed to insert template record", "error", err)
|
||||
// Build succeeded on disk, just DB record failed — don't mark as failed.
|
||||
}
|
||||
|
||||
// Record defaults and metadata on the build record for inspection.
|
||||
_ = s.DB.UpdateBuildDefaults(buildCtx, db.UpdateBuildDefaultsParams{
|
||||
_ = s.DB.UpdateBuildDefaults(ctx, db.UpdateBuildDefaultsParams{
|
||||
ID: buildID,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultUser: defaultUser,
|
||||
DefaultEnv: defaultEnvJSON,
|
||||
Metadata: metadataJSON,
|
||||
})
|
||||
|
||||
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
|
||||
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
|
||||
// No additional destroy needed.
|
||||
|
||||
// Mark build as success.
|
||||
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
|
||||
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
|
||||
ID: buildID, Status: "success",
|
||||
}); err != nil {
|
||||
log.Error("failed to mark build as success", "error", err)
|
||||
}
|
||||
s.publishStatus(ctx, buildID, "success", build.TotalSteps, build.TotalSteps, "")
|
||||
|
||||
s.destroySandbox(ctx, agent, sandboxIDStr)
|
||||
log.Info("template build completed successfully", "name", build.Name)
|
||||
}
|
||||
|
||||
@ -642,6 +725,91 @@ func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg
|
||||
}); err != nil {
|
||||
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
|
||||
}
|
||||
s.publishStatus(ctx, buildID, "failed", 0, 0, errMsg)
|
||||
}
|
||||
|
||||
// build PTY dimensions — wide enough for tools that adapt output to terminal
|
||||
// width (apt/pip progress bars).
|
||||
const (
|
||||
buildPtyCols = 120
|
||||
buildPtyRows = 40
|
||||
)
|
||||
|
||||
// publishStatus emits a build-status event to the build's live stream.
|
||||
func (s *BuildService) publishStatus(ctx context.Context, buildID pgtype.UUID, status string, currentStep, totalSteps int32, errMsg string) {
|
||||
publishBuildEvent(ctx, s.Redis, id.FormatBuildID(buildID), BuildStreamEvent{
|
||||
Type: "build-status",
|
||||
Status: status,
|
||||
CurrentStep: currentStep,
|
||||
TotalSteps: totalSteps,
|
||||
Error: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// ptyStreamExec returns a recipe.StreamExecFunc that runs a shell command in a
|
||||
// PTY on the build sandbox via the host agent and streams its output. A PTY
|
||||
// makes build tools emit unbuffered, colorized output (apt/pip progress bars).
|
||||
func (s *BuildService) ptyStreamExec(agent buildAgentClient) recipe.StreamExecFunc {
|
||||
return func(ctx context.Context, sandboxID, shellCmd string) (<-chan recipe.PtyChunk, error) {
|
||||
tag := "build-" + id.NewPtyTag()
|
||||
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
|
||||
SandboxId: sandboxID,
|
||||
Tag: tag,
|
||||
Cmd: "/bin/sh",
|
||||
Args: []string{"-c", shellCmd},
|
||||
Cols: buildPtyCols,
|
||||
Rows: buildPtyRows,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan recipe.PtyChunk, 64)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
gotExit := false
|
||||
for stream.Receive() {
|
||||
switch ev := stream.Msg().Event.(type) {
|
||||
case *pb.PtyAttachResponse_Output:
|
||||
select {
|
||||
case ch <- recipe.PtyChunk{Data: ev.Output.Data}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case *pb.PtyAttachResponse_Exited:
|
||||
gotExit = true
|
||||
select {
|
||||
case ch <- recipe.PtyChunk{Done: true, Exit: ev.Exited.ExitCode}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if gotExit {
|
||||
return
|
||||
}
|
||||
// Stream ended with no exit event: timeout, cancellation, or error.
|
||||
// Kill the lingering guest process so it does not keep running.
|
||||
streamErr := stream.Err()
|
||||
if ctx.Err() != nil {
|
||||
killCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
_, _ = agent.PtyKill(killCtx, connect.NewRequest(&pb.PtyKillRequest{
|
||||
SandboxId: sandboxID, Tag: tag,
|
||||
}))
|
||||
cancel()
|
||||
if streamErr == nil {
|
||||
streamErr = ctx.Err()
|
||||
}
|
||||
}
|
||||
if streamErr == nil {
|
||||
streamErr = fmt.Errorf("pty stream ended without an exit event")
|
||||
}
|
||||
ch <- recipe.PtyChunk{Err: streamErr}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
|
||||
@ -653,6 +821,13 @@ func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient,
|
||||
})); err != nil {
|
||||
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
if sbID, err := id.ParseSandboxID(sandboxIDStr); err == nil {
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sbID, Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("failed to mark builder sandbox stopped", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
|
||||
@ -768,7 +943,7 @@ var runtimeEnvVars = map[string]bool{
|
||||
"HOME": true, "USER": true, "LOGNAME": true, "SHELL": true,
|
||||
"PWD": true, "OLDPWD": true, "HOSTNAME": true, "TERM": true,
|
||||
"SHLVL": true, "_": true,
|
||||
// Per-sandbox identifiers set by envd at boot via MMDS.
|
||||
// Per-sandbox identifiers set by envd at boot via PostInit.
|
||||
"WRENN_SANDBOX_ID": true, "WRENN_TEMPLATE_ID": true,
|
||||
}
|
||||
|
||||
|
||||
143
pkg/service/build_broker.go
Normal file
143
pkg/service/build_broker.go
Normal file
@ -0,0 +1,143 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// buildSubBuffer is the per-subscriber channel buffer. A slow WebSocket
|
||||
// consumer that fills the buffer drops live events; it recovers the full
|
||||
// build state from the DB log on reconnect.
|
||||
const buildSubBuffer = 256
|
||||
|
||||
// buildBrokerReconnect is the backoff before re-subscribing to Redis after a
|
||||
// subscription error.
|
||||
const buildBrokerReconnect = 2 * time.Second
|
||||
|
||||
// BuildBroker fans build events out from per-build Redis pub/sub channels to
|
||||
// in-process WebSocket subscribers. A Redis subscription is started lazily for
|
||||
// a build when its first client connects and torn down when the last leaves.
|
||||
//
|
||||
// The build worker publishes via publishBuildEvent (Redis only); the broker is
|
||||
// purely the read/fan-out side. Decoupling through Redis means the worker and
|
||||
// the WebSocket handler need not run in the same process.
|
||||
type BuildBroker struct {
|
||||
rdb *redis.Client
|
||||
mu sync.Mutex
|
||||
builds map[string]*buildFanout
|
||||
}
|
||||
|
||||
type buildFanout struct {
|
||||
subs map[chan BuildStreamEvent]struct{}
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewBuildBroker creates a broker reading from the given Redis client.
|
||||
func NewBuildBroker(rdb *redis.Client) *BuildBroker {
|
||||
return &BuildBroker{rdb: rdb, builds: make(map[string]*buildFanout)}
|
||||
}
|
||||
|
||||
// Subscribe registers an in-process subscriber for buildID's event stream and
|
||||
// returns the receive channel plus a release function. The first subscriber
|
||||
// for a build starts its Redis subscription; the last to release stops it.
|
||||
// The release function is idempotent and closes the channel.
|
||||
func (b *BuildBroker) Subscribe(buildID string) (<-chan BuildStreamEvent, func()) {
|
||||
ch := make(chan BuildStreamEvent, buildSubBuffer)
|
||||
|
||||
b.mu.Lock()
|
||||
fan, ok := b.builds[buildID]
|
||||
if !ok {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
fan = &buildFanout{subs: make(map[chan BuildStreamEvent]struct{}), cancel: cancel}
|
||||
b.builds[buildID] = fan
|
||||
go b.run(ctx, buildID)
|
||||
}
|
||||
fan.subs[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
|
||||
var once sync.Once
|
||||
release := func() {
|
||||
once.Do(func() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fan, ok := b.builds[buildID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, present := fan.subs[ch]; !present {
|
||||
return
|
||||
}
|
||||
delete(fan.subs, ch)
|
||||
close(ch)
|
||||
if len(fan.subs) == 0 {
|
||||
fan.cancel()
|
||||
delete(b.builds, buildID)
|
||||
}
|
||||
})
|
||||
}
|
||||
return ch, release
|
||||
}
|
||||
|
||||
// run keeps a Redis subscription alive for buildID, reconnecting on error,
|
||||
// until the fanout's context is cancelled (last subscriber left).
|
||||
func (b *BuildBroker) run(ctx context.Context, buildID string) {
|
||||
for ctx.Err() == nil {
|
||||
b.subscribeOnce(ctx, buildID)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(buildBrokerReconnect):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BuildBroker) subscribeOnce(ctx context.Context, buildID string) {
|
||||
sub := b.rdb.Subscribe(ctx, buildStreamChannel(buildID))
|
||||
defer sub.Close()
|
||||
|
||||
msgCh := sub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-msgCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var ev BuildStreamEvent
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &ev); err != nil {
|
||||
slog.Warn("build broker: bad event payload", "build_id", buildID, "error", err)
|
||||
continue
|
||||
}
|
||||
b.dispatch(buildID, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch fans one event to every in-process subscriber. The send is
|
||||
// non-blocking; a full subscriber buffer drops the event. The mutex is held
|
||||
// for the whole dispatch so a concurrent release cannot close a channel
|
||||
// mid-send.
|
||||
func (b *BuildBroker) dispatch(buildID string, ev BuildStreamEvent) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fan, ok := b.builds[buildID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for ch := range fan.subs {
|
||||
select {
|
||||
case ch <- ev:
|
||||
default:
|
||||
slog.Debug("build broker: dropped event for slow consumer", "build_id", buildID)
|
||||
}
|
||||
}
|
||||
}
|
||||
72
pkg/service/build_stream.go
Normal file
72
pkg/service/build_stream.go
Normal file
@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// buildStreamChannelPrefix is the Redis pub/sub channel prefix for live build
|
||||
// events. One channel per build: wrenn:build:{buildID}.
|
||||
const buildStreamChannelPrefix = "wrenn:build:"
|
||||
|
||||
func buildStreamChannel(buildID string) string {
|
||||
return buildStreamChannelPrefix + buildID
|
||||
}
|
||||
|
||||
// BuildStreamEvent is one event in a build's live stream. The same struct is
|
||||
// published to Redis by the build worker and forwarded verbatim to admin
|
||||
// WebSocket clients, so its JSON shape is the wire contract for both.
|
||||
//
|
||||
// Type discriminates the payload:
|
||||
// - "step-start": Step, Phase, Cmd set.
|
||||
// - "output": Step, Data (base64 PTY bytes) set.
|
||||
// - "step-end": Step, Phase, Cmd, Exit, Ok, ElapsedMs set.
|
||||
// - "build-status": Status, CurrentStep, TotalSteps, Error set.
|
||||
type BuildStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Phase string `json:"phase,omitempty"`
|
||||
Cmd string `json:"cmd,omitempty"`
|
||||
Data string `json:"data,omitempty"` // base64-encoded PTY output bytes
|
||||
Exit int32 `json:"exit,omitempty"`
|
||||
Ok bool `json:"ok,omitempty"`
|
||||
ElapsedMs int64 `json:"elapsed_ms,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
CurrentStep int32 `json:"current_step,omitempty"`
|
||||
TotalSteps int32 `json:"total_steps,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
T int64 `json:"t"` // unix milliseconds, set at publish time
|
||||
}
|
||||
|
||||
// IsTerminalBuildStatus reports whether a build status is final (the worker
|
||||
// will publish no further events for it).
|
||||
func IsTerminalBuildStatus(status string) bool {
|
||||
switch status {
|
||||
case "success", "failed", "cancelled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// publishBuildEvent fire-and-forget publishes one event to a build's Redis
|
||||
// channel. A missing/closed Redis connection only drops live events; the WS
|
||||
// client always has the DB log history to fall back on.
|
||||
func publishBuildEvent(ctx context.Context, rdb *redis.Client, buildID string, ev BuildStreamEvent) {
|
||||
if rdb == nil {
|
||||
return
|
||||
}
|
||||
ev.T = time.Now().UnixMilli()
|
||||
payload, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
slog.Warn("build event marshal failed", "build_id", buildID, "error", err)
|
||||
return
|
||||
}
|
||||
if err := rdb.Publish(ctx, buildStreamChannel(buildID), payload).Err(); err != nil {
|
||||
slog.Debug("build event publish failed", "build_id", buildID, "error", err)
|
||||
}
|
||||
}
|
||||
@ -94,6 +94,31 @@ type regTokenPayload struct {
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
func (s *HostService) issueRegistrationToken(ctx context.Context, hostID, createdBy pgtype.UUID) (string, error) {
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return "", fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: createdBy,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
|
||||
func requireAdminOrOwner(role string) error {
|
||||
if role == "owner" || role == "admin" {
|
||||
@ -159,26 +184,9 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||
}
|
||||
|
||||
// Generate registration token and store in Redis + Postgres audit trail.
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
token, err := s.issueRegistrationToken(ctx, hostID, p.RequestingUserID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
@ -218,25 +226,9 @@ func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamI
|
||||
}
|
||||
}
|
||||
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: id.FormatHostID(hostID),
|
||||
TokenID: id.FormatHostTokenID(tokenID),
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
|
||||
token, err := s.issueRegistrationToken(ctx, hostID, userID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, err
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
@ -434,13 +426,27 @@ func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
|
||||
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.ListHostsByTeamRow, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
rows, err := s.DB.ListHostsAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]db.ListHostsByTeamRow, len(rows))
|
||||
for i, r := range rows {
|
||||
result[i] = db.ListHostsByTeamRow(r)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListAdmin returns all hosts with aggregated resource consumption.
|
||||
// Admin-only — caller must verify admin status.
|
||||
func (s *HostService) ListAdmin(ctx context.Context) ([]db.ListHostsAdminRow, error) {
|
||||
return s.DB.ListHostsAdmin(ctx)
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
|
||||
@ -18,12 +18,28 @@ import (
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// SandboxEventPublisher writes sandbox lifecycle events to the Redis stream.
|
||||
type SandboxEventPublisher func(ctx context.Context, event SandboxStateEvent)
|
||||
|
||||
// SandboxStateEvent is the event payload published to the Redis stream.
|
||||
type SandboxStateEvent struct {
|
||||
Event string `json:"event"`
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
HostID string `json:"host_id"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// SandboxService provides sandbox lifecycle operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type SandboxService struct {
|
||||
DB *db.Queries
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
DB *db.Queries
|
||||
Pool *lifecycle.HostClientPool
|
||||
Scheduler scheduler.HostScheduler
|
||||
PublishEvent SandboxEventPublisher
|
||||
}
|
||||
|
||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||
@ -33,7 +49,24 @@ type SandboxCreateParams struct {
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
TimeoutSec int32
|
||||
DiskSizeMB int32
|
||||
}
|
||||
|
||||
// MinTimeoutSec mirrors internal/sandbox.MinTimeoutSec. Sub-minute TTLs race
|
||||
// the post-create startup window (DB insert → /init → memory loader); the
|
||||
// agent silently clamps anyway, but the CP must clamp too so the DB record
|
||||
// agrees with what the agent runs. 0 is preserved (no TTL).
|
||||
const MinTimeoutSec int32 = 60
|
||||
|
||||
// clampTimeout normalises a caller-supplied TTL the same way the host agent
|
||||
// does. Keep in sync with internal/sandbox.clampTimeout.
|
||||
func clampTimeout(timeoutSec int32) int32 {
|
||||
if timeoutSec <= 0 {
|
||||
return 0
|
||||
}
|
||||
if timeoutSec < MinTimeoutSec {
|
||||
return MinTimeoutSec
|
||||
}
|
||||
return timeoutSec
|
||||
}
|
||||
|
||||
// agentForSandbox looks up the host for the given sandbox and returns a client.
|
||||
@ -42,15 +75,31 @@ func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.U
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
host, err := s.DB.GetHost(ctx, sb.HostID)
|
||||
agent, err := s.agentForHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
|
||||
return nil, db.Sandbox{}, err
|
||||
}
|
||||
return agent, sb, nil
|
||||
}
|
||||
|
||||
// agentForHost returns the host client by host UUID, skipping the sandbox
|
||||
// lookup. Used by callers that already have a db.Sandbox in hand.
|
||||
func (s *SandboxService) agentForHost(ctx context.Context, hostID pgtype.UUID) (hostagentClient, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err != nil {
|
||||
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
|
||||
return nil, fmt.Errorf("get agent client: %w", err)
|
||||
}
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (s *SandboxService) publishEvent(ctx context.Context, event SandboxStateEvent) {
|
||||
if s.PublishEvent != nil {
|
||||
s.PublishEvent(ctx, event)
|
||||
}
|
||||
return agent, sb, nil
|
||||
}
|
||||
|
||||
// hostagentClient is a local alias to avoid the full package path in signatures.
|
||||
@ -62,13 +111,16 @@ type hostagentClient = interface {
|
||||
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
|
||||
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
|
||||
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
|
||||
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
|
||||
// DB record, calls the host agent, and updates the record to running.
|
||||
// Create creates a new sandbox asynchronously: picks a host, inserts a
|
||||
// "starting" DB record, fires the agent RPC in a background goroutine, and
|
||||
// returns the sandbox immediately. The background goroutine publishes a
|
||||
// sandbox event to the Redis stream when the operation completes.
|
||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||
if p.Template == "" {
|
||||
p.Template = "minimal"
|
||||
p.Template = "minimal-ubuntu"
|
||||
}
|
||||
if err := validate.SafeName(p.Template); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid template name: %w", err)
|
||||
@ -79,46 +131,37 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
if p.DiskSizeMB <= 0 {
|
||||
p.DiskSizeMB = 5120 // 5 GB default
|
||||
}
|
||||
p.TimeoutSec = clampTimeout(p.TimeoutSec)
|
||||
|
||||
// Resolve template name → (teamID, templateID).
|
||||
templateTeamID := id.PlatformTeamID
|
||||
templateID := id.MinimalTemplateID
|
||||
var templateDefaultUser string
|
||||
// Resolve template name → (teamID, templateID). System base templates are
|
||||
// platform-owned rows like any other, so the lookup handles them too (the
|
||||
// query also matches platform templates for any team).
|
||||
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
|
||||
}
|
||||
templateTeamID := tmpl.TeamID
|
||||
templateID := tmpl.ID
|
||||
templateDefaultUser := tmpl.DefaultUser
|
||||
var templateDefaultEnv map[string]string
|
||||
if p.Template != "minimal" {
|
||||
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
|
||||
}
|
||||
templateTeamID = tmpl.TeamID
|
||||
templateID = tmpl.ID
|
||||
templateDefaultUser = tmpl.DefaultUser
|
||||
// Parse default_env JSONB into a map.
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &templateDefaultEnv)
|
||||
}
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl.Type == "snapshot" {
|
||||
p.VCPUs = tmpl.Vcpus
|
||||
p.MemoryMB = tmpl.MemoryMb
|
||||
}
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &templateDefaultEnv)
|
||||
}
|
||||
if tmpl.Type == "snapshot" {
|
||||
p.VCPUs = tmpl.Vcpus
|
||||
p.MemoryMB = tmpl.MemoryMb
|
||||
}
|
||||
|
||||
if !p.TeamID.Valid {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
|
||||
}
|
||||
|
||||
// Determine whether this team uses BYOC hosts or platform hosts.
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
|
||||
// Pick a host for this sandbox.
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc, p.MemoryMB, p.DiskSizeMB)
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc, p.MemoryMB, 0)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||
}
|
||||
@ -130,25 +173,42 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
hostIDStr := id.FormatHostID(host.ID)
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
sb, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: host.ID,
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Status: "starting",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
DiskSizeMb: 0,
|
||||
TemplateID: templateID,
|
||||
TemplateTeamID: templateTeamID,
|
||||
Metadata: []byte("{}"),
|
||||
}); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||
}
|
||||
|
||||
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
teamIDStr := id.FormatTeamID(p.TeamID)
|
||||
go s.createInBackground(sandboxID, sandboxIDStr, hostIDStr, teamIDStr, agent, p, templateTeamID, templateID, templateDefaultUser, templateDefaultEnv)
|
||||
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
func (s *SandboxService) createInBackground(
|
||||
sandboxID pgtype.UUID, sandboxIDStr, hostIDStr, teamIDStr string,
|
||||
agent hostagentClient, p SandboxCreateParams,
|
||||
templateTeamID, templateID pgtype.UUID,
|
||||
defaultUser string, defaultEnv map[string]string,
|
||||
) {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
resp, err := agent.CreateSandbox(bgCtx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Template: p.Template,
|
||||
TeamId: id.UUIDString(templateTeamID),
|
||||
@ -156,46 +216,62 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
DiskSizeMb: p.DiskSizeMB,
|
||||
DefaultUser: templateDefaultUser,
|
||||
DefaultEnv: templateDefaultEnv,
|
||||
DiskSizeMb: 0,
|
||||
DefaultUser: defaultUser,
|
||||
DefaultEnv: defaultEnv,
|
||||
}))
|
||||
if err != nil {
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "error",
|
||||
slog.Warn("background create failed", "sandbox_id", sandboxIDStr, "error", err)
|
||||
errCtx, errCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer errCancel()
|
||||
if _, dbErr := s.DB.UpdateSandboxStatusIf(errCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "starting", Status_2: "error",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
|
||||
slog.Warn("failed to update sandbox to error after create failure", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
s.publishEvent(errCtx, SandboxStateEvent{
|
||||
Event: "sandbox.failed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Error: err.Error(), Timestamp: time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Msg.DiskSizeMb > 0 {
|
||||
if err := s.DB.UpdateSandboxDiskSize(bgCtx, db.UpdateSandboxDiskSizeParams{
|
||||
ID: sandboxID,
|
||||
DiskSizeMb: resp.Msg.DiskSizeMb,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update sandbox disk size", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err := s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
if _, dbErr := s.DB.UpdateSandboxRunningIf(bgCtx, db.UpdateSandboxRunningIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "starting",
|
||||
HostIp: resp.Msg.HostIp,
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update sandbox running: %w", err)
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox running after create", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
|
||||
// Store runtime metadata from the agent (envd/kernel/firecracker/agent versions).
|
||||
if meta := resp.Msg.Metadata; len(meta) > 0 {
|
||||
metaJSON, _ := json.Marshal(meta)
|
||||
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID,
|
||||
Metadata: metaJSON,
|
||||
if err := s.DB.UpdateSandboxMetadata(bgCtx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID, Metadata: metaJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to store sandbox metadata", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
sb.Metadata = metaJSON
|
||||
}
|
||||
|
||||
return sb, nil
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.started", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
HostIP: resp.Msg.HostIp, Metadata: resp.Msg.Metadata,
|
||||
Timestamp: now.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
|
||||
@ -208,152 +284,331 @@ func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID)
|
||||
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Pause snapshots and freezes a running sandbox to disk.
|
||||
// Pause asynchronously pauses a running sandbox. The DB CAS from "running"
|
||||
// to "pausing" is the authoritative gate against concurrent Pause/Destroy
|
||||
// calls; if it loses, no agent RPC fires.
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
if sb.Status == "paused" {
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if _, err := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "running", Status_2: "pausing",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not in running state (current: %s)", sb.Status)
|
||||
}
|
||||
|
||||
agent, err := s.agentForHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
// Roll back the CAS so the sandbox isn't stuck in "pausing".
|
||||
if _, rerr := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "pausing", Status_2: "running",
|
||||
}); rerr != nil {
|
||||
slog.Warn("failed to roll back pausing→running", "id", id.FormatSandboxID(sandboxID), "error", rerr)
|
||||
}
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
hostIDStr := id.FormatHostID(sb.HostID)
|
||||
teamIDStr := id.FormatTeamID(sb.TeamID)
|
||||
|
||||
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
|
||||
// mark the sandbox "stopped" while the host agent processes the pause.
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
|
||||
}
|
||||
go s.pauseInBackground(sandboxID, sandboxIDStr, hostIDStr, teamIDStr, agent)
|
||||
|
||||
// Flush all metrics tiers before pausing so data survives in DB.
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
|
||||
|
||||
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
// Check if the agent still has this sandbox. If it was destroyed
|
||||
// (e.g. frozen VM couldn't be resumed), mark as "error" instead of
|
||||
// reverting to "running" — which would create a ghost record.
|
||||
// Use a fresh context since the original ctx may already be expired.
|
||||
revertStatus := "running"
|
||||
pingCtx, pingCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
if _, pingErr := agent.PingSandbox(pingCtx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); pingErr != nil {
|
||||
revertStatus = "error"
|
||||
slog.Warn("sandbox gone from agent after failed pause, marking as error", "sandbox_id", sandboxIDStr)
|
||||
}
|
||||
pingCancel()
|
||||
dbCtx, dbCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(dbCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: revertStatus,
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
dbCancel()
|
||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||
}
|
||||
|
||||
sb, err = s.DB.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
|
||||
}
|
||||
sb.Status = "pausing"
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Resume restores a paused sandbox from snapshot.
|
||||
func (s *SandboxService) pauseInBackground(sandboxID pgtype.UUID, sandboxIDStr, hostIDStr, teamIDStr string, agent hostagentClient) {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Flush metrics before the VM stops sampling so the persisted history
|
||||
// covers the entire run-up to the pause.
|
||||
s.flushAndPersistMetrics(bgCtx, agent, sandboxID, true)
|
||||
|
||||
if _, err := agent.PauseSandbox(bgCtx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil {
|
||||
slog.Warn("background pause failed", "sandbox_id", sandboxIDStr, "error", err)
|
||||
// Best-effort: try to recover the sandbox back to "running" so the
|
||||
// user isn't stuck in "pausing".
|
||||
if _, dbErr := s.DB.UpdateSandboxStatusIf(bgCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "pausing", Status_2: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to recover pausing→running after pause failure", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.pause_failed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Error: err.Error(), Timestamp: time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatusIf(bgCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "pausing", Status_2: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update sandbox to paused", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.paused", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// Resume asynchronously resumes a paused sandbox on its original host.
|
||||
// The DB CAS from "paused" to "resuming" gates concurrent Resume/Destroy.
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "paused" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||
if sb.Status == "running" {
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if _, err := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "paused", Status_2: "resuming",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not in paused state (current: %s)", sb.Status)
|
||||
}
|
||||
|
||||
agent, err := s.agentForHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
if _, rerr := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "resuming", Status_2: "paused",
|
||||
}); rerr != nil {
|
||||
slog.Warn("failed to roll back resuming→paused", "id", id.FormatSandboxID(sandboxID), "error", rerr)
|
||||
}
|
||||
return db.Sandbox{}, err
|
||||
}
|
||||
|
||||
// Look up template defaults so a resumed sandbox has the same env as
|
||||
// the original Create did.
|
||||
var defaultUser string
|
||||
var defaultEnv map[string]string
|
||||
if tmpl, terr := s.DB.GetTemplate(ctx, sb.TemplateID); terr == nil {
|
||||
defaultUser = tmpl.DefaultUser
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &defaultEnv)
|
||||
}
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
hostIDStr := id.FormatHostID(sb.HostID)
|
||||
teamIDStr := id.FormatTeamID(sb.TeamID)
|
||||
|
||||
// Look up template defaults for resume.
|
||||
var resumeDefaultUser string
|
||||
var resumeDefaultEnv map[string]string
|
||||
if sb.TemplateID.Valid {
|
||||
tmpl, err := s.DB.GetTemplate(ctx, sb.TemplateID)
|
||||
if err == nil {
|
||||
resumeDefaultUser = tmpl.DefaultUser
|
||||
if len(tmpl.DefaultEnv) > 0 {
|
||||
_ = json.Unmarshal(tmpl.DefaultEnv, &resumeDefaultEnv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract kernel version hint from existing sandbox metadata.
|
||||
var kernelVersion string
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(sb.Metadata, &meta); err == nil {
|
||||
kernelVersion = meta["kernel_version"]
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
DefaultUser: resumeDefaultUser,
|
||||
DefaultEnv: resumeDefaultEnv,
|
||||
KernelVersion: kernelVersion,
|
||||
}))
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("agent resume: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err = s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
|
||||
// Update metadata with actual versions used after resume.
|
||||
if meta := resp.Msg.Metadata; len(meta) > 0 {
|
||||
metaJSON, _ := json.Marshal(meta)
|
||||
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID,
|
||||
Metadata: metaJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update sandbox metadata after resume", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
sb.Metadata = metaJSON
|
||||
}
|
||||
go s.resumeInBackground(sandboxID, sandboxIDStr, hostIDStr, teamIDStr, agent, sb.TimeoutSec, defaultUser, defaultEnv)
|
||||
|
||||
sb.Status = "resuming"
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox and marks it as stopped.
|
||||
func (s *SandboxService) resumeInBackground(
|
||||
sandboxID pgtype.UUID, sandboxIDStr, hostIDStr, teamIDStr string,
|
||||
agent hostagentClient, timeoutSec int32,
|
||||
defaultUser string, defaultEnv map[string]string,
|
||||
) {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
resp, err := agent.ResumeSandbox(bgCtx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
TimeoutSec: timeoutSec,
|
||||
DefaultUser: defaultUser,
|
||||
DefaultEnv: defaultEnv,
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("background resume failed", "sandbox_id", sandboxIDStr, "error", err)
|
||||
if _, dbErr := s.DB.UpdateSandboxStatusIf(bgCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "resuming", Status_2: "paused",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to recover resuming→paused after resume failure", "id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.resume_failed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Error: err.Error(), Timestamp: time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.UpdateSandboxRunningIf(bgCtx, db.UpdateSandboxRunningIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "resuming",
|
||||
HostIp: resp.Msg.HostIp,
|
||||
StartedAt: pgtype.Timestamptz{Time: now, Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update sandbox to running after resume", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
if meta := resp.Msg.Metadata; len(meta) > 0 {
|
||||
metaJSON, _ := json.Marshal(meta)
|
||||
if err := s.DB.UpdateSandboxMetadata(bgCtx, db.UpdateSandboxMetadataParams{
|
||||
ID: sandboxID, Metadata: metaJSON,
|
||||
}); err != nil {
|
||||
slog.Warn("failed to store sandbox metadata after resume", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.resumed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
HostIP: resp.Msg.HostIp, Metadata: resp.Msg.Metadata,
|
||||
Timestamp: now.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateSnapshot asynchronously snapshots a running or paused sandbox,
|
||||
// publishing the result as a new template owned by the sandbox's team. The DB
|
||||
// CAS from the sandbox's current status to "snapshotting" is the authoritative
|
||||
// gate against concurrent Pause/Snapshot/Destroy calls; if it loses, no agent
|
||||
// RPC fires. A running sandbox is snapshotted live (CH briefly paused, then
|
||||
// resumed); a paused sandbox is snapshotted from its on-disk artefacts without
|
||||
// reviving the VM. Either way the sandbox returns to its original status on
|
||||
// completion. Returns the sandbox (now "snapshotting") and the resolved name.
|
||||
func (s *SandboxService) CreateSnapshot(ctx context.Context, sandboxID, teamID pgtype.UUID, name string) (db.Sandbox, string, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, "", fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
return db.Sandbox{}, "", fmt.Errorf("sandbox is not running or paused (status: %s)", sb.Status)
|
||||
}
|
||||
origStatus := sb.Status
|
||||
|
||||
if name == "" {
|
||||
name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
return db.Sandbox{}, "", fmt.Errorf("invalid name: %w", err)
|
||||
}
|
||||
// Reject duplicate names up front so we don't pause the VM and dump memory
|
||||
// only to fail on the template insert at the very end.
|
||||
if _, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: teamID}); err == nil {
|
||||
return db.Sandbox{}, "", fmt.Errorf("conflict: a snapshot named %q already exists", name)
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: origStatus, Status_2: "snapshotting",
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, "", fmt.Errorf("sandbox not in %s state (current: %s)", origStatus, sb.Status)
|
||||
}
|
||||
|
||||
agent, err := s.agentForHost(ctx, sb.HostID)
|
||||
if err != nil {
|
||||
// Roll back the CAS so the sandbox isn't stuck in "snapshotting".
|
||||
if _, rerr := s.DB.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "snapshotting", Status_2: origStatus,
|
||||
}); rerr != nil {
|
||||
slog.Warn("failed to roll back snapshotting→"+origStatus, "id", id.FormatSandboxID(sandboxID), "error", rerr)
|
||||
}
|
||||
return db.Sandbox{}, "", err
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
hostIDStr := id.FormatHostID(sb.HostID)
|
||||
teamIDStr := id.FormatTeamID(sb.TeamID)
|
||||
|
||||
// Notify other clients that the badge moved to "snapshotting".
|
||||
s.publishStateChanged(ctx, sandboxIDStr, teamIDStr, hostIDStr, origStatus, "snapshotting")
|
||||
|
||||
go s.snapshotInBackground(sandboxID, sandboxIDStr, hostIDStr, teamIDStr, teamID, agent, name, origStatus, sb.Vcpus, sb.MemoryMb)
|
||||
|
||||
sb.Status = "snapshotting"
|
||||
return sb, name, nil
|
||||
}
|
||||
|
||||
func (s *SandboxService) snapshotInBackground(
|
||||
sandboxID pgtype.UUID, sandboxIDStr, hostIDStr, teamIDStr string, teamID pgtype.UUID,
|
||||
agent hostagentClient, name, origStatus string, vcpus, memoryMB int32,
|
||||
) {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
newTemplateID := id.NewSandboxID() // any random UUID
|
||||
templateUUID := pgtype.UUID{Bytes: newTemplateID.Bytes, Valid: true}
|
||||
|
||||
resp, err := agent.CreateSnapshot(bgCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: name,
|
||||
TeamId: id.UUIDString(teamID),
|
||||
TemplateId: id.UUIDString(templateUUID),
|
||||
}))
|
||||
|
||||
// Either way, the host-side op is done; return the badge to its original
|
||||
// status (running for a live snapshot, paused for an on-disk one). Use a CAS
|
||||
// so a concurrent Destroy (which sets "stopping") wins: if the CAS misses,
|
||||
// the sandbox is no longer ours and we must NOT announce its old status. The
|
||||
// snapshot itself is still valid and is registered below — a snapshot
|
||||
// template outlives its source sandbox.
|
||||
if _, derr := s.DB.UpdateSandboxStatusIf(bgCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "snapshotting", Status_2: origStatus,
|
||||
}); derr != nil {
|
||||
slog.Warn("snapshotting→"+origStatus+" CAS missed (sandbox moved on); skipping state signal", "sandbox_id", sandboxIDStr, "error", derr)
|
||||
} else {
|
||||
s.publishStateChanged(bgCtx, sandboxIDStr, teamIDStr, hostIDStr, "snapshotting", origStatus)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.Warn("background snapshot failed", "sandbox_id", sandboxIDStr, "error", err)
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.snapshot_failed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Metadata: map[string]string{"name": name}, Error: err.Error(), Timestamp: time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := s.DB.InsertTemplate(bgCtx, db.InsertTemplateParams{
|
||||
ID: templateUUID,
|
||||
Name: name,
|
||||
Type: "snapshot",
|
||||
Vcpus: vcpus,
|
||||
MemoryMb: memoryMB,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: teamID,
|
||||
DefaultUser: "",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: []byte("{}"),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert snapshot template", "sandbox_id", sandboxIDStr, "name", name, "error", err)
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.snapshot_failed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Metadata: map[string]string{"name": name}, Error: "failed to register snapshot", Timestamp: time.Now().Unix(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.snapshotted", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Metadata: map[string]string{"name": name}, Timestamp: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// publishStateChanged emits a transient capsule.state.changed event so the
|
||||
// dashboard flips the status badge during a transition that has no terminal
|
||||
// lifecycle verb of its own (e.g. the snapshotting round-trip).
|
||||
func (s *SandboxService) publishStateChanged(ctx context.Context, sandboxIDStr, teamIDStr, hostIDStr, from, to string) {
|
||||
s.publishEvent(ctx, SandboxStateEvent{
|
||||
Event: "sandbox.state_changed", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Metadata: map[string]string{"from": from, "to": to}, Timestamp: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox asynchronously. Pre-marks the DB status as
|
||||
// "stopping" and fires the agent RPC in a background goroutine.
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status == "stopped" || sb.Status == "error" {
|
||||
return nil
|
||||
}
|
||||
|
||||
agent, _, err := s.agentForSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
@ -361,35 +616,54 @@ func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.U
|
||||
}
|
||||
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
hostIDStr := id.FormatHostID(sb.HostID)
|
||||
teamIDStr := id.FormatTeamID(sb.TeamID)
|
||||
prevStatus := sb.Status
|
||||
|
||||
// If running, flush 24h tier metrics for analytics before destroying.
|
||||
if sb.Status == "running" {
|
||||
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopping",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("pre-mark stopping: %w", err)
|
||||
}
|
||||
|
||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
go s.destroyInBackground(sandboxID, sandboxIDStr, hostIDStr, teamIDStr, agent, prevStatus)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SandboxService) destroyInBackground(sandboxID pgtype.UUID, sandboxIDStr, hostIDStr, teamIDStr string, agent hostagentClient, prevStatus string) {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if prevStatus == "running" || prevStatus == "pausing" {
|
||||
s.flushAndPersistMetrics(bgCtx, agent, sandboxID, false)
|
||||
}
|
||||
|
||||
if _, err := agent.DestroySandbox(bgCtx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
return fmt.Errorf("agent destroy: %w", err)
|
||||
slog.Warn("background destroy failed", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
|
||||
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
|
||||
if sb.Status == "paused" {
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
if prevStatus == "paused" {
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(bgCtx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "10m",
|
||||
})
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
_ = s.DB.DeleteSandboxMetricPointsByTier(bgCtx, db.DeleteSandboxMetricPointsByTierParams{
|
||||
SandboxID: sandboxID, Tier: "2h",
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopped",
|
||||
if _, err := s.DB.UpdateSandboxStatusIf(bgCtx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: "stopping", Status_2: "stopped",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update status: %w", err)
|
||||
slog.Warn("failed to update sandbox to stopped", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
s.publishEvent(bgCtx, SandboxStateEvent{
|
||||
Event: "sandbox.stopped", SandboxID: sandboxIDStr, TeamID: teamIDStr, HostID: hostIDStr,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
|
||||
@ -429,6 +703,40 @@ func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgty
|
||||
}
|
||||
}
|
||||
|
||||
// GetDiskUsage returns the current disk usage in bytes for a sandbox.
|
||||
// For running or paused sandboxes, it queries the host agent for live data.
|
||||
// For other states or when the agent is unreachable, it falls back to the
|
||||
// last known metric point from the database.
|
||||
func (s *SandboxService) GetDiskUsage(ctx context.Context, sandboxID, teamID pgtype.UUID) (int64, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
|
||||
// For running or paused sandboxes, try the agent for live disk usage.
|
||||
if sb.Status == "running" || sb.Status == "paused" {
|
||||
sandboxIDStr := id.FormatSandboxID(sandboxID)
|
||||
agent, hostErr := s.agentForHost(ctx, sb.HostID)
|
||||
if hostErr == nil {
|
||||
resp, err := agent.GetSandboxMetrics(ctx, connect.NewRequest(&pb.GetSandboxMetricsRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Range: "5m",
|
||||
}))
|
||||
if err == nil && len(resp.Msg.Points) > 0 {
|
||||
last := resp.Msg.Points[len(resp.Msg.Points)-1]
|
||||
return last.DiskBytes, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: query the database for the last known metric point.
|
||||
point, err := s.DB.GetLatestSandboxMetricPoint(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return point.DiskBytes, nil
|
||||
}
|
||||
|
||||
// Ping resets the inactivity timer for a running sandbox.
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
|
||||
@ -479,6 +479,8 @@ type AdminTeamRow struct {
|
||||
OwnerEmail string
|
||||
ActiveSandboxCount int32
|
||||
ChannelCount int32
|
||||
RunningVcpus int32
|
||||
RunningMemoryMb int32
|
||||
}
|
||||
|
||||
// AdminListTeams returns a paginated list of all teams (excluding the platform
|
||||
@ -511,6 +513,8 @@ func (s *TeamService) AdminListTeams(ctx context.Context, limit, offset int32) (
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
RunningVcpus: t.RunningVcpus,
|
||||
RunningMemoryMb: t.RunningMemoryMb,
|
||||
}
|
||||
if t.DeletedAt.Valid {
|
||||
deletedAt := t.DeletedAt.Time
|
||||
|
||||
Reference in New Issue
Block a user