1
0
forked from wrenn/wrenn
This commit is contained in:
2026-04-16 19:24:25 +00:00
parent 172413e91e
commit 605ad666a0
239 changed files with 19966 additions and 3454 deletions

View File

@ -6,8 +6,8 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)

View File

@ -17,10 +17,9 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
)
// Sentinel errors returned by proxyTarget, used to map to HTTP status codes
@ -44,7 +43,7 @@ func (e errProxySandboxNotRunning) Error() string {
return fmt.Sprintf("sandbox is not running (status: %s)", e.status)
}
// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair.
// proxyCacheEntry caches the resolved agent URL for a sandbox.
// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure
// can capture the correct port without the cache key needing to include it.
type proxyCacheEntry struct {
@ -52,23 +51,13 @@ type proxyCacheEntry struct {
expiresAt time.Time
}
// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation.
type proxyCacheKey [32]byte
func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey {
var k proxyCacheKey
copy(k[:16], sandboxID.Bytes[:])
copy(k[16:], teamID.Bytes[:])
return k
}
// SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests
// whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching
// requests are reverse-proxied through the host agent that owns the sandbox.
// All other requests are passed through to the inner handler.
//
// Authentication is via X-API-Key header only (no JWT). The API key's team
// must own the sandbox.
// No authentication is required — sandbox URLs are unguessable and access is
// scoped to the sandbox ID embedded in the hostname.
type SandboxProxyWrapper struct {
inner http.Handler
db *db.Queries
@ -76,7 +65,7 @@ type SandboxProxyWrapper struct {
transport http.RoundTripper
cacheMu sync.Mutex
cache map[proxyCacheKey]proxyCacheEntry
cache map[pgtype.UUID]proxyCacheEntry
}
// NewSandboxProxyWrapper creates a new proxy wrapper.
@ -86,19 +75,15 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
db: queries,
pool: pool,
transport: pool.Transport(),
cache: make(map[proxyCacheKey]proxyCacheEntry),
cache: make(map[pgtype.UUID]proxyCacheEntry),
}
}
// proxyTarget looks up the cached agent URL for (sandboxID, teamID).
// proxyTarget looks up the cached agent URL for sandboxID.
// On a miss it queries the DB, resolves the address, and populates the cache.
// The *httputil.ReverseProxy is built by the caller so the Director closure
// captures the correct port without the cache key needing to include it.
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) {
cacheKey := makeProxyCacheKey(sandboxID, teamID)
func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID pgtype.UUID) (*url.URL, error) {
h.cacheMu.Lock()
entry, ok := h.cache[cacheKey]
entry, ok := h.cache[sandboxID]
h.cacheMu.Unlock()
if ok && time.Now().Before(entry.expiresAt) {
@ -106,10 +91,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
}
// Cache miss or expired — query DB.
target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{
ID: sandboxID,
TeamID: teamID,
})
target, err := h.db.GetSandboxProxyTarget(ctx, sandboxID)
if err != nil {
return nil, errProxySandboxNotFound
}
@ -126,7 +108,7 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
}
h.cacheMu.Lock()
h.cache[cacheKey] = proxyCacheEntry{
h.cache[sandboxID] = proxyCacheEntry{
agentURL: agentURL,
expiresAt: time.Now().Add(proxyCacheTTL),
}
@ -135,11 +117,11 @@ func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID
return agentURL, nil
}
// evictProxyCache removes the cached entry for a (sandbox, team) pair.
// evictProxyCache removes the cached entry for a sandbox.
// Called on 502 so a stopped/moved sandbox is re-resolved on the next request.
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) {
func (h *SandboxProxyWrapper) evictProxyCache(sandboxID pgtype.UUID) {
h.cacheMu.Lock()
delete(h.cache, makeProxyCacheKey(sandboxID, teamID))
delete(h.cache, sandboxID)
h.cacheMu.Unlock()
}
@ -166,20 +148,13 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
// Authenticate: require API key or JWT, extract team ID.
teamID, err := h.authenticateRequest(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
return
}
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
http.Error(w, "invalid sandbox ID", http.StatusBadRequest)
return
}
agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID)
agentURL, err := h.proxyTarget(r.Context(), sandboxID)
if err != nil {
switch {
case errors.Is(err, errProxySandboxNotFound):
@ -206,25 +181,9 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
"port", port,
"error", err,
)
h.evictProxyCache(sandboxID, teamID)
h.evictProxyCache(sandboxID)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}
// authenticateRequest validates the request's API key and returns the team ID.
// Only API key authentication is supported for sandbox proxy requests (not JWT).
func (h *SandboxProxyWrapper) authenticateRequest(r *http.Request) (pgtype.UUID, error) {
key := r.Header.Get("X-API-Key")
if key == "" {
return pgtype.UUID{}, fmt.Errorf("X-API-Key header required")
}
hash := auth.HashAPIKey(key)
row, err := h.db.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid API key")
}
return row.TeamID, nil
}

View File

@ -0,0 +1,248 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type adminCapsuleHandler struct {
svc *service.SandboxService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newAdminCapsuleHandler(svc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *adminCapsuleHandler {
return &adminCapsuleHandler{svc: svc, db: db, pool: pool, audit: al}
}
// Create handles POST /v1/admin/capsules.
func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createSandboxRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ac := auth.MustFromContext(r.Context())
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
TeamID: id.PlatformTeamID,
Template: req.Template,
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
TimeoutSec: req.TimeoutSec,
})
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
}
// List handles GET /v1/admin/capsules.
func (h *adminCapsuleHandler) List(w http.ResponseWriter, r *http.Request) {
sandboxes, err := h.svc.List(r.Context(), id.PlatformTeamID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes")
return
}
resp := make([]sandboxResponse, len(sandboxes))
for i, sb := range sandboxes {
resp[i] = sandboxToResponse(sb)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/admin/capsules/{id}.
func (h *adminCapsuleHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.svc.Get(r.Context(), sandboxID, id.PlatformTeamID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Destroy handles DELETE /v1/admin/capsules/{id}.
func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
}
type adminSnapshotRequest struct {
Name string `json:"name"`
}
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
var req adminSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
// Verify sandbox exists and belongs to platform team BEFORE any
// destructive operations (template overwrite).
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
// Check if name already exists as a platform template.
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
// Delete old snapshot files from all hosts before removing the DB record.
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
return
}
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context so the snapshot completes even if the client disconnects.
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: req.Name,
TeamId: formatUUIDForRPC(id.PlatformTeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: id.PlatformTeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
// Destroy the ephemeral capsule after successful snapshot.
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
// Don't fail the response — the snapshot was created successfully.
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
}

View File

@ -6,11 +6,11 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type apiKeyHandler struct {
@ -63,7 +63,7 @@ func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyRes
Name: k.Name,
KeyPrefix: k.KeyPrefix,
CreatedBy: id.FormatUserID(k.CreatedBy),
CreatorEmail: k.CreatorEmail,
CreatorEmail: k.CreatorEmail.String,
}
if k.CreatedAt.Valid {
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)

View File

@ -8,9 +8,9 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type auditHandler struct {

View File

@ -2,19 +2,32 @@ package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const (
activationKeyPrefix = "wrenn:activation:"
activationTTL = 30 * time.Minute
signupCooldown = 30 * time.Minute
)
// loginTeam returns the team and role to stamp into a login JWT.
@ -52,18 +65,89 @@ func loginTeam(ctx context.Context, q *db.Queries, userID pgtype.UUID) (db.Team,
}, first.Role, nil
}
// ensureDefaultTeam creates a default team for a user if they have none.
// This happens on first login after activation or for edge cases where a user
// has no teams. Returns the team, role, and whether the user was set as admin.
func ensureDefaultTeam(ctx context.Context, qtx *db.Queries, pool *pgxpool.Pool, userID pgtype.UUID, userName string) (db.Team, string, bool, error) {
// Try existing teams first.
team, role, err := loginTeam(ctx, qtx, userID)
if err == nil {
return team, role, false, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return db.Team{}, "", false, err
}
// No teams — create default team in a transaction.
tx, err := pool.Begin(ctx)
if err != nil {
return db.Team{}, "", false, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
txq := qtx.WithTx(tx)
// First active user to have a team becomes admin.
activeCount, err := txq.CountActiveUsers(ctx)
if err != nil {
return db.Team{}, "", false, fmt.Errorf("count active users: %w", err)
}
isFirstUser := activeCount == 1 // only this user is active
teamID := id.NewTeamID()
teamRow, err := txq.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: userName + "'s Team",
Slug: id.NewTeamSlug(),
})
if err != nil {
return db.Team{}, "", false, fmt.Errorf("insert team: %w", err)
}
if err := txq.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: userID,
TeamID: teamID,
IsDefault: true,
Role: "owner",
}); err != nil {
return db.Team{}, "", false, fmt.Errorf("insert team member: %w", err)
}
if isFirstUser {
if err := txq.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
return db.Team{}, "", false, fmt.Errorf("set admin: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return db.Team{}, "", false, fmt.Errorf("commit: %w", err)
}
return db.Team{
ID: teamRow.ID,
Name: teamRow.Name,
Slug: teamRow.Slug,
IsByoc: teamRow.IsByoc,
CreatedAt: teamRow.CreatedAt,
DeletedAt: teamRow.DeletedAt,
}, "owner", isFirstUser, nil
}
type switchTeamRequest struct {
TeamID string `json:"team_id"`
}
type authHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
mailer email.Mailer
rdb *redis.Client
redirectURL string
}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler {
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
}
type signupRequest struct {
@ -77,6 +161,10 @@ type loginRequest struct {
Password string `json:"password"`
}
type activateRequest struct {
Token string `json:"token"`
}
type authResponse struct {
Token string `json:"token"`
UserID string `json:"user_id"`
@ -85,6 +173,10 @@ type authResponse struct {
Name string `json:"name"`
}
type signupResponse struct {
Message string `json:"message"`
}
// Signup handles POST /v1/auth/signup.
func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
var req signupRequest
@ -110,24 +202,41 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Check for existing user with this email.
existing, err := h.db.GetUserByEmail(ctx, req.Email)
if err == nil {
// User exists — decide what to do based on status.
switch existing.Status {
case "inactive":
// Unactivated user — allow re-signup after cooldown.
if time.Since(existing.CreatedAt.Time) < signupCooldown {
writeError(w, http.StatusConflict, "signup_cooldown",
"an activation email was recently sent to this address — please check your inbox or try again later")
return
}
// Cooldown passed — delete the old row and proceed with fresh signup.
if err := h.db.HardDeleteUser(ctx, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to clean up previous signup")
return
}
default:
// active, disabled, deleted — email is taken.
writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists")
return
}
} else if !errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
return
}
passwordHash, err := auth.HashPassword(req.Password)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
return
}
// Use a transaction to atomically create user + team + membership.
tx, err := h.pool.Begin(ctx)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction")
return
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := h.db.WithTx(tx)
userID := id.NewUserID()
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
_, err = h.db.InsertUserInactive(ctx, db.InsertUserInactiveParams{
ID: userID,
Email: req.Email,
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
@ -143,44 +252,111 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
return
}
// Create default team.
teamID := id.NewTeamID()
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: req.Name + "'s Team",
Slug: id.NewTeamSlug(),
// Generate activation token and store in Redis.
rawToken := generateActivationToken()
tokenHash := hashActivationToken(rawToken)
redisKey := activationKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
slog.Error("signup: failed to store activation token in redis", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create activation token")
return
}
activateURL := h.redirectURL + "/activate?token=" + rawToken
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := h.mailer.Send(sendCtx, req.Email, "Activate your Wrenn account", email.EmailData{
RecipientName: req.Name,
Message: "Welcome to Wrenn! Click the button below to activate your account. This link expires in 30 minutes.",
Button: &email.Button{Text: "Activate Account", URL: activateURL},
Closing: "If you didn't create this account, you can safely ignore this email.",
}); err != nil {
slog.Warn("signup: failed to send activation email", "email", req.Email, "error", err)
}
}()
writeJSON(w, http.StatusCreated, signupResponse{
Message: "Account created. Please check your email to activate your account.",
})
}
// Activate handles POST /v1/auth/activate.
func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
var req activateRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Token == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
return
}
ctx := r.Context()
tokenHash := hashActivationToken(req.Token)
redisKey := activationKeyPrefix + tokenHash
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
if errors.Is(err, redis.Nil) {
writeError(w, http.StatusBadRequest, "invalid_token", "activation link is invalid or has expired")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
return
}
userID, err := id.ParseUserID(userIDStr)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
return
}
user, err := h.db.GetUserByID(ctx, userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
if user.Status != "inactive" {
writeError(w, http.StatusBadRequest, "already_activated", "this account has already been activated")
return
}
// Activate the user.
if err := h.db.SetUserStatus(ctx, db.SetUserStatusParams{
ID: userID,
Status: "active",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
slog.Error("activate: failed to set user status", "user_id", id.FormatUserID(userID), "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to activate user")
return
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: userID,
TeamID: teamID,
IsDefault: true,
Role: "owner",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team")
// Create default team and log them in.
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, userID, user.Name)
if err != nil {
slog.Error("activate: failed to create default team", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to set up account")
return
}
if err := tx.Commit(ctx); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup")
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusCreated, authResponse{
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(teamID),
Email: req.Email,
Name: req.Name,
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
})
}
@ -222,17 +398,36 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
return
}
team, role, err := loginTeam(ctx, h.db, user.ID)
switch user.Status {
case "active":
// OK — proceed.
case "inactive":
slog.Warn("login failed: account not activated", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusForbidden, "account_not_activated", "please check your email and activate your account before signing in")
return
case "disabled":
slog.Warn("login failed: account disabled", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusForbidden, "account_disabled", "your account has been deactivated — contact your administrator to regain access")
return
case "deleted":
slog.Warn("login failed: account deleted", "email", req.Email, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
default:
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
return
}
// Ensure user has a default team (creates one on first login after activation).
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
writeError(w, http.StatusForbidden, "no_team", "user is not a member of any team")
return
}
slog.Error("login: failed to ensure default team", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
@ -322,3 +517,18 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
Name: user.Name,
})
}
// --- helpers ---
func generateActivationToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hashActivationToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}

View File

@ -3,19 +3,21 @@ package api
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/internal/validate"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -54,6 +56,8 @@ type buildResponse struct {
Error *string `json:"error,omitempty"`
SandboxID *string `json:"sandbox_id,omitempty"`
HostID *string `json:"host_id,omitempty"`
DefaultUser string `json:"default_user"`
DefaultEnv json.RawMessage `json:"default_env"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
@ -71,6 +75,8 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
CurrentStep: b.CurrentStep,
TotalSteps: b.TotalSteps,
Logs: b.Logs,
DefaultUser: b.DefaultUser,
DefaultEnv: b.DefaultEnv,
}
if b.Healthcheck != "" {
resp.Healthcheck = &b.Healthcheck
@ -101,11 +107,54 @@ func buildToResponse(b db.TemplateBuild) buildResponse {
}
// Create handles POST /v1/admin/builds.
// Accepts either JSON body or multipart/form-data with a "config" JSON part
// and an optional "archive" file part (tar/tar.gz/zip for COPY commands).
func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createBuildRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
var archive []byte
var archiveName string
ct := r.Header.Get("Content-Type")
if strings.HasPrefix(ct, "multipart/") {
// 100 MB max for multipart (archive + JSON config).
if err := r.ParseMultipartForm(100 << 20); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart form")
return
}
// Parse JSON config from "config" field.
configStr := r.FormValue("config")
if configStr == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "multipart form requires a 'config' JSON field")
return
}
if err := json.Unmarshal([]byte(configStr), &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid config JSON in multipart form")
return
}
// Read optional archive file (max 100 MB).
file, header, err := r.FormFile("archive")
if err == nil {
defer file.Close()
const maxArchiveSize = 100 << 20 // 100 MB
lr := io.LimitReader(file, maxArchiveSize+1)
archive, err = io.ReadAll(lr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "failed to read archive file")
return
}
if int64(len(archive)) > maxArchiveSize {
writeError(w, http.StatusRequestEntityTooLarge, "invalid_request", "archive exceeds 100 MB limit")
return
}
archiveName = header.Filename
}
} else {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
}
if req.Name == "" {
@ -129,6 +178,8 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
Archive: archive,
ArchiveName: archiveName,
})
if err != nil {
slog.Error("failed to create build", "error", err)

View File

@ -8,11 +8,11 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/channels"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
type channelHandler struct {

View File

@ -12,10 +12,10 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -29,9 +29,13 @@ func newExecHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execHandler
}
type execRequest struct {
Cmd string `json:"cmd"`
Args []string `json:"args"`
TimeoutSec int32 `json:"timeout_sec"`
Cmd string `json:"cmd"`
Args []string `json:"args"`
TimeoutSec int32 `json:"timeout_sec"`
Background bool `json:"background"`
Tag string `json:"tag"`
Envs map[string]string `json:"envs"`
Cwd string `json:"cwd"`
}
type execResponse struct {
@ -45,7 +49,14 @@ type execResponse struct {
Encoding string `json:"encoding"`
}
// Exec handles POST /v1/sandboxes/{id}/exec.
type backgroundExecResponse struct {
SandboxID string `json:"sandbox_id"`
Cmd string `json:"cmd"`
PID uint32 `json:"pid"`
Tag string `json:"tag"`
}
// Exec handles POST /v1/capsules/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
@ -78,14 +89,54 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
return
}
start := time.Now()
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Background mode: start process and return immediately.
if req.Background {
tag := req.Tag
if tag == "" {
tag = "proc-" + id.NewPtyTag()
}
bgResp, err := agent.StartBackground(ctx, connect.NewRequest(&pb.StartBackgroundRequest{
SandboxId: sandboxIDStr,
Tag: tag,
Cmd: req.Cmd,
Args: req.Args,
Envs: req.Envs,
Cwd: req.Cwd,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
PID: bgResp.Msg.Pid,
Tag: bgResp.Msg.Tag,
})
return
}
start := time.Now()
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: req.Cmd,

View File

@ -12,20 +12,21 @@ import (
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type execStreamHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
var upgrader = websocket.Upgrader{
@ -47,11 +48,10 @@ type wsOutMsg struct {
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
// ExecStream handles WS /v1/capsules/{id}/exec/stream.
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
@ -59,13 +59,31 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
@ -76,6 +94,20 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
}
func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
// Read the start message.
var startMsg wsStartMsg
if err := conn.ReadJSON(&startMsg); err != nil {

View File

@ -9,10 +9,10 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -25,7 +25,7 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
return &filesHandler{db: db, pool: pool}
}
// Upload handles POST /v1/sandboxes/{id}/files/write.
// Upload handles POST /v1/capsules/{id}/files/write.
// Expects multipart/form-data with:
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
@ -105,7 +105,7 @@ type readFileRequest struct {
Path string `json:"path"`
}
// Download handles POST /v1/sandboxes/{id}/files/read.
// Download handles POST /v1/capsules/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")

View File

@ -10,10 +10,10 @@ import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -26,7 +26,7 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
return &filesStreamHandler{db: db, pool: pool}
}
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
// StreamUpload handles POST /v1/capsules/{id}/files/stream/write.
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
@ -150,7 +150,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
w.WriteHeader(http.StatusNoContent)
}
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")

236
internal/api/handlers_fs.go Normal file
View File

@ -0,0 +1,236 @@
package api
import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type fsHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
}
func newFSHandler(db *db.Queries, pool *lifecycle.HostClientPool) *fsHandler {
return &fsHandler{db: db, pool: pool}
}
type listDirRequest struct {
Path string `json:"path"`
Depth uint32 `json:"depth"`
}
type fileEntryResponse struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"`
Size int64 `json:"size"`
Mode uint32 `json:"mode"`
Permissions string `json:"permissions"`
Owner string `json:"owner"`
Group string `json:"group"`
ModifiedAt int64 `json:"modified_at"`
SymlinkTarget *string `json:"symlink_target,omitempty"`
}
type listDirResponse struct {
Entries []fileEntryResponse `json:"entries"`
}
type makeDirRequest struct {
Path string `json:"path"`
}
type makeDirResponse struct {
Entry fileEntryResponse `json:"entry"`
}
type removeRequest struct {
Path string `json:"path"`
}
// ListDir handles POST /v1/capsules/{id}/files/list.
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
return
}
var req listDirRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Path == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ListDir(ctx, connect.NewRequest(&pb.ListDirRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
Depth: req.Depth,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
entries := make([]fileEntryResponse, 0, len(resp.Msg.Entries))
for _, e := range resp.Msg.Entries {
entries = append(entries, fileEntryFromPB(e))
}
writeJSON(w, http.StatusOK, listDirResponse{Entries: entries})
}
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
return
}
var req makeDirRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Path == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.MakeDir(ctx, connect.NewRequest(&pb.MakeDirRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, makeDirResponse{Entry: fileEntryFromPB(resp.Msg.Entry)})
}
// Remove handles POST /v1/capsules/{id}/files/remove.
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
return
}
var req removeRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Path == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
if _, err := agent.RemovePath(ctx, connect.NewRequest(&pb.RemovePathRequest{
SandboxId: sandboxIDStr,
Path: req.Path,
})); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
func fileEntryFromPB(e *pb.FileEntry) fileEntryResponse {
if e == nil {
return fileEntryResponse{}
}
resp := fileEntryResponse{
Name: e.Name,
Path: e.Path,
Type: e.Type,
Size: e.Size,
Mode: e.Mode,
Permissions: e.Permissions,
Owner: e.Owner,
Group: e.Group,
ModifiedAt: e.ModifiedAt,
}
if e.SymlinkTarget != nil {
resp.SymlinkTarget = e.SymlinkTarget
}
return resp
}

View File

@ -10,11 +10,11 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type hostHandler struct {

585
internal/api/handlers_me.go Normal file
View File

@ -0,0 +1,585 @@
package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
const (
passwordResetKeyPrefix = "wrenn:password_reset:"
passwordResetTTL = 15 * time.Minute
)
type meHandler struct {
db *db.Queries
pool *pgxpool.Pool
rdb *redis.Client
jwtSecret []byte
mailer email.Mailer
oauthRegistry *oauth.Registry
redirectURL string
teamSvc *service.TeamService
}
func newMeHandler(
db *db.Queries,
pool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
mailer email.Mailer,
registry *oauth.Registry,
redirectURL string,
teamSvc *service.TeamService,
) *meHandler {
return &meHandler{
db: db,
pool: pool,
rdb: rdb,
jwtSecret: jwtSecret,
mailer: mailer,
oauthRegistry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
teamSvc: teamSvc,
}
}
type meResponse struct {
Name string `json:"name"`
Email string `json:"email"`
HasPassword bool `json:"has_password"`
Providers []string `json:"providers"`
}
type updateNameRequest struct {
Name string `json:"name"`
}
type changePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
ConfirmPassword string `json:"confirm_password"`
}
type requestPasswordResetRequest struct {
Email string `json:"email"`
}
type confirmPasswordResetRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type deleteAccountRequest struct {
Confirmation string `json:"confirmation"`
}
// GetMe handles GET /v1/me.
func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
return
}
providerNames := make([]string, 0, len(providers))
for _, p := range providers {
providerNames = append(providerNames, p.Provider)
}
writeJSON(w, http.StatusOK, meResponse{
Name: user.Name,
Email: user.Email,
HasPassword: user.PasswordHash.Valid,
Providers: providerNames,
})
}
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
var req updateNameRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
req.Name = strings.TrimSpace(req.Name)
if req.Name == "" || len(req.Name) > 100 {
writeError(w, http.StatusBadRequest, "invalid_request", "name must be between 1 and 100 characters")
return
}
if err := h.db.UpdateUserName(ctx, db.UpdateUserNameParams{
ID: ac.UserID,
Name: req.Name,
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update name")
return
}
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
team, role, err := loginTeam(ctx, h.db, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: req.Name,
})
}
// ChangePassword handles POST /v1/me/password.
// For users with a password: requires current_password + new_password.
// For OAuth-only users: requires new_password + confirm_password.
func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
var req changePasswordRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
if user.PasswordHash.Valid {
// Changing existing password — verify current.
if req.CurrentPassword == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "current_password is required")
return
}
if err := auth.CheckPassword(user.PasswordHash.String, req.CurrentPassword); err != nil {
writeError(w, http.StatusUnauthorized, "wrong_password", "current password is incorrect")
return
}
} else {
// OAuth user adding a password — confirm must match.
if req.ConfirmPassword == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "confirm_password is required")
return
}
if req.NewPassword != req.ConfirmPassword {
writeError(w, http.StatusBadRequest, "invalid_request", "passwords do not match")
return
}
}
if len(req.NewPassword) < 8 {
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
return
}
hash, err := auth.HashPassword(req.NewPassword)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
return
}
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
ID: ac.UserID,
PasswordHash: pgtype.Text{String: hash, Valid: true},
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
return
}
isAdding := !user.PasswordHash.Valid
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
subject, message := "Your Wrenn password was changed", "Your account password was successfully updated. If you did not make this change, reset your password immediately."
if isAdding {
subject = "Password added to your Wrenn account"
message = "A password has been added to your Wrenn account. You can now sign in with your email and password in addition to any connected OAuth providers."
}
if err := h.mailer.Send(sendCtx, user.Email, subject, email.EmailData{
RecipientName: user.Name,
Message: message,
Closing: "If you didn't make this change, contact support immediately.",
}); err != nil {
slog.Warn("change password: failed to send notification", "email", user.Email, "error", err)
}
}()
w.WriteHeader(http.StatusNoContent)
}
// RequestPasswordReset handles POST /v1/me/password/reset (unauthenticated).
// Always returns 200 to avoid leaking account existence.
func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req requestPasswordResetRequest
if err := decodeJSON(r, &req); err != nil {
w.WriteHeader(http.StatusNoContent)
return
}
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
if req.Email == "" {
w.WriteHeader(http.StatusNoContent)
return
}
ctx := r.Context()
user, err := h.db.GetUserByEmail(ctx, req.Email)
if err != nil {
// Don't leak whether the email exists.
w.WriteHeader(http.StatusNoContent)
return
}
if user.Status != "active" {
w.WriteHeader(http.StatusNoContent)
return
}
rawToken := generateResetToken()
tokenHash := hashResetToken(rawToken)
redisKey := passwordResetKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
slog.Error("password reset: failed to store token in redis", "error", err)
w.WriteHeader(http.StatusNoContent)
return
}
resetURL := h.redirectURL + "/reset-password?token=" + rawToken
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := h.mailer.Send(sendCtx, user.Email, "Reset your Wrenn password", email.EmailData{
RecipientName: user.Name,
Message: "We received a request to reset your password. Click the button below to set a new password. This link expires in 15 minutes.",
Button: &email.Button{Text: "Reset Password", URL: resetURL},
Closing: "If you didn't request a password reset, you can safely ignore this email.",
}); err != nil {
slog.Error("password reset: failed to send email", "email", user.Email, "error", err)
}
}()
w.WriteHeader(http.StatusNoContent)
}
// ConfirmPasswordReset handles POST /v1/me/password/reset/confirm (unauthenticated).
func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request) {
var req confirmPasswordResetRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Token == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
return
}
if len(req.NewPassword) < 8 {
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
return
}
ctx := r.Context()
tokenHash := hashResetToken(req.Token)
redisKey := passwordResetKeyPrefix + tokenHash
// GetDel atomically retrieves and removes the token in a single round-trip,
// preventing concurrent requests from both consuming the same token.
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
if errors.Is(err, redis.Nil) {
writeError(w, http.StatusBadRequest, "invalid_token", "reset token is invalid or has expired")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to verify token")
return
}
userID, err := id.ParseUserID(userIDStr)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "invalid stored user ID")
return
}
user, err := h.db.GetUserByID(ctx, userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
hash, err := auth.HashPassword(req.NewPassword)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
return
}
if err := h.db.UpdateUserPassword(ctx, db.UpdateUserPasswordParams{
ID: userID,
PasswordHash: pgtype.Text{String: hash, Valid: true},
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update password")
return
}
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn password was reset", email.EmailData{
RecipientName: user.Name,
Message: "Your password has been successfully reset. You can now sign in with your new password.",
Closing: "If you didn't request this change, contact support immediately.",
}); err != nil {
slog.Warn("confirm password reset: failed to send notification", "email", user.Email, "error", err)
}
}()
w.WriteHeader(http.StatusNoContent)
}
// ConnectProvider handles GET /v1/me/providers/{provider}/connect.
// Sets OAuth state + link cookies and returns the provider auth URL.
func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
provider := chi.URLParam(r, "provider")
p, ok := h.oauthRegistry.Get(provider)
if !ok {
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
return
}
state, err := generateState()
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
return
}
mac := computeHMAC(h.jwtSecret, state)
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state + ":" + mac,
Path: "/",
MaxAge: 600,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
userIDStr := id.FormatUserID(ac.UserID)
linkMac := computeHMAC(h.jwtSecret, userIDStr)
http.SetCookie(w, &http.Cookie{
Name: "oauth_link_user_id",
Value: userIDStr + ":" + linkMac,
Path: "/",
MaxAge: 600,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
writeJSON(w, http.StatusOK, map[string]string{"auth_url": p.AuthCodeURL(state)})
}
// DisconnectProvider handles DELETE /v1/me/providers/{provider}.
func (h *meHandler) DisconnectProvider(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
provider := chi.URLParam(r, "provider")
ctx := r.Context()
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
providers, err := h.db.GetOAuthProvidersByUserID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get providers")
return
}
// Ensure the user will still have at least one login method after disconnecting.
if !user.PasswordHash.Valid && len(providers) <= 1 {
writeError(w, http.StatusBadRequest, "last_login_method", "cannot disconnect your only login method — add a password first")
return
}
// Check the provider is actually linked to this user.
found := false
for _, p := range providers {
if p.Provider == provider {
found = true
break
}
}
if !found {
writeError(w, http.StatusNotFound, "not_found", "provider not connected")
return
}
if err := h.db.DeleteOAuthProvider(ctx, db.DeleteOAuthProviderParams{
UserID: ac.UserID,
Provider: provider,
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to disconnect provider")
return
}
w.WriteHeader(http.StatusNoContent)
}
// DeleteAccount handles DELETE /v1/me — soft-deletes the user's account.
func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
var req deleteAccountRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
}
if !strings.EqualFold(strings.TrimSpace(req.Confirmation), user.Email) {
writeError(w, http.StatusBadRequest, "invalid_request", "confirmation does not match your email address")
return
}
teamsBlocking, err := h.db.CountUserOwnedTeamsWithOtherMembers(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to check team ownership")
return
}
if teamsBlocking > 0 {
writeError(w, http.StatusConflict, "owns_team_with_members",
fmt.Sprintf("you own %d team(s) with other members — transfer ownership or remove members before deleting your account", teamsBlocking))
return
}
// Delete all teams the user solely owns (no other members).
// Team deletion involves RPC calls (sandbox destruction) that cannot be
// transactional, so we do those first as best-effort, then wrap the
// DB-only cleanup in a transaction.
soleTeams, err := h.db.ListSoleOwnedTeams(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list owned teams")
return
}
for _, teamID := range soleTeams {
if err := h.teamSvc.DeleteTeamInternal(ctx, teamID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error",
fmt.Sprintf("failed to delete sole-owned team %s", id.FormatTeamID(teamID)))
return
}
}
tx, err := h.pool.Begin(ctx)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to start transaction")
return
}
defer tx.Rollback(ctx)
qtx := h.db.WithTx(tx)
if err := qtx.DeleteAPIKeysByCreator(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete user's API keys")
return
}
if err := qtx.SoftDeleteUser(ctx, ac.UserID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete account")
return
}
if err := tx.Commit(ctx); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit account deletion")
return
}
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := h.mailer.Send(sendCtx, user.Email, "Your Wrenn account has been deleted", email.EmailData{
RecipientName: user.Name,
Message: "Your Wrenn account has been deactivated and is scheduled for permanent deletion in 15 days. If this was a mistake, contact support before then to recover your account.",
Closing: "Thank you for using Wrenn.",
}); err != nil {
slog.Warn("delete account: failed to send notification", "email", user.Email, "error", err)
}
}()
w.WriteHeader(http.StatusNoContent)
}
// --- helpers ---
func generateResetToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hashResetToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}

View File

@ -9,10 +9,10 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -38,7 +38,7 @@ type metricsResponse struct {
Points []metricPointResponse `json:"points"`
}
// GetMetrics handles GET /v1/sandboxes/{id}/metrics?range=10m|2h|24h.
// GetMetrics handles GET /v1/capsules/{id}/metrics?range=10m|2h|24h.
func (h *sandboxMetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()

View File

@ -16,10 +16,10 @@ import (
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
type oauthHandler struct {
@ -137,6 +137,73 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
email := strings.TrimSpace(strings.ToLower(profile.Email))
// Check for a link operation initiated from the settings page.
if linkCookie, err := r.Cookie("oauth_link_user_id"); err == nil && linkCookie.Value != "" {
// Clear the link cookie immediately.
http.SetCookie(w, &http.Cookie{
Name: "oauth_link_user_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
settingsBase := h.redirectURL + "/dashboard/settings"
// Verify the HMAC to prevent cookie forgery.
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
slog.Warn("oauth link: invalid or tampered link cookie")
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
return
}
userID, parseErr := id.ParseUserID(linkParts[0])
if parseErr != nil {
slog.Error("oauth link: invalid user ID in cookie", "error", parseErr)
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
return
}
// Ensure the GitHub account isn't already linked to a different user.
existing, lookupErr := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
Provider: provider,
ProviderID: profile.ProviderID,
})
if lookupErr == nil && existing.UserID != userID {
slog.Warn("oauth link: provider already linked to another account", "provider", provider)
http.Redirect(w, r, settingsBase+"?connect_error=already_linked", http.StatusFound)
return
}
if lookupErr == nil && existing.UserID == userID {
// Already linked to this user — treat as success.
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
return
}
if !errors.Is(lookupErr, pgx.ErrNoRows) {
slog.Error("oauth link: db lookup failed", "error", lookupErr)
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
return
}
if insertErr := h.db.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
Provider: provider,
ProviderID: profile.ProviderID,
UserID: userID,
Email: email,
}); insertErr != nil {
slog.Error("oauth link: failed to insert provider", "error", insertErr)
http.Redirect(w, r, settingsBase+"?connect_error=db_error", http.StatusFound)
return
}
slog.Info("oauth link: provider linked", "provider", provider, "user_id", id.FormatUserID(userID))
http.Redirect(w, r, settingsBase+"?connected="+provider, http.StatusFound)
return
}
// Check if this OAuth identity already exists.
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
Provider: provider,
@ -145,18 +212,29 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if err == nil {
// Existing OAuth user — log them in.
user, err := h.db.GetUserByID(ctx, existing.UserID)
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
if err != nil {
slog.Error("oauth login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, role, err := loginTeam(ctx, h.db, user.ID)
if user.Status != "active" {
slog.Warn("oauth login: account not active", "email", user.Email, "status", user.Status)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
if err != nil {
slog.Error("oauth login: failed to get team", "error", err)
slog.Error("oauth login: failed to ensure team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
@ -172,13 +250,21 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
}
// New OAuth identity — check for email collision.
_, err = h.db.GetUserByEmail(ctx, email)
existingUser, err := h.db.GetUserByEmail(ctx, email)
if err == nil {
// Email already taken by another account.
redirectWithError(w, r, redirectBase, "email_taken")
return
}
if !errors.Is(err, pgx.ErrNoRows) {
if existingUser.Status == "inactive" {
// Unactivated email signup — delete and let OAuth take over.
if delErr := h.db.HardDeleteUser(ctx, existingUser.ID); delErr != nil {
slog.Error("oauth: failed to delete inactive user", "error", delErr)
redirectWithError(w, r, redirectBase, "db_error")
return
}
} else {
// Email already taken by an active/disabled/deleted account.
redirectWithError(w, r, redirectBase, "email_taken")
return
}
} else if !errors.Is(err, pgx.ErrNoRows) {
slog.Error("oauth: email check failed", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
@ -195,6 +281,15 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
qtx := h.db.WithTx(tx)
// The first user to sign up becomes a platform admin.
userCount, err := qtx.CountUsers(ctx)
if err != nil {
slog.Error("oauth: failed to count users", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
isFirstUser := userCount == 0
userID := id.NewUserID()
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
ID: userID,
@ -238,6 +333,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
if isFirstUser {
if err := qtx.SetUserAdmin(ctx, db.SetUserAdminParams{ID: userID, IsAdmin: true}); err != nil {
slog.Error("oauth: failed to set admin status", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
}
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
Provider: provider,
ProviderID: profile.ProviderID,
@ -255,7 +358,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
@ -279,18 +382,29 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
return
}
user, err := h.db.GetUserByID(ctx, existing.UserID)
if errors.Is(err, pgx.ErrNoRows) {
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
if err != nil {
slog.Error("oauth: retry login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, role, err := loginTeam(ctx, h.db, user.ID)
if user.Status != "active" {
slog.Warn("oauth: retry login: account not active", "email", user.Email, "status", user.Status)
redirectWithError(w, r, redirectBase, "account_deactivated")
return
}
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
if err != nil {
slog.Error("oauth: retry login: failed to get team", "error", err)
slog.Error("oauth: retry login: failed to ensure team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")

View File

@ -0,0 +1,298 @@
package api
import (
"context"
"log/slog"
"net/http"
"strconv"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type processHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// processResponse is a single entry in the process list.
type processResponse struct {
PID uint32 `json:"pid"`
Tag string `json:"tag,omitempty"`
Cmd string `json:"cmd"`
Args []string `json:"args,omitempty"`
}
// processListResponse wraps the list of processes.
type processListResponse struct {
Processes []processResponse `json:"processes"`
}
// ListProcesses handles GET /v1/capsules/{id}/processes.
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
resp, err := agent.ListProcesses(ctx, connect.NewRequest(&pb.ListProcessesRequest{
SandboxId: sandboxIDStr,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
procs := make([]processResponse, 0, len(resp.Msg.Processes))
for _, p := range resp.Msg.Processes {
procs = append(procs, processResponse{
PID: p.Pid,
Tag: p.Tag,
Cmd: p.Cmd,
Args: p.Args,
})
}
writeJSON(w, http.StatusOK, processListResponse{Processes: procs})
}
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
// The selector can be a numeric PID or a string tag.
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Build the kill request with PID or tag selector.
killReq := &pb.KillProcessRequest{
SandboxId: sandboxIDStr,
Signal: "SIGKILL",
}
if sig := r.URL.Query().Get("signal"); sig == "SIGTERM" {
killReq.Signal = "SIGTERM"
}
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
killReq.Selector = &pb.KillProcessRequest_Pid{Pid: uint32(pid)}
} else {
killReq.Selector = &pb.KillProcessRequest_Tag{Tag: selectorStr}
}
if _, err := agent.KillProcess(ctx, connect.NewRequest(killReq)); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// wsProcessOut is the JSON message sent to the WebSocket client.
type wsProcessOut struct {
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
PID uint32 `json:"pid,omitempty"` // only for "start"
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
}
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
return
}
// Build the connect request with PID or tag selector.
connectReq := &pb.ConnectProcessRequest{
SandboxId: sandboxIDStr,
}
if pid, err := strconv.ParseUint(selectorStr, 10, 32); err == nil {
connectReq.Selector = &pb.ConnectProcessRequest_Pid{Pid: uint32(pid)}
} else {
connectReq.Selector = &pb.ConnectProcessRequest_Tag{Tag: selectorStr}
}
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
if err != nil {
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
return
}
defer stream.Close()
// Listen for client disconnect in a goroutine.
go func() {
for {
_, _, err := conn.ReadMessage()
if err != nil {
cancel()
return
}
}
}()
// Forward stream events to WebSocket.
for stream.Receive() {
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.ConnectProcessResponse_Start:
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
case *pb.ConnectProcessResponse_Data:
switch o := ev.Data.Output.(type) {
case *pb.ExecStreamData_Stdout:
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
case *pb.ExecStreamData_Stderr:
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
}
case *pb.ConnectProcessResponse_End:
exitCode := ev.End.ExitCode
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
}
}
if err := stream.Err(); err != nil {
if streamCtx.Err() == nil {
sendProcessWSError(conn, err.Error())
}
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
}
}
func sendProcessWSError(conn *websocket.Conn, msg string) {
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
}

View File

@ -0,0 +1,400 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"sync"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
const (
ptyKeepaliveInterval = 30 * time.Second
ptyDefaultCmd = "/bin/bash"
ptyDefaultCols = 80
ptyDefaultRows = 24
)
type ptyHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
}
// --- WebSocket message types ---
// wsPtyIn is the inbound message from the client.
type wsPtyIn struct {
Type string `json:"type"` // "start", "connect", "input", "resize", "kill"
Cmd string `json:"cmd,omitempty"` // for "start"
Args []string `json:"args,omitempty"` // for "start"
Cols uint32 `json:"cols,omitempty"` // for "start", "resize"
Rows uint32 `json:"rows,omitempty"` // for "start", "resize"
Envs map[string]string `json:"envs,omitempty"` // for "start"
Cwd string `json:"cwd,omitempty"` // for "start"
User string `json:"user,omitempty"` // for "start"
Tag string `json:"tag,omitempty"` // for "connect"
Data string `json:"data,omitempty"` // for "input" (base64)
}
// wsPtyOut is the outbound message to the client.
type wsPtyOut struct {
Type string `json:"type"` // "started", "output", "exit", "error"
Tag string `json:"tag,omitempty"` // for "started"
PID uint32 `json:"pid,omitempty"` // for "started"
Data string `json:"data,omitempty"` // for "output" (base64), "error"
ExitCode *int32 `json:"exit_code,omitempty"` // for "exit"
Fatal bool `json:"fatal,omitempty"` // for "error"
}
// wsWriter wraps a websocket.Conn with a mutex for concurrent writes.
type wsWriter struct {
conn *websocket.Conn
mu sync.Mutex
}
func (w *wsWriter) writeJSON(v any) {
w.mu.Lock()
defer w.mu.Unlock()
if err := w.conn.WriteJSON(v); err != nil {
slog.Debug("pty websocket write error", "error", err)
}
}
// PtySession handles WS /v1/capsules/{id}/pty.
func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox not found", Fatal: true})
return
}
if sb.Status != "running" {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox is not running (status: " + sb.Status + ")", Fatal: true})
return
}
// Read the first message to determine start vs connect.
var firstMsg wsPtyIn
if err := conn.ReadJSON(&firstMsg); err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to read first message: " + err.Error(), Fatal: true})
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "sandbox host is not reachable", Fatal: true})
return
}
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
switch firstMsg.Type {
case "start":
h.handleStart(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
case "connect":
h.handleConnect(streamCtx, cancel, ws, agent, sandboxIDStr, firstMsg)
default:
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
}
}
func (h *ptyHandler) handleStart(
ctx context.Context,
cancel context.CancelFunc,
ws *wsWriter,
agent hostagentv1connect.HostAgentServiceClient,
sandboxIDStr string,
msg wsPtyIn,
) {
cmd := msg.Cmd
if cmd == "" {
cmd = ptyDefaultCmd
}
cols := msg.Cols
if cols == 0 {
cols = ptyDefaultCols
}
rows := msg.Rows
if rows == 0 {
rows = ptyDefaultRows
}
tag := newPtyTag()
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
SandboxId: sandboxIDStr,
Tag: tag,
Cmd: cmd,
Args: msg.Args,
Cols: cols,
Rows: rows,
Envs: msg.Envs,
Cwd: msg.Cwd,
User: msg.User,
}))
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to start pty: " + err.Error(), Fatal: true})
return
}
defer stream.Close()
// Wait for the started event and forward it.
if !stream.Receive() {
if err := stream.Err(); err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "pty stream failed: " + err.Error(), Fatal: true})
}
return
}
resp := stream.Msg()
started, ok := resp.Event.(*pb.PtyAttachResponse_Started)
if !ok {
ws.writeJSON(wsPtyOut{Type: "error", Data: "expected started event from host agent", Fatal: true})
return
}
ws.writeJSON(wsPtyOut{Type: "started", Tag: started.Started.Tag, PID: started.Started.Pid})
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, tag)
}
func (h *ptyHandler) handleConnect(
ctx context.Context,
cancel context.CancelFunc,
ws *wsWriter,
agent hostagentv1connect.HostAgentServiceClient,
sandboxIDStr string,
msg wsPtyIn,
) {
if msg.Tag == "" {
ws.writeJSON(wsPtyOut{Type: "error", Data: "connect requires a 'tag' field", Fatal: true})
return
}
stream, err := agent.PtyAttach(ctx, connect.NewRequest(&pb.PtyAttachRequest{
SandboxId: sandboxIDStr,
Tag: msg.Tag,
}))
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "failed to connect to pty: " + err.Error(), Fatal: true})
return
}
defer stream.Close()
runPtyLoop(ctx, cancel, ws, stream, agent, sandboxIDStr, msg.Tag)
}
// runPtyLoop drives the bidirectional communication between the WebSocket
// and the host agent PTY stream.
func runPtyLoop(
ctx context.Context,
cancel context.CancelFunc,
ws *wsWriter,
stream *connect.ServerStreamForClient[pb.PtyAttachResponse],
agent hostagentv1connect.HostAgentServiceClient,
sandboxID string,
tag string,
) {
var wg sync.WaitGroup
// Output pump: read from Connect stream, write to WebSocket.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
for stream.Receive() {
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.PtyAttachResponse_Started:
// Already handled before the loop for "start" mode.
// For "connect" mode this won't appear.
ws.writeJSON(wsPtyOut{Type: "started", Tag: ev.Started.Tag, PID: ev.Started.Pid})
case *pb.PtyAttachResponse_Output:
ws.writeJSON(wsPtyOut{
Type: "output",
Data: base64.StdEncoding.EncodeToString(ev.Output.Data),
})
case *pb.PtyAttachResponse_Exited:
exitCode := ev.Exited.ExitCode
ws.writeJSON(wsPtyOut{Type: "exit", ExitCode: &exitCode})
return
}
}
if err := stream.Err(); err != nil && ctx.Err() == nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: err.Error()})
}
}()
// Input pump: read from WebSocket, dispatch to host agent.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
for {
_, raw, err := ws.conn.ReadMessage()
if err != nil {
return
}
var msg wsPtyIn
if json.Unmarshal(raw, &msg) != nil {
continue
}
// Use a background context for unary RPCs so they complete
// even if the stream context is being cancelled.
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
switch msg.Type {
case "input":
data, err := base64.StdEncoding.DecodeString(msg.Data)
if err != nil {
rpcCancel()
continue
}
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
SandboxId: sandboxID,
Tag: tag,
Data: data,
})); err != nil {
slog.Debug("pty send input error", "error", err)
}
case "resize":
cols := msg.Cols
rows := msg.Rows
if cols > 0 && rows > 0 {
if _, err := agent.PtyResize(rpcCtx, connect.NewRequest(&pb.PtyResizeRequest{
SandboxId: sandboxID,
Tag: tag,
Cols: cols,
Rows: rows,
})); err != nil {
slog.Debug("pty resize error", "error", err)
}
}
case "kill":
if _, err := agent.PtyKill(rpcCtx, connect.NewRequest(&pb.PtyKillRequest{
SandboxId: sandboxID,
Tag: tag,
})); err != nil {
slog.Debug("pty kill error", "error", err)
}
}
rpcCancel()
}
}()
// Keepalive pump: send periodic pings to prevent idle WS closure.
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(ptyKeepaliveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ws.writeJSON(wsPtyOut{Type: "ping"})
case <-ctx.Done():
return
}
}
}()
wg.Wait()
}
// newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars.
func newPtyTag() string {
return "pty-" + id.NewPtyTag()
}

View File

@ -7,11 +7,11 @@ import (
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type sandboxHandler struct {
@ -31,18 +31,19 @@ type createSandboxRequest struct {
}
type sandboxResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Template string `json:"template"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIP string `json:"guest_ip,omitempty"`
HostIP string `json:"host_ip,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
LastActiveAt *string `json:"last_active_at,omitempty"`
LastUpdated string `json:"last_updated"`
ID string `json:"id"`
Status string `json:"status"`
Template string `json:"template"`
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIP string `json:"guest_ip,omitempty"`
HostIP string `json:"host_ip,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
LastActiveAt *string `json:"last_active_at,omitempty"`
LastUpdated string `json:"last_updated"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
@ -56,6 +57,12 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
GuestIP: sb.GuestIp,
HostIP: sb.HostIp,
}
if len(sb.Metadata) > 0 {
var meta map[string]string
if err := json.Unmarshal(sb.Metadata, &meta); err == nil && len(meta) > 0 {
resp.Metadata = meta
}
}
if sb.CreatedAt.Valid {
resp.CreatedAt = sb.CreatedAt.Time.Format(time.RFC3339)
}
@ -73,7 +80,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
return resp
}
// Create handles POST /v1/sandboxes.
// Create handles POST /v1/capsules.
func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createSandboxRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -104,7 +111,7 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
}
// List handles GET /v1/sandboxes.
// List handles GET /v1/capsules.
func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
sandboxes, err := h.svc.List(r.Context(), ac.TeamID)
@ -121,7 +128,7 @@ func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/sandboxes/{id}.
// Get handles GET /v1/capsules/{id}.
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
@ -141,7 +148,7 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Pause handles POST /v1/sandboxes/{id}/pause.
// Pause handles POST /v1/capsules/{id}/pause.
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
@ -163,7 +170,7 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Resume handles POST /v1/sandboxes/{id}/resume.
// Resume handles POST /v1/capsules/{id}/resume.
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
@ -185,7 +192,7 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
}
// Ping handles POST /v1/sandboxes/{id}/ping.
// Ping handles POST /v1/capsules/{id}/ping.
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
@ -205,7 +212,7 @@ func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// Destroy handles DELETE /v1/sandboxes/{id}.
// Destroy handles DELETE /v1/capsules/{id}.
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())

View File

@ -13,14 +13,14 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/internal/validate"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -38,8 +38,8 @@ func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *life
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, templateID pgtype.UUID) error {
hosts, err := h.db.ListActiveHosts(ctx)
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
hosts, err := queries.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
@ -47,7 +47,7 @@ func (h *snapshotHandler) deleteSnapshotBroadcast(ctx context.Context, teamID, t
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
agent, err := pool.GetForHost(host)
if err != nil {
continue
}
@ -69,13 +69,14 @@ type createSnapshotRequest struct {
}
type snapshotResponse struct {
Name string `json:"name"`
Type string `json:"type"`
VCPUs *int32 `json:"vcpus,omitempty"`
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
Name string `json:"name"`
Type string `json:"type"`
VCPUs *int32 `json:"vcpus,omitempty"`
MemoryMB *int32 `json:"memory_mb,omitempty"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func templateToResponse(t db.Template) snapshotResponse {
@ -94,6 +95,12 @@ func templateToResponse(t db.Template) snapshotResponse {
if t.CreatedAt.Valid {
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
}
if len(t.Metadata) > 0 {
var meta map[string]string
if err := json.Unmarshal(t.Metadata, &meta); err == nil && len(meta) > 0 {
resp.Metadata = meta
}
}
return resp
}
@ -126,7 +133,6 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ac := auth.MustFromContext(ctx)
overwrite := r.URL.Query().Get("overwrite") == "true"
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
@ -135,20 +141,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
// Check if name already exists for this team.
if existing, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
if !overwrite {
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
return
}
// Delete old snapshot files from all hosts before removing the DB record.
if err := h.deleteSnapshotBroadcast(ctx, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
return
}
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.
@ -210,13 +206,16 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
}
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
@ -277,7 +276,7 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.deleteSnapshotBroadcast(ctx, tmpl.TeamID, tmpl.ID); err != nil {
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
return
}

View File

@ -5,8 +5,8 @@ import (
"net/http"
"time"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type statsHandler struct {
@ -43,7 +43,7 @@ type statsResponse struct {
Series statsSeriesResponse `json:"series"`
}
// GetStats handles GET /v1/sandboxes/stats?range=5m|1h|6h|24h|30d
// GetStats handles GET /v1/capsules/stats?range=5m|1h|6h|24h|30d
func (h *statsHandler) GetStats(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())

View File

@ -1,6 +1,8 @@
package api
import (
"context"
"fmt"
"log/slog"
"net/http"
"strings"
@ -9,20 +11,22 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type teamHandler struct {
svc *service.TeamService
audit *audit.AuditLogger
svc *service.TeamService
audit *audit.AuditLogger
mailer email.Mailer
}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger) *teamHandler {
return &teamHandler{svc: svc, audit: al}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
return &teamHandler{svc: svc, audit: al, mailer: mailer}
}
// teamResponse is the JSON shape for a team.
@ -131,6 +135,15 @@ func (h *teamHandler) Create(w http.ResponseWriter, r *http.Request) {
return
}
go func() {
if err := h.mailer.Send(context.Background(), ac.Email, "Your team has been created", email.EmailData{
RecipientName: ac.Name,
Message: fmt.Sprintf("Your team \"%s\" has been created on Wrenn. You can now invite members and start creating sandboxes under this team.", req.Name),
}); err != nil {
slog.Warn("failed to send team created email", "email", ac.Email, "error", err)
}
}()
writeJSON(w, http.StatusCreated, teamWithRoleResponse{
teamResponse: teamToResponse(team.Team),
Role: team.Role,
@ -279,6 +292,21 @@ func (h *teamHandler) AddMember(w http.ResponseWriter, r *http.Request) {
if parseErr == nil {
h.audit.LogMemberAdd(r.Context(), ac, targetUserID, member.Email, member.Role)
}
go func() {
team, err := h.svc.GetTeam(context.Background(), teamID)
teamName := "a team"
if err == nil {
teamName = team.Name
}
if err := h.mailer.Send(context.Background(), member.Email, "You've been added to a team on Wrenn", email.EmailData{
RecipientName: member.Name,
Message: fmt.Sprintf("%s has added you to the team \"%s\" on Wrenn.", ac.Name, teamName),
}); err != nil {
slog.Warn("failed to send team invitation email", "email", member.Email, "error", err)
}
}()
writeJSON(w, http.StatusCreated, memberInfoToResponse(member))
}
@ -388,3 +416,87 @@ func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// AdminListTeams handles GET /v1/admin/teams?page=1
// Returns a paginated list of all teams with member counts, owner info, and active sandbox counts.
func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
page := 1
if p := r.URL.Query().Get("page"); p != "" {
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
page = 1
}
}
const perPage = 100
offset := int32((page - 1) * perPage)
teams, total, err := h.svc.AdminListTeams(r.Context(), perPage, offset)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
type adminTeamResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt string `json:"created_at"`
DeletedAt *string `json:"deleted_at"`
MemberCount int32 `json:"member_count"`
OwnerName string `json:"owner_name"`
OwnerEmail string `json:"owner_email"`
ActiveSandboxCount int32 `json:"active_sandbox_count"`
ChannelCount int32 `json:"channel_count"`
}
resp := make([]adminTeamResponse, len(teams))
for i, t := range teams {
r := adminTeamResponse{
ID: id.FormatTeamID(t.ID),
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
CreatedAt: t.CreatedAt.Format(time.RFC3339),
MemberCount: t.MemberCount,
OwnerName: t.OwnerName,
OwnerEmail: t.OwnerEmail,
ActiveSandboxCount: t.ActiveSandboxCount,
ChannelCount: t.ChannelCount,
}
if t.DeletedAt != nil {
s := t.DeletedAt.Format(time.RFC3339)
r.DeletedAt = &s
}
resp[i] = r
}
totalPages := (total + perPage - 1) / perPage
writeJSON(w, http.StatusOK, map[string]any{
"teams": resp,
"total": total,
"page": page,
"per_page": perPage,
"total_pages": totalPages,
})
}
// AdminDeleteTeam handles DELETE /v1/admin/teams/{id}
// Soft-deletes a team and destroys all its active sandboxes.
func (h *teamHandler) AdminDeleteTeam(w http.ResponseWriter, r *http.Request) {
teamIDStr := chi.URLParam(r, "id")
teamID, err := id.ParseTeamID(teamIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid team ID")
return
}
if err := h.svc.AdminDeleteTeam(r.Context(), teamID); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,22 +1,27 @@
package api
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type usersHandler struct {
db *db.Queries
db *db.Queries
svc *service.UserService
}
func newUsersHandler(db *db.Queries) *usersHandler {
return &usersHandler{db: db}
func newUsersHandler(db *db.Queries, svc *service.UserService) *usersHandler {
return &usersHandler{db: db, svc: svc}
}
// Search handles GET /v1/users/search?email=<prefix>
@ -50,3 +55,96 @@ func (h *usersHandler) Search(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, resp)
}
// AdminListUsers handles GET /v1/admin/users?page=1
// Returns a paginated list of all users with team counts.
func (h *usersHandler) AdminListUsers(w http.ResponseWriter, r *http.Request) {
page := 1
if p := r.URL.Query().Get("page"); p != "" {
if _, err := fmt.Sscanf(p, "%d", &page); err != nil || page < 1 {
page = 1
}
}
const perPage = 100
offset := int32((page - 1) * perPage)
users, total, err := h.svc.AdminListUsers(r.Context(), perPage, offset)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
type adminUserResponse struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
Status string `json:"status"`
CreatedAt string `json:"created_at"`
TeamsJoined int32 `json:"teams_joined"`
TeamsOwned int32 `json:"teams_owned"`
}
resp := make([]adminUserResponse, len(users))
for i, u := range users {
resp[i] = adminUserResponse{
ID: id.FormatUserID(u.ID),
Email: u.Email,
Name: u.Name,
IsAdmin: u.IsAdmin,
Status: u.Status,
CreatedAt: u.CreatedAt.Format(time.RFC3339),
TeamsJoined: u.TeamsJoined,
TeamsOwned: u.TeamsOwned,
}
}
totalPages := (total + perPage - 1) / perPage
writeJSON(w, http.StatusOK, map[string]any{
"users": resp,
"total": total,
"page": page,
"per_page": perPage,
"total_pages": totalPages,
})
}
// SetUserActive handles PUT /v1/admin/users/{id}/active
// Enables or disables a user account. Admins cannot deactivate themselves.
func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
userIDStr := chi.URLParam(r, "id")
userID, err := id.ParseUserID(userIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID")
return
}
var req struct {
Active bool `json:"active"`
}
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if ac.UserID == userID && !req.Active {
writeError(w, http.StatusBadRequest, "invalid_request", "cannot deactivate your own account")
return
}
newStatus := "active"
if !req.Active {
newStatus = "disabled"
}
if err := h.svc.SetUserStatus(r.Context(), userID, newStatus); err != nil {
httpStatus, code, msg := serviceErrToHTTP(err)
writeError(w, httpStatus, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}

109
internal/api/helpers_ws.go Normal file
View File

@ -0,0 +1,109 @@
package api
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
func isWebSocketUpgrade(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}
// setAdminWSFlag marks the context as an admin WebSocket route.
func setAdminWSFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
}
// isAdminWSRoute checks if the request context was marked as admin WS.
func isAdminWSRoute(ctx context.Context) bool {
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
return v
}
// wsAuthMsg is the first message a browser WS client sends to authenticate.
type wsAuthMsg struct {
Type string `json:"type"`
Token string `json:"token"`
}
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
// authenticated context. The caller must send this as the first message after
// connecting.
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg wsAuthMsg
if err := conn.ReadJSON(&msg); err != nil {
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
}
conn.SetReadDeadline(time.Time{}) // clear deadline
if msg.Type != "auth" || msg.Token == "" {
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
}
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
}
user, err := queries.GetUserByID(ctx, userID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if user.Status != "active" {
return auth.AuthContext{}, fmt.Errorf("account deactivated")
}
return auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
}, nil
}
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
// returning an AuthContext with the platform team ID.
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
if err != nil {
return auth.AuthContext{}, err
}
user, err := queries.GetUserByID(ctx, ac.UserID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if !user.IsAdmin {
return auth.AuthContext{}, fmt.Errorf("admin access required")
}
ac.TeamID = id.PlatformTeamID
return ac, nil
}

View File

@ -8,10 +8,10 @@ import (
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)

View File

@ -5,7 +5,7 @@ import (
"log/slog"
"time"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// MetricsSampler records per-team sandbox resource usage to

View File

@ -14,7 +14,7 @@ import (
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
type errorResponse struct {
@ -50,8 +50,12 @@ func agentErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", err.Error()
case connect.CodeInvalidArgument:
return http.StatusBadRequest, "invalid_request", err.Error()
case connect.CodeFailedPrecondition:
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
return http.StatusConflict, "conflict", err.Error()
case connect.CodePermissionDenied:
return http.StatusForbidden, "forbidden", err.Error()
case connect.CodeUnimplemented:
return http.StatusNotImplemented, "agent_error", err.Error()
default:
return http.StatusBadGateway, "agent_error", err.Error()
}
@ -90,21 +94,25 @@ func serviceErrToHTTP(err error) (int, string, string) {
}
// Map well-known service error patterns.
// Return generic messages for most cases to avoid leaking internal details.
switch {
case strings.Contains(msg, "not found"):
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
return http.StatusNotFound, "not_found", "resource not found"
case strings.Contains(msg, "not running"):
return http.StatusConflict, "invalid_state", "resource is not running"
case strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", "resource is not paused"
case strings.Contains(msg, "conflict:"):
return http.StatusConflict, "conflict", msg
return http.StatusConflict, "conflict", strings.TrimPrefix(msg, "conflict: ")
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
return http.StatusForbidden, "forbidden", "forbidden"
case strings.Contains(msg, "invalid or expired"):
return http.StatusUnauthorized, "unauthorized", msg
return http.StatusUnauthorized, "unauthorized", "invalid or expired credentials"
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
return http.StatusBadRequest, "invalid_request", "invalid request"
default:
return http.StatusInternalServerError, "internal_error", msg
slog.Error("unhandled service error", "error", err)
return http.StatusInternalServerError, "internal_error", "an internal error occurred"
}
}

View File

@ -3,19 +3,47 @@ package api
import (
"net/http"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// injectPlatformTeam overwrites the AuthContext's TeamID with the platform
// sentinel UUID. This lets existing team-scoped handlers (exec, files, pty,
// metrics) work unchanged under admin routes. Must run after requireAdmin.
func injectPlatformTeam() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := auth.FromContext(r.Context()); !ok {
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
next.ServeHTTP(w, r)
return
}
ac := auth.MustFromContext(r.Context())
ac.TeamID = id.PlatformTeamID
ctx := auth.WithAuthContext(r.Context(), ac)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// WebSocket upgrade requests without auth context are passed through —
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
if isWebSocketUpgrade(r) {
ctx := r.Context()
ctx = setAdminWSFlag(ctx)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}

View File

@ -5,9 +5,9 @@ import (
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
@ -38,9 +38,12 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// Try JWT bearer token.
// Try JWT bearer token from Authorization header.
tokenStr := ""
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr := strings.TrimPrefix(header, "Bearer ")
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr != "" {
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
if err != nil {
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
@ -59,6 +62,18 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
@ -70,7 +85,15 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// WebSocket upgrade requests may not carry auth headers (browsers
// cannot set custom headers on WS connections). Pass through —
// the WS handler authenticates via the first message after upgrade.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
}

View File

@ -3,8 +3,8 @@ package api
import (
"net/http"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,

View File

@ -1,25 +1,37 @@
package api
import (
"log/slog"
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
// signature and expiry, and stamps UserID + TeamID + Email into the request context.
func requireJWT(secret []byte) func(http.Handler) http.Handler {
// requireJWT validates a JWT from the Authorization: Bearer header.
// It also verifies the user is still active in the database.
// WebSocket upgrade requests without an Authorization header are passed through
// — WS handlers authenticate via the first message after upgrade.
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Authorization")
if !strings.HasPrefix(header, "Bearer ") {
var tokenStr string
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr == "" {
// WebSocket upgrade requests may not have an Authorization header
// (browsers cannot set custom headers on WS connections). Let them
// through — the handler authenticates via the first WS message.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}
tokenStr := strings.TrimPrefix(header, "Bearer ")
claims, err := auth.VerifyJWT(secret, tokenStr)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
@ -37,6 +49,18 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,

File diff suppressed because it is too large Load Diff

View File

@ -9,14 +9,16 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/audit"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/auth/oauth"
"git.omukk.dev/wrenn/wrenn/internal/channels"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/internal/scheduler"
"git.omukk.dev/wrenn/wrenn/internal/service"
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
//go:embed openapi.yaml
@ -26,9 +28,12 @@ var openapiYAML []byte
type Server struct {
router chi.Router
BuildSvc *service.BuildService
version string
}
// New constructs the chi router and registers all routes.
// Extensions are called after core routes are registered, allowing enterprise
// or third-party code to add routes and middleware.
func New(
queries *db.Queries,
pool *lifecycle.HostClientPool,
@ -41,6 +46,10 @@ func New(
ca *auth.CA,
al *audit.AuditLogger,
channelSvc *channels.Service,
mailer email.Mailer,
extensions []cpextension.Extension,
sctx cpextension.ServerContext,
version string,
) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
@ -51,27 +60,39 @@ func New(
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool, jwtSecret)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
fsH := newFSHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
authH := newAuthHandler(queries, pgPool, jwtSecret)
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc, al)
hostH := newHostHandler(hostSvc, queries, al)
teamH := newTeamHandler(teamSvc, al)
usersH := newUsersHandler(queries)
teamH := newTeamHandler(teamSvc, al, mailer)
usersH := newUsersHandler(queries, userSvc)
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool)
channelH := newChannelHandler(channelSvc, al)
ptyH := newPtyHandler(queries, pool, jwtSecret)
processH := newProcessHandler(queries, pool, jwtSecret)
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
// Health check.
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"status":"ok","version":%q}`, version)
})
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@ -80,15 +101,31 @@ func New(
// Unauthenticated auth endpoints.
r.Post("/v1/auth/signup", authH.Signup)
r.Post("/v1/auth/login", authH.Login)
r.Post("/v1/auth/activate", authH.Activate)
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
// Unauthenticated: password reset request and confirmation.
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
// JWT-authenticated: self-service account management.
r.Route("/v1/me", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Get("/", meH.GetMe)
r.Patch("/", meH.UpdateName)
r.Post("/password", meH.ChangePassword)
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
r.Delete("/providers/{provider}", meH.DisconnectProvider)
r.Delete("/", meH.DeleteAccount)
})
// JWT-authenticated: switch active team.
r.With(requireJWT(jwtSecret)).Post("/v1/auth/switch-team", authH.SwitchTeam)
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
// JWT-authenticated: API key management.
r.Route("/v1/api-keys", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireJWT(jwtSecret, queries))
r.Post("/", apiKeys.Create)
r.Get("/", apiKeys.List)
r.Delete("/{id}", apiKeys.Delete)
@ -96,7 +133,7 @@ func New(
// JWT-authenticated: team management.
r.Route("/v1/teams", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireJWT(jwtSecret, queries))
r.Get("/", teamH.List)
r.Post("/", teamH.Create)
r.Route("/{id}", func(r chi.Router) {
@ -112,10 +149,12 @@ func New(
})
// JWT-authenticated: user search (for add-member UI).
r.With(requireJWT(jwtSecret)).Get("/v1/users/search", usersH.Search)
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Sandbox lifecycle: accepts API key or JWT bearer token.
r.Route("/v1/sandboxes", func(r chi.Router) {
// Capsule lifecycle: accepts API key or JWT bearer token.
// WebSocket upgrade requests without auth headers are passed through by
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
r.Route("/v1/capsules", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
@ -133,7 +172,14 @@ func New(
r.Post("/files/read", files.Download)
r.Post("/files/stream/write", filesStream.StreamUpload)
r.Post("/files/stream/read", filesStream.StreamDownload)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
})
})
@ -158,7 +204,7 @@ func New(
// JWT-authenticated: host CRUD and tags.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireJWT(jwtSecret, queries))
r.Post("/", hostH.Create)
r.Get("/", hostH.List)
r.Route("/{id}", func(r chi.Router) {
@ -175,7 +221,7 @@ func New(
// JWT-authenticated: notification channels.
r.Route("/v1/channels", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireJWT(jwtSecret, queries))
r.Post("/", channelH.Create)
r.Get("/", channelH.List)
r.Post("/test", channelH.Test)
@ -188,22 +234,51 @@ func New(
})
// JWT-authenticated: audit log.
r.With(requireJWT(jwtSecret)).Get("/v1/audit-logs", auditH.List)
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
// Platform admin routes — require JWT + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireAdmin(queries))
r.Get("/teams", teamH.AdminListTeams)
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
r.Post("/capsules", adminCapsules.Create)
r.Get("/capsules", adminCapsules.List)
r.Route("/capsules/{id}", func(r chi.Router) {
r.Use(injectPlatformTeam())
r.Get("/", adminCapsules.Get)
r.Delete("/", adminCapsules.Destroy)
r.Post("/snapshot", adminCapsules.Snapshot)
r.Post("/exec", exec.Exec)
r.Get("/exec/stream", execStream.ExecStream)
r.Post("/files/write", files.Upload)
r.Post("/files/read", files.Download)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
})
})
return &Server{router: r, BuildSvc: buildSvc}
// Let extensions register their routes after all core routes.
for _, ext := range extensions {
ext.RegisterRoutes(r, sctx)
}
return &Server{router: r, BuildSvc: buildSvc, version: version}
}
// Handler returns the HTTP handler.
@ -211,6 +286,11 @@ func (s *Server) Handler() http.Handler {
return s.router
}
// Router returns the underlying chi.Router for direct access.
func (s *Server) Router() chi.Router {
return s.router
}
func serveOpenAPI(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
_, _ = w.Write(openapiYAML)
@ -223,7 +303,7 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrenn Sandbox API</title>
<title>Wrenn API</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.18.2/swagger-ui.css" integrity="sha384-rcbEi6xgdPk0iWkAQzT2F3FeBJXdG+ydrawGlfHAFIZG7wU6aKbQaRewysYpmrlW" crossorigin="anonymous">
<style>
body { margin: 0; background: #fafafa; }

View File

@ -1,569 +0,0 @@
package audit
import (
"context"
"encoding/json"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/events"
"git.omukk.dev/wrenn/wrenn/internal/id"
)
// AuditLogger writes audit log entries for user-initiated and system events.
// All methods are fire-and-forget: failures are logged via slog and never
// propagated to the caller.
type AuditLogger struct {
db *db.Queries
pub events.EventPublisher // optional — nil disables event publishing
}
// New constructs an AuditLogger without event publishing.
func New(queries *db.Queries) *AuditLogger {
return &AuditLogger{db: queries}
}
// NewWithPublisher constructs an AuditLogger that also publishes channel events.
func NewWithPublisher(queries *db.Queries, pub events.EventPublisher) *AuditLogger {
return &AuditLogger{db: queries, pub: pub}
}
// publish sends an event to the notification stream if a publisher is configured.
func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
if l.pub != nil {
l.pub.Publish(ctx, e)
}
}
// actorToEvent converts auth context fields to an events.Actor.
func actorToEvent(ac auth.AuthContext) events.Actor {
at, aid, aname := actorFields(ac)
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
}
// systemActor returns an events.Actor for system-initiated events.
func systemActor() events.Actor {
return events.Actor{Type: events.ActorSystem}
}
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
// actor_id is stored as a prefixed string in the TEXT column.
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
if ac.UserID.Valid {
return "user", id.FormatUserID(ac.UserID), ac.Name
}
if ac.APIKeyID.Valid {
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
}
return "system", "", ""
}
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
if err := l.db.InsertAuditLog(ctx, p); err != nil {
slog.Warn("audit: failed to write log entry",
"action", p.Action,
"resource_type", p.ResourceType,
"error", err,
)
}
}
func marshalMeta(meta map[string]any) []byte {
if len(meta) == 0 {
return []byte("{}")
}
b, err := json.Marshal(meta)
if err != nil {
return []byte("{}")
}
return b
}
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
func optText(s string) pgtype.Text {
if s == "" {
return pgtype.Text{}
}
return pgtype.Text{String: s, Valid: true}
}
// --- Sandbox events (scope: team) ---
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"template": template}),
})
l.publish(ctx, events.Event{
Event: events.CapsuleCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "resume",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleRunning,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "destroy",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleDestroyed,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// --- Snapshot events (scope: team) ---
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "create",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotDeleted,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
// --- Team events (scope: team) ---
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "team",
ResourceID: optText(id.FormatTeamID(teamID)),
Action: "rename",
Scope: "team",
Status: "info",
Metadata: marshalMeta(map[string]any{"old_name": oldName, "new_name": newName}),
})
}
// --- Channel events (scope: team) ---
func (l *AuditLogger) LogChannelCreate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID, name, provider string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": name, "provider": provider}),
})
}
func (l *AuditLogger) LogChannelUpdate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "update",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelRotateConfig(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "rotate_config",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelDelete(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- API key events (scope: team) ---
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": keyName}),
})
}
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "revoke",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- Member events (scope: admin) ---
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "add",
Scope: "admin",
Status: "success",
Metadata: marshalMeta(map[string]any{"email": targetEmail, "role": role}),
})
}
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "remove",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
actorType, actorID, actorName := actorFields(ac)
resourceID := ""
if ac.UserID.Valid {
resourceID = id.FormatUserID(ac.UserID)
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(resourceID),
Action: "leave",
Scope: "admin",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "role_update",
Scope: "admin",
Status: "info",
Metadata: marshalMeta(map[string]any{"new_role": newRole}),
})
}
// --- Host events (scope: admin) ---
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
// For shared hosts with no owning team, use the caller's team.
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "create",
Scope: "admin",
Status: "success",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "delete",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
// Scoped to "team" so BYOC team members can see when their hosts go down.
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_down",
Scope: "team",
Status: "error",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostDown,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}
// LogHostMarkedUp records a system-initiated host status transition back to online.
// Scoped to "team" so BYOC team members can see when their hosts recover.
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_up",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostUp,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}

View File

@ -1,35 +0,0 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
)
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
// only the hash is stored.
func GenerateAPIKey() (plaintext, hash string, err error) {
b := make([]byte, 16) // 16 bytes → 32 hex chars
if _, err = rand.Read(b); err != nil {
return "", "", fmt.Errorf("generate api key: %w", err)
}
plaintext = "wrn_" + hex.EncodeToString(b)
hash = HashAPIKey(plaintext)
return plaintext, hash, nil
}
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
func HashAPIKey(plaintext string) string {
sum := sha256.Sum256([]byte(plaintext))
return hex.EncodeToString(sum[:])
}
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
func APIKeyPrefix(plaintext string) string {
if len(plaintext) > 10 {
return plaintext[:10]
}
return plaintext
}

View File

@ -1,251 +0,0 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

View File

@ -1,72 +0,0 @@
package auth
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
type contextKey int
const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
TeamID pgtype.UUID
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
return context.WithValue(ctx, authCtxKey, a)
}
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
func FromContext(ctx context.Context) (AuthContext, bool) {
a, ok := ctx.Value(authCtxKey).(AuthContext)
return a, ok
}
// MustFromContext retrieves the AuthContext. Panics if absent — only call
// inside handlers behind auth middleware.
func MustFromContext(ctx context.Context) AuthContext {
a, ok := FromContext(ctx)
if !ok {
panic("auth: MustFromContext called on unauthenticated request")
}
return a
}
const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.
func WithHostContext(ctx context.Context, h HostContext) context.Context {
return context.WithValue(ctx, hostCtxKey, h)
}
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
func HostFromContext(ctx context.Context) (HostContext, bool) {
h, ok := ctx.Value(hostCtxKey).(HostContext)
return h, ok
}
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
// inside handlers behind host token middleware.
func MustHostFromContext(ctx context.Context) HostContext {
h, ok := HostFromContext(ctx)
if !ok {
panic("auth: MustHostFromContext called on unauthenticated request")
}
return h
}

View File

@ -1,113 +0,0 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/id"
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens.
type Claims struct {
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Role string `json:"role"` // owner, admin, or member within TeamID
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
jwt.RegisteredClaims
}
// SignJWT signs a new 6-hour JWT for the given user.
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyJWT parses and validates a user JWT, returning the claims on success.
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return Claims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return Claims{}, fmt.Errorf("invalid token claims")
}
if c.Type == "host" {
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
}
return *c, nil
}
// HostClaims are the JWT payload for host agent tokens.
type HostClaims struct {
Type string `json:"typ"` // always "host"
HostID string `json:"host_id"`
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
// It rejects user JWTs by checking the "typ" claim.
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*HostClaims)
if !ok || !token.Valid {
return HostClaims{}, fmt.Errorf("invalid token claims")
}
if c.Type != "host" {
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
}
return *c, nil
}

View File

@ -1,127 +0,0 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
// GitHubProvider implements Provider for GitHub OAuth.
type GitHubProvider struct {
cfg *oauth2.Config
}
// NewGitHubProvider creates a GitHub OAuth provider.
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
return &GitHubProvider{
cfg: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoints.GitHub,
Scopes: []string{"user:email"},
RedirectURL: callbackURL,
},
}
}
func (p *GitHubProvider) Name() string { return "github" }
func (p *GitHubProvider) AuthCodeURL(state string) string {
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
}
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
token, err := p.cfg.Exchange(ctx, code)
if err != nil {
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
}
client := p.cfg.Client(ctx, token)
profile, err := fetchGitHubUser(client)
if err != nil {
return UserProfile{}, err
}
// GitHub may not include email if the user's email is private.
if profile.Email == "" {
email, err := fetchGitHubPrimaryEmail(client)
if err != nil {
return UserProfile{}, err
}
profile.Email = email
}
return profile, nil
}
type githubUser struct {
ID int64 `json:"id"`
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
resp, err := client.Get("https://api.github.com/user")
if err != nil {
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
}
var u githubUser
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
}
name := u.Name
if name == "" {
name = u.Login
}
return UserProfile{
ProviderID: strconv.FormatInt(u.ID, 10),
Email: u.Email,
Name: name,
}, nil
}
type githubEmail struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
resp, err := client.Get("https://api.github.com/user/emails")
if err != nil {
return "", fmt.Errorf("fetch github emails: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
}
var emails []githubEmail
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
return "", fmt.Errorf("decode github emails: %w", err)
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
return "", fmt.Errorf("github account has no verified primary email")
}

View File

@ -1,41 +0,0 @@
package oauth
import "context"
// UserProfile is the normalized user info returned by an OAuth provider.
type UserProfile struct {
ProviderID string
Email string
Name string
}
// Provider abstracts an OAuth 2.0 identity provider.
type Provider interface {
// Name returns the provider identifier (e.g. "github", "google").
Name() string
// AuthCodeURL returns the URL to redirect the user to for authorization.
AuthCodeURL(state string) string
// Exchange trades an authorization code for a user profile.
Exchange(ctx context.Context, code string) (UserProfile, error)
}
// Registry maps provider names to Provider implementations.
type Registry struct {
providers map[string]Provider
}
// NewRegistry creates an empty provider registry.
func NewRegistry() *Registry {
return &Registry{providers: make(map[string]Provider)}
}
// Register adds a provider to the registry.
func (r *Registry) Register(p Provider) {
r.providers[p.Name()] = p
}
// Get looks up a provider by name.
func (r *Registry) Get(name string) (Provider, bool) {
p, ok := r.providers[name]
return p, ok
}

View File

@ -1,16 +0,0 @@
package auth
import "golang.org/x/crypto/bcrypt"
const bcryptCost = 12
// HashPassword returns the bcrypt hash of a plaintext password.
func HashPassword(plaintext string) (string, error) {
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
return string(b), err
}
// CheckPassword returns nil if plaintext matches the stored hash.
func CheckPassword(hash, plaintext string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
}

View File

@ -1,63 +0,0 @@
package channels
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
// Returns base64(nonce || ciphertext).
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptSecret decrypts a value produced by EncryptSecret.
func DecryptSecret(key [32]byte, encoded string) (string, error) {
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@ -1,36 +0,0 @@
package channels
import (
"context"
"encoding/json"
"fmt"
"github.com/containrrr/shoutrrr"
"git.omukk.dev/wrenn/wrenn/internal/events"
)
// Deliver sends a notification to a single provider with the given config.
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
payload, err := json.Marshal(e)
if err != nil {
return fmt.Errorf("marshal event: %w", err)
}
if provider == "webhook" {
wh := NewWebhookDelivery()
return wh.Deliver(ctx, config["url"], config["secret"], payload)
}
shoutrrrURL, err := ShoutrrrURL(provider, config)
if err != nil {
return fmt.Errorf("build shoutrrr URL: %w", err)
}
msg := FormatMessage(e)
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
return fmt.Errorf("shoutrrr send: %w", err)
}
return nil
}

View File

@ -1,183 +0,0 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/events"
"git.omukk.dev/wrenn/wrenn/internal/id"
)
const (
groupName = "wrenn-channels-v1"
consumerName = "cp-0"
)
// Dispatcher consumes events from the Redis stream and delivers them
// to matching notification channels.
type Dispatcher struct {
rdb *redis.Client
db *db.Queries
encKey [32]byte
webhook *WebhookDelivery
}
// NewDispatcher constructs an event dispatcher.
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
return &Dispatcher{
rdb: rdb,
db: queries,
encKey: encKey,
webhook: NewWebhookDelivery(),
}
}
// Start launches the consumer goroutine. Returns when ctx is cancelled.
func (d *Dispatcher) Start(ctx context.Context) {
go d.run(ctx)
}
func (d *Dispatcher) run(ctx context.Context) {
// Create consumer group idempotently. "$" means only new messages.
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
if err != nil && !isGroupExistsError(err) {
slog.Error("channels: failed to create consumer group", "error", err)
return
}
for {
select {
case <-ctx.Done():
return
default:
}
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: groupName,
Consumer: consumerName,
Streams: []string{streamKey, ">"},
Count: 10,
Block: 5 * time.Second,
}).Result()
if err != nil {
if err == redis.Nil || ctx.Err() != nil {
continue
}
slog.Warn("channels: xreadgroup error", "error", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, msg := range stream.Messages {
d.handleMessage(ctx, msg)
}
}
}
}
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
defer func() {
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
}
}()
payload, ok := msg.Values["payload"].(string)
if !ok {
slog.Warn("channels: message missing payload", "id", msg.ID)
return
}
var event events.Event
if err := json.Unmarshal([]byte(payload), &event); err != nil {
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
return
}
teamID, err := id.ParseTeamID(event.TeamID)
if err != nil {
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
return
}
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
TeamID: teamID,
EventType: event.Event,
})
if err != nil {
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
return
}
for _, ch := range channels {
d.dispatch(ctx, ch, event)
}
}
// retryDelays defines the wait durations before each retry attempt.
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
config, err := d.decryptConfig(ch.Config)
if err != nil {
slog.Warn("channels: failed to decrypt config",
"channel_id", id.FormatChannelID(ch.ID), "error", err)
return
}
chID := id.FormatChannelID(ch.ID)
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
slog.Warn("channels: delivery failed, scheduling retries",
"channel_id", chID, "provider", ch.Provider, "error", err)
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
}
}
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
for i, delay := range retryDelays {
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
if err := Deliver(ctx, provider, config, e); err != nil {
slog.Warn("channels: retry delivery failed",
"channel_id", chID, "provider", provider,
"attempt", i+2, "error", err)
continue
}
return
}
slog.Error("channels: delivery failed after all retries",
"channel_id", chID, "provider", provider, "event", e.Event)
}
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
var encrypted map[string]string
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
return nil, err
}
decrypted := make(map[string]string, len(encrypted))
for k, v := range encrypted {
plaintext, err := DecryptSecret(d.encKey, v)
if err != nil {
return nil, err
}
decrypted[k] = plaintext
}
return decrypted, nil
}
func isGroupExistsError(err error) bool {
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
}

View File

@ -1,65 +0,0 @@
package channels
import (
"fmt"
"strings"
"git.omukk.dev/wrenn/wrenn/internal/events"
)
// FormatMessage produces a human-readable notification string containing
// the event summary, resource details, actor, and timestamp.
func FormatMessage(e events.Event) string {
var b strings.Builder
b.WriteString(formatSummary(e))
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
return b.String()
}
func formatSummary(e events.Event) string {
switch e.Event {
case events.CapsuleCreated:
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
case events.CapsuleRunning:
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
case events.CapsulePaused:
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
case events.CapsuleDestroyed:
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
case events.SnapshotCreated:
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
case events.SnapshotDeleted:
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
case events.HostUp:
return fmt.Sprintf("Host %s is up", e.Resource.ID)
case events.HostDown:
return fmt.Sprintf("Host %s is down", e.Resource.ID)
default:
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
}
}
func formatActor(a events.Actor) string {
switch a.Type {
case events.ActorSystem:
return "system"
case events.ActorUser:
if a.Name != "" {
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
}
return a.ID
case events.ActorAPIKey:
if a.Name != "" {
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
}
return fmt.Sprintf("api_key %s", a.ID)
default:
return string(a.Type)
}
}

View File

@ -1,44 +0,0 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/events"
)
const streamKey = "wrenn:events"
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
type Publisher struct {
rdb *redis.Client
}
// NewPublisher constructs an event publisher.
func NewPublisher(rdb *redis.Client) *Publisher {
return &Publisher{rdb: rdb}
}
// Publish serializes the event and appends it to the global stream.
// Fire-and-forget: failures are logged, never propagated.
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
payload, err := json.Marshal(e)
if err != nil {
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
return
}
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
Stream: streamKey,
MaxLen: 10000,
Approx: true,
Values: map[string]interface{}{
"payload": string(payload),
},
}).Err(); err != nil {
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
}
}

View File

@ -1,298 +0,0 @@
package channels
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/events"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/validate"
)
// Valid providers.
var validProviders = map[string]bool{
"discord": true,
"slack": true,
"teams": true,
"googlechat": true,
"telegram": true,
"matrix": true,
"webhook": true,
}
// Required config fields per provider.
var requiredFields = map[string][]string{
"discord": {"webhook_url"},
"slack": {"webhook_url"},
"teams": {"webhook_url"},
"googlechat": {"webhook_url"},
"telegram": {"bot_token", "chat_id"},
"matrix": {"homeserver_url", "access_token", "room_id"},
"webhook": {"url"},
}
// validEvents maps event type strings to true for validation.
var validEvents map[string]bool
func init() {
validEvents = make(map[string]bool, len(events.AllEventTypes))
for _, et := range events.AllEventTypes {
validEvents[et] = true
}
}
// Service handles channel CRUD operations.
type Service struct {
DB *db.Queries
EncKey [32]byte
}
// CreateParams holds the parameters for creating a channel.
type CreateParams struct {
TeamID pgtype.UUID
Name string
Provider string
Config map[string]string
Events []string
}
// CreateResult holds the result of creating a channel.
type CreateResult struct {
Channel db.Channel
PlaintextSecret string // non-empty only for webhook provider
}
// Create creates a new notification channel.
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
clean, err := cleanName(p.Name)
if err != nil {
return CreateResult{}, err
}
p.Name = clean
if !validProviders[p.Provider] {
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
}
if len(p.Events) == 0 {
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range p.Events {
if !validEvents[et] {
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
// Validate required config fields.
for _, field := range requiredFields[p.Provider] {
if p.Config[field] == "" {
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
var plaintextSecret string
if p.Provider == "webhook" {
if p.Config["secret"] == "" {
secret := generateSecret()
p.Config["secret"] = secret
plaintextSecret = secret
} else {
plaintextSecret = p.Config["secret"]
}
}
// Encrypt config fields.
encrypted := make(map[string]string, len(p.Config))
for k, v := range p.Config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
}
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
ID: id.NewChannelID(),
TeamID: p.TeamID,
Name: p.Name,
Provider: p.Provider,
Config: configJSON,
EventTypes: p.Events,
})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
}
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
}
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
}
// List returns all channels belonging to the given team.
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
return s.DB.ListChannelsByTeam(ctx, teamID)
}
// Get returns a single channel by ID, scoped to the given team.
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// Update updates a channel's name and event types.
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
clean, err := cleanName(name)
if err != nil {
return db.Channel{}, err
}
name = clean
if len(eventTypes) == 0 {
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range eventTypes {
if !validEvents[et] {
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
ID: channelID,
TeamID: teamID,
Name: name,
EventTypes: eventTypes,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
}
return db.Channel{}, fmt.Errorf("update channel: %w", err)
}
return ch, nil
}
// RotateConfig replaces a channel's config with new provider secrets.
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
// Look up the existing channel to get its provider for validation.
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("get channel: %w", err)
}
// Validate required config fields for this provider.
for _, field := range requiredFields[ch.Provider] {
if config[field] == "" {
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
if ch.Provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
// Encrypt all config fields.
encrypted := make(map[string]string, len(config))
for k, v := range config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
}
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
ID: channelID,
TeamID: teamID,
Config: configJSON,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
}
return updated, nil
}
// Test validates config and sends a test notification without persisting anything.
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
if !validProviders[provider] {
return fmt.Errorf("invalid: unsupported provider %q", provider)
}
for _, field := range requiredFields[provider] {
if config[field] == "" {
return fmt.Errorf("invalid: %s is required for %s", field, provider)
}
}
// For webhooks, auto-generate a temporary secret if not provided.
if provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
testEvent := events.Event{
Event: "channel.test",
Timestamp: events.Now(),
TeamID: "test",
Actor: events.Actor{Type: events.ActorSystem},
Resource: events.Resource{ID: "test", Type: "channel"},
}
return Deliver(ctx, provider, config, testEvent)
}
// Delete removes a channel by ID, scoped to the given team.
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// cleanName normalises a channel name: trim whitespace, lowercase, replace
// spaces with hyphens, then validate against SafeName rules.
func cleanName(name string) (string, error) {
name = strings.TrimSpace(name)
name = strings.ToLower(name)
name = strings.ReplaceAll(name, " ", "-")
if err := validate.SafeName(name); err != nil {
return "", fmt.Errorf("invalid: %w", err)
}
return name, nil
}
func generateSecret() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

View File

@ -1,119 +0,0 @@
package channels
import (
"fmt"
"net/url"
"regexp"
"strings"
)
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
switch provider {
case "discord":
return discordURL(config)
case "slack":
return slackURL(config)
case "teams":
return teamsURL(config)
case "googlechat":
return googlechatURL(config)
case "telegram":
return telegramURL(config)
case "matrix":
return matrixURL(config)
default:
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
}
}
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
func discordURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
}
// Path: /api/webhooks/{id}/{token}
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
return "", fmt.Errorf("unexpected discord webhook URL format")
}
webhookID, token := parts[2], parts[3]
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
}
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
func slackURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
}
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "services" {
return "", fmt.Errorf("unexpected slack webhook URL format")
}
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
}
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
func teamsURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("teams webhook_url is required")
}
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
if len(groups) != 5 {
return "", fmt.Errorf("unexpected teams webhook URL format")
}
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
}
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
func googlechatURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("googlechat webhook_url is required")
}
u, err := url.Parse(webhookURL)
if err != nil {
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
}
if u.Host != "chat.googleapis.com" {
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
}
// Rebuild as googlechat:// scheme with same host, path, and query.
u.Scheme = "googlechat"
return u.String(), nil
}
// telegramURL builds telegram://token@telegram/?chats=chatID
func telegramURL(config map[string]string) (string, error) {
token := config["bot_token"]
chatID := config["chat_id"]
if token == "" || chatID == "" {
return "", fmt.Errorf("telegram bot_token and chat_id are required")
}
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
}
// matrixURL builds matrix://user:token@homeserver/room
func matrixURL(config map[string]string) (string, error) {
homeserver := config["homeserver_url"]
token := config["access_token"]
roomID := config["room_id"]
if homeserver == "" || token == "" || roomID == "" {
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
}
// Strip protocol from homeserver URL.
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
// Room ID often starts with ! — URL-encode it.
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
}

View File

@ -1,62 +0,0 @@
package channels
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
type WebhookDelivery struct {
client *http.Client
}
// NewWebhookDelivery constructs a webhook delivery client.
func NewWebhookDelivery() *WebhookDelivery {
return &WebhookDelivery{
client: &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
// Deliver signs and POSTs the event payload to the configured URL.
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
timestamp := time.Now().UTC().Format(time.RFC3339)
deliveryID := uuid.New().String()
// Compute HMAC-SHA256: sign over "timestamp.body".
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(timestamp + "." + string(payload)))
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-WRENN-SIGNATURE", signature)
req.Header.Set("X-Wrenn-Delivery", deliveryID)
req.Header.Set("X-Wrenn-Timestamp", timestamp)
resp, err := d.client.Do(req)
if err != nil {
return fmt.Errorf("http post: %w", err)
}
resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook returned %d", resp.StatusCode)
}
return nil
}

View File

@ -1,70 +0,0 @@
package config
import (
"encoding/hex"
"os"
"github.com/joho/godotenv"
)
// Config holds the control plane configuration.
type Config struct {
DatabaseURL string
RedisURL string
ListenAddr string
JWTSecret string
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
OAuthGitHubClientID string
OAuthGitHubClientSecret string
OAuthRedirectURL string
CPPublicURL string
// Channels — encryption for channel secrets (AES-256-GCM).
EncryptionKeyHex string // WRENN_ENCRYPTION_KEY raw hex string (for validation)
EncryptionKey [32]byte // parsed 32-byte key
}
// Load reads configuration from a .env file (if present) and environment variables.
// Real environment variables take precedence over .env values.
func Load() Config {
// Best-effort load — missing .env file is fine.
_ = godotenv.Load()
cfg := Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
CACert: os.Getenv("WRENN_CA_CERT"),
CAKey: os.Getenv("WRENN_CA_KEY"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
EncryptionKeyHex: os.Getenv("WRENN_ENCRYPTION_KEY"),
}
if cfg.EncryptionKeyHex != "" {
b, err := hex.DecodeString(cfg.EncryptionKeyHex)
if err == nil && len(b) == 32 {
copy(cfg.EncryptionKey[:], b)
}
}
return cfg
}
func envOrDefault(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}

View File

@ -1,177 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: api_keys.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteAPIKey = `-- name: DeleteAPIKey :exec
DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
`
type DeleteAPIKeyParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
_, err := q.db.Exec(ctx, deleteAPIKey, arg.ID, arg.TeamID)
return err
}
const getAPIKeyByHash = `-- name: GetAPIKeyByHash :one
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE key_hash = $1
`
func (q *Queries) GetAPIKeyByHash(ctx context.Context, keyHash string) (TeamApiKey, error) {
row := q.db.QueryRow(ctx, getAPIKeyByHash, keyHash)
var i TeamApiKey
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
)
return i, err
}
const insertAPIKey = `-- name: InsertAPIKey :one
INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used
`
type InsertAPIKeyParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
row := q.db.QueryRow(ctx, insertAPIKey,
arg.ID,
arg.TeamID,
arg.Name,
arg.KeyHash,
arg.KeyPrefix,
arg.CreatedBy,
)
var i TeamApiKey
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
)
return i, err
}
const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TeamApiKey
for rows.Next() {
var i TeamApiKey
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAPIKeysByTeamWithCreator = `-- name: ListAPIKeysByTeamWithCreator :many
SELECT k.id, k.team_id, k.name, k.key_hash, k.key_prefix, k.created_by, k.created_at, k.last_used,
u.email AS creator_email
FROM team_api_keys k
JOIN users u ON u.id = k.created_by
WHERE k.team_id = $1
ORDER BY k.created_at DESC
`
type ListAPIKeysByTeamWithCreatorRow struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
CreatorEmail string `json:"creator_email"`
}
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListAPIKeysByTeamWithCreatorRow
for rows.Next() {
var i ListAPIKeysByTeamWithCreatorRow
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
&i.CreatorEmail,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
`
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
return err
}

View File

@ -1,111 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: audit.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const insertAuditLog = `-- name: InsertAuditLog :exec
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
type InsertAuditLogParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) error {
_, err := q.db.Exec(ctx, insertAuditLog,
arg.ID,
arg.TeamID,
arg.ActorType,
arg.ActorID,
arg.ActorName,
arg.ResourceType,
arg.ResourceID,
arg.Action,
arg.Scope,
arg.Status,
arg.Metadata,
)
return err
}
const listAuditLogs = `-- name: ListAuditLogs :many
SELECT id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata, created_at FROM audit_logs
WHERE team_id = $1
AND scope = ANY($2::text[])
AND (cardinality($3::text[]) = 0 OR resource_type = ANY($3::text[]))
AND (cardinality($4::text[]) = 0 OR action = ANY($4::text[]))
AND ($5::timestamptz IS NULL OR created_at < $5
OR (created_at = $5 AND id < $6))
ORDER BY created_at DESC, id DESC
LIMIT $7
`
type ListAuditLogsParams struct {
TeamID pgtype.UUID `json:"team_id"`
Column2 []string `json:"column_2"`
Column3 []string `json:"column_3"`
Column4 []string `json:"column_4"`
Column5 pgtype.Timestamptz `json:"column_5"`
ID pgtype.UUID `json:"id"`
Limit int32 `json:"limit"`
}
func (q *Queries) ListAuditLogs(ctx context.Context, arg ListAuditLogsParams) ([]AuditLog, error) {
rows, err := q.db.Query(ctx, listAuditLogs,
arg.TeamID,
arg.Column2,
arg.Column3,
arg.Column4,
arg.Column5,
arg.ID,
arg.Limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AuditLog
for rows.Next() {
var i AuditLog
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.ActorType,
&i.ActorID,
&i.ActorName,
&i.ResourceType,
&i.ResourceID,
&i.Action,
&i.Scope,
&i.Status,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,225 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: channels.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteChannelByTeam = `-- name: DeleteChannelByTeam :exec
DELETE FROM channels WHERE id = $1 AND team_id = $2
`
type DeleteChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteChannelByTeam(ctx context.Context, arg DeleteChannelByTeamParams) error {
_, err := q.db.Exec(ctx, deleteChannelByTeam, arg.ID, arg.TeamID)
return err
}
const getChannelByTeam = `-- name: GetChannelByTeam :one
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE id = $1 AND team_id = $2
`
type GetChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetChannelByTeam(ctx context.Context, arg GetChannelByTeamParams) (Channel, error) {
row := q.db.QueryRow(ctx, getChannelByTeam, arg.ID, arg.TeamID)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const insertChannel = `-- name: InsertChannel :one
INSERT INTO channels (id, team_id, name, provider, config, event_types)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type InsertChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) InsertChannel(ctx context.Context, arg InsertChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, insertChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.Provider,
arg.Config,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listChannelsByTeam = `-- name: ListChannelsByTeam :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListChannelsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listChannelsForEvent = `-- name: ListChannelsForEvent :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels
WHERE team_id = $1
AND $2::text = ANY(event_types)
ORDER BY created_at
`
type ListChannelsForEventParams struct {
TeamID pgtype.UUID `json:"team_id"`
EventType string `json:"event_type"`
}
func (q *Queries) ListChannelsForEvent(ctx context.Context, arg ListChannelsForEventParams) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsForEvent, arg.TeamID, arg.EventType)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateChannel = `-- name: UpdateChannel :one
UPDATE channels SET name = $3, event_types = $4, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) UpdateChannel(ctx context.Context, arg UpdateChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateChannelConfig = `-- name: UpdateChannelConfig :one
UPDATE channels SET config = $3, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelConfigParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Config []byte `json:"config"`
}
func (q *Queries) UpdateChannelConfig(ctx context.Context, arg UpdateChannelConfigParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannelConfig, arg.ID, arg.TeamID, arg.Config)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}

View File

@ -1,32 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@ -1,92 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: host_refresh_tokens.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
`
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
return err
}
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
`
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, insertHostRefreshToken,
arg.ID,
arg.HostID,
arg.TokenHash,
arg.ExpiresAt,
)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

View File

@ -1,632 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: hosts.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const addHostTag = `-- name: AddHostTag :exec
INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
`
type AddHostTagParams struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
_, err := q.db.Exec(ctx, addHostTag, arg.HostID, arg.Tag)
return err
}
const deleteHost = `-- name: DeleteHost :exec
DELETE FROM hosts WHERE id = $1
`
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteHost, id)
return err
}
const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
`
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
row := q.db.QueryRow(ctx, getHost, id)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const getHostTags = `-- name: GetHostTags :many
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
`
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
rows, err := q.db.Query(ctx, getHostTags, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var tag string
if err := rows.Scan(&tag); err != nil {
return nil, err
}
items = append(items, tag)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getHostTokensByHost = `-- name: GetHostTokensByHost :many
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
`
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []HostToken
for rows.Next() {
var i HostToken
if err := rows.Scan(
&i.ID,
&i.HostID,
&i.CreatedBy,
&i.CreatedAt,
&i.ExpiresAt,
&i.UsedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
`
type InsertHostParams struct {
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
row := q.db.QueryRow(ctx, insertHost,
arg.ID,
arg.Type,
arg.TeamID,
arg.Provider,
arg.AvailabilityZone,
arg.CreatedBy,
)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const insertHostToken = `-- name: InsertHostToken :one
INSERT INTO host_tokens (id, host_id, created_by, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, created_by, created_at, expires_at, used_at
`
type InsertHostTokenParams struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams) (HostToken, error) {
row := q.db.QueryRow(ctx, insertHostToken,
arg.ID,
arg.HostID,
arg.CreatedBy,
arg.ExpiresAt,
)
var i HostToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.CreatedBy,
&i.CreatedAt,
&i.ExpiresAt,
&i.UsedAt,
)
return i, err
}
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listActiveHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByStatus, status)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
`
func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTag, tag)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByType, type_)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
return err
}
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
ID pgtype.UUID `json:"id"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
result, err := q.db.Exec(ctx, registerHost,
arg.ID,
arg.Arch,
arg.CpuCores,
arg.MemoryMb,
arg.DiskGb,
arg.Address,
arg.CertFingerprint,
arg.CertExpiresAt,
)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const removeHostTag = `-- name: RemoveHostTag :exec
DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
`
type RemoveHostTagParams struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
_, err := q.db.Exec(ctx, removeHostTag, arg.HostID, arg.Tag)
return err
}
const updateHostCert = `-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1
`
type UpdateHostCertParams struct {
ID pgtype.UUID `json:"id"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
return err
}
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
`
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
return err
}
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1
`
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
// Returns 0 if no host was found (deleted), which the caller treats as 404.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const updateHostStatus = `-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`
type UpdateHostStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
_, err := q.db.Exec(ctx, updateHostStatus, arg.ID, arg.Status)
return err
}

View File

@ -1,250 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: metrics.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1
`
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
return err
}
const deleteSandboxMetricPointsByTier = `-- name: DeleteSandboxMetricPointsByTier :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2
`
type DeleteSandboxMetricPointsByTierParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
}
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPointsByTier, arg.SandboxID, arg.Tier)
return err
}
const getLiveMetrics = `-- name: GetLiveMetrics :one
SELECT
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
WHERE team_id = $1
`
type GetLiveMetricsRow struct {
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Reads directly from sandboxes for accurate real-time current values.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
var i GetLiveMetricsRow
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
return i, err
}
const getPeakMetrics = `-- name: GetPeakMetrics :one
SELECT
COALESCE(MAX(running_count), 0)::INTEGER AS peak_running_count,
COALESCE(MAX(vcpus_reserved), 0)::INTEGER AS peak_vcpus,
COALESCE(MAX(memory_mb_reserved), 0)::INTEGER AS peak_memory_mb
FROM sandbox_metrics_snapshots
WHERE team_id = $1
AND sampled_at > NOW() - INTERVAL '30 days'
`
type GetPeakMetricsRow struct {
PeakRunningCount int32 `json:"peak_running_count"`
PeakVcpus int32 `json:"peak_vcpus"`
PeakMemoryMb int32 `json:"peak_memory_mb"`
}
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
var i GetPeakMetricsRow
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
return i, err
}
const getSandboxMetricPoints = `-- name: GetSandboxMetricPoints :many
SELECT ts, cpu_pct, mem_bytes, disk_bytes
FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2 AND ts >= $3
ORDER BY ts ASC
`
type GetSandboxMetricPointsParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
}
type GetSandboxMetricPointsRow struct {
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) GetSandboxMetricPoints(ctx context.Context, arg GetSandboxMetricPointsParams) ([]GetSandboxMetricPointsRow, error) {
rows, err := q.db.Query(ctx, getSandboxMetricPoints, arg.SandboxID, arg.Tier, arg.Ts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetSandboxMetricPointsRow
for rows.Next() {
var i GetSandboxMetricPointsRow
if err := rows.Scan(
&i.Ts,
&i.CpuPct,
&i.MemBytes,
&i.DiskBytes,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertMetricsSnapshot = `-- name: InsertMetricsSnapshot :exec
INSERT INTO sandbox_metrics_snapshots (team_id, running_count, vcpus_reserved, memory_mb_reserved)
VALUES ($1, $2, $3, $4)
`
type InsertMetricsSnapshotParams struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
_, err := q.db.Exec(ctx, insertMetricsSnapshot,
arg.TeamID,
arg.RunningCount,
arg.VcpusReserved,
arg.MemoryMbReserved,
)
return err
}
const insertSandboxMetricPoint = `-- name: InsertSandboxMetricPoint :exec
INSERT INTO sandbox_metric_points (sandbox_id, tier, ts, cpu_pct, mem_bytes, disk_bytes)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
`
type InsertSandboxMetricPointParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
_, err := q.db.Exec(ctx, insertSandboxMetricPoint,
arg.SandboxID,
arg.Tier,
arg.Ts,
arg.CpuPct,
arg.MemBytes,
arg.DiskBytes,
)
return err
}
const pruneOldMetrics = `-- name: PruneOldMetrics :exec
DELETE FROM sandbox_metrics_snapshots
WHERE sampled_at < NOW() - INTERVAL '60 days'
`
func (q *Queries) PruneOldMetrics(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneOldMetrics)
return err
}
const pruneSandboxMetricPoints = `-- name: PruneSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE ts < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT
`
// Remove metric points older than 30 days for destroyed sandboxes.
func (q *Queries) PruneSandboxMetricPoints(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneSandboxMetricPoints)
return err
}
const sampleSandboxMetrics = `-- name: SampleSandboxMetrics :many
SELECT
team_id,
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
GROUP BY team_id
`
type SampleSandboxMetricsRow struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Aggregates per-team resource usage from the live sandboxes table.
// Groups by all teams that have any sandbox row (including stopped) so that
// zero-value snapshots are recorded when all capsules are stopped, keeping the
// time-series charts continuous rather than trailing off into empty space.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) SampleSandboxMetrics(ctx context.Context) ([]SampleSandboxMetricsRow, error) {
rows, err := q.db.Query(ctx, sampleSandboxMetrics)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SampleSandboxMetricsRow
for rows.Next() {
var i SampleSandboxMetricsRow
if err := rows.Scan(
&i.TeamID,
&i.RunningCount,
&i.VcpusReserved,
&i.MemoryMbReserved,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,204 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
import (
"github.com/jackc/pgx/v5/pgtype"
)
type AdminPermission struct {
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type AuditLog struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Channel struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type Host struct {
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
type HostRefreshToken struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
}
type HostTag struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
type HostToken struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
UsedAt pgtype.Timestamptz `json:"used_at"`
}
type OauthProvider struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Sandbox struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
type SandboxMetricPoint struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type SandboxMetricsSnapshot struct {
ID int64 `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
SampledAt pgtype.Timestamptz `json:"sampled_at"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
type Team struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type TeamApiKey struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
}
type Template struct {
Name string `json:"name"`
Type string `json:"type"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
TeamID pgtype.UUID `json:"team_id"`
ID pgtype.UUID `json:"id"`
}
type TemplateBuild struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs []byte `json:"logs"`
Error string `json:"error"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
CompletedAt pgtype.Timestamptz `json:"completed_at"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
type User struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type UsersTeam struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}

View File

@ -1,57 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: oauth.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getOAuthProvider = `-- name: GetOAuthProvider :one
SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers
WHERE provider = $1 AND provider_id = $2
`
type GetOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
}
func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) {
row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID)
var i OauthProvider
err := row.Scan(
&i.Provider,
&i.ProviderID,
&i.UserID,
&i.Email,
&i.CreatedAt,
)
return i, err
}
const insertOAuthProvider = `-- name: InsertOAuthProvider :exec
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
VALUES ($1, $2, $3, $4)
`
type InsertOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
}
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
_, err := q.db.Exec(ctx, insertOAuthProvider,
arg.Provider,
arg.ProviderID,
arg.UserID,
arg.Email,
)
return err
}

View File

@ -1,487 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: sandboxes.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::uuid[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::uuid[])
`
type BulkUpdateStatusByIDsParams struct {
Column1 []pgtype.UUID `json:"column_1"`
Status string `json:"status"`
}
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
_, err := q.db.Exec(ctx, bulkUpdateStatusByIDs, arg.Column1, arg.Status)
return err
}
const getSandbox = `-- name: GetSandbox :one
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1
`
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandbox, id)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes WHERE id = $1 AND team_id = $2
`
type GetSandboxByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandboxByTeam, arg.ID, arg.TeamID)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
SELECT s.status, h.address AS host_address
FROM sandboxes s
JOIN hosts h ON h.id = s.host_id
WHERE s.id = $1 AND s.team_id = $2
`
type GetSandboxProxyTargetParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
type GetSandboxProxyTargetRow struct {
Status string `json:"status"`
HostAddress string `json:"host_address"`
}
// Returns the sandbox status and its host's address in one query.
// Used by SandboxProxyWrapper to avoid two round-trips.
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) {
row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID)
var i GetSandboxProxyTargetRow
err := row.Scan(&i.Status, &i.HostAddress)
return i, err
}
const insertSandbox = `-- name: InsertSandbox :one
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type InsertSandboxParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
}
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, insertSandbox,
arg.ID,
arg.TeamID,
arg.HostID,
arg.Template,
arg.Status,
arg.Vcpus,
arg.MemoryMb,
arg.TimeoutSec,
arg.DiskSizeMb,
arg.TemplateID,
arg.TemplateTeamID,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC
`
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxes = `-- name: ListSandboxes :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes ORDER BY created_at DESC
`
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE host_id = $1 AND status = ANY($2::text[])
ORDER BY created_at DESC
`
type ListSandboxesByHostAndStatusParams struct {
HostID pgtype.UUID `json:"host_id"`
Column2 []string `json:"column_2"`
}
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByHostAndStatus, arg.HostID, arg.Column2)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id FROM sandboxes
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
ORDER BY created_at DESC
`
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
`
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
const updateLastActive = `-- name: UpdateLastActive :exec
UPDATE sandboxes
SET last_active_at = $2,
last_updated = NOW()
WHERE id = $1
`
type UpdateLastActiveParams struct {
ID pgtype.UUID `json:"id"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
}
func (q *Queries) UpdateLastActive(ctx context.Context, arg UpdateLastActiveParams) error {
_, err := q.db.Exec(ctx, updateLastActive, arg.ID, arg.LastActiveAt)
return err
}
const updateSandboxRunning = `-- name: UpdateSandboxRunning :one
UPDATE sandboxes
SET status = 'running',
host_ip = $2,
guest_ip = $3,
started_at = $4,
last_active_at = $4,
last_updated = NOW()
WHERE id = $1
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxRunningParams struct {
ID pgtype.UUID `json:"id"`
HostIp string `json:"host_ip"`
GuestIp string `json:"guest_ip"`
StartedAt pgtype.Timestamptz `json:"started_at"`
}
func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRunningParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, updateSandboxRunning,
arg.ID,
arg.HostIp,
arg.GuestIp,
arg.StartedAt,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}
const updateSandboxStatus = `-- name: UpdateSandboxStatus :one
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = $1
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id
`
type UpdateSandboxStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, updateSandboxStatus, arg.ID, arg.Status)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
)
return i, err
}

View File

@ -1,324 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: teams.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteTeamMember = `-- name: DeleteTeamMember :exec
DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
`
type DeleteTeamMemberParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
_, err := q.db.Exec(ctx, deleteTeamMember, arg.TeamID, arg.UserID)
return err
}
const getBYOCTeams = `-- name: GetBYOCTeams :many
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
`
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
rows, err := q.db.Query(ctx, getBYOCTeams)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Team
for rows.Next() {
var i Team
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
LIMIT 1
`
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeam = `-- name: GetTeam :one
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
`
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getTeam, id)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamBySlug = `-- name: GetTeamBySlug :one
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
`
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
row := q.db.QueryRow(ctx, getTeamBySlug, slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamMembers = `-- name: GetTeamMembers :many
SELECT u.id, u.name, u.email, ut.role, ut.created_at AS joined_at
FROM users_teams ut
JOIN users u ON u.id = ut.user_id
WHERE ut.team_id = $1
ORDER BY ut.created_at
`
type GetTeamMembersRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt pgtype.Timestamptz `json:"joined_at"`
}
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamMembersRow
for rows.Next() {
var i GetTeamMembersRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Email,
&i.Role,
&i.JoinedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getTeamMembership = `-- name: GetTeamMembership :one
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
`
type GetTeamMembershipParams struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
var i UsersTeam
err := row.Scan(
&i.UserID,
&i.TeamID,
&i.IsDefault,
&i.Role,
&i.CreatedAt,
)
return i, err
}
const getTeamsForUser = `-- name: GetTeamsForUser :many
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at, ut.role
FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND t.deleted_at IS NULL
ORDER BY ut.created_at
`
type GetTeamsForUserRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
Role string `json:"role"`
}
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamsForUserRow
for rows.Next() {
var i GetTeamsForUserRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
&i.Role,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, slug, is_byoc, created_at, deleted_at
`
type InsertTeamParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
}
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name, arg.Slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const insertTeamMember = `-- name: InsertTeamMember :exec
INSERT INTO users_teams (user_id, team_id, is_default, role)
VALUES ($1, $2, $3, $4)
`
type InsertTeamMemberParams struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
}
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
_, err := q.db.Exec(ctx, insertTeamMember,
arg.UserID,
arg.TeamID,
arg.IsDefault,
arg.Role,
)
return err
}
const setTeamBYOC = `-- name: SetTeamBYOC :exec
UPDATE teams SET is_byoc = $2 WHERE id = $1
`
type SetTeamBYOCParams struct {
ID pgtype.UUID `json:"id"`
IsByoc bool `json:"is_byoc"`
}
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
return err
}
const softDeleteTeam = `-- name: SoftDeleteTeam :exec
UPDATE teams SET deleted_at = NOW() WHERE id = $1
`
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, softDeleteTeam, id)
return err
}
const updateMemberRole = `-- name: UpdateMemberRole :exec
UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
`
type UpdateMemberRoleParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
Role string `json:"role"`
}
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
_, err := q.db.Exec(ctx, updateMemberRole, arg.TeamID, arg.UserID, arg.Role)
return err
}
const updateTeamName = `-- name: UpdateTeamName :exec
UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
`
type UpdateTeamNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
_, err := q.db.Exec(ctx, updateTeamName, arg.ID, arg.Name)
return err
}

View File

@ -1,241 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: template_builds.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getTemplateBuild = `-- name: GetTemplateBuild :one
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds WHERE id = $1
`
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, getTemplateBuild, id)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type InsertTemplateBuildParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TotalSteps int32 `json:"total_steps"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, insertTemplateBuild,
arg.ID,
arg.Name,
arg.BaseTemplate,
arg.Recipe,
arg.Healthcheck,
arg.Vcpus,
arg.MemoryMb,
arg.TotalSteps,
arg.TemplateID,
arg.TeamID,
arg.SkipPrePost,
)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post FROM template_builds ORDER BY created_at DESC
`
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
rows, err := q.db.Query(ctx, listTemplateBuilds)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TemplateBuild
for rows.Next() {
var i TemplateBuild
if err := rows.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateBuildError = `-- name: UpdateBuildError :exec
UPDATE template_builds
SET error = $2, status = 'failed', completed_at = NOW()
WHERE id = $1
`
type UpdateBuildErrorParams struct {
ID pgtype.UUID `json:"id"`
Error string `json:"error"`
}
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
return err
}
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
UPDATE template_builds
SET current_step = $2, logs = $3
WHERE id = $1
`
type UpdateBuildProgressParams struct {
ID pgtype.UUID `json:"id"`
CurrentStep int32 `json:"current_step"`
Logs []byte `json:"logs"`
}
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
return err
}
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
UPDATE template_builds
SET sandbox_id = $2, host_id = $3
WHERE id = $1
`
type UpdateBuildSandboxParams struct {
ID pgtype.UUID `json:"id"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
}
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
return err
}
const updateBuildStatus = `-- name: UpdateBuildStatus :one
UPDATE template_builds
SET status = $2,
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
WHERE id = $1
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post
`
type UpdateBuildStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
)
return i, err
}

View File

@ -1,351 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: templates.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteTemplate = `-- name: DeleteTemplate :exec
DELETE FROM templates WHERE id = $1
`
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplate, id)
return err
}
const deleteTemplateByTeam = `-- name: DeleteTemplateByTeam :exec
DELETE FROM templates WHERE name = $1 AND team_id = $2
`
type DeleteTemplateByTeamParams struct {
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
_, err := q.db.Exec(ctx, deleteTemplateByTeam, arg.Name, arg.TeamID)
return err
}
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
DELETE FROM templates WHERE team_id = $1
`
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
return err
}
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
`
// Check if a global (platform) template exists with the given name.
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, id)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByName = `-- name: GetTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 AND name = $2
`
type GetTemplateByNameParams struct {
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
}
// Look up a template by team_id and name (exact team match, no global fallback).
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
`
type GetTemplateByTeamParams struct {
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
// Platform templates (team_id = 00000000-...) are visible to all teams.
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const insertTemplate = `-- name: InsertTemplate :one
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id
`
type InsertTemplateParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
row := q.db.QueryRow(ctx, insertTemplate,
arg.ID,
arg.Name,
arg.Type,
arg.Vcpus,
arg.MemoryMb,
arg.SizeBytes,
arg.TeamID,
)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
)
return i, err
}
const listTemplates = `-- name: ListTemplates :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates ORDER BY created_at DESC
`
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplates)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
`
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
`
type ListTemplatesByTeamAndTypeParams struct {
TeamID pgtype.UUID `json:"team_id"`
Type string `json:"type"`
}
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
`
// List templates owned by a specific team (NOT including platform templates).
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByType = `-- name: ListTemplatesByType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id FROM templates WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByType, type_)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,276 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: users.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteAdminPermission = `-- name: DeleteAdminPermission :exec
DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
`
type DeleteAdminPermissionParams struct {
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
_, err := q.db.Exec(ctx, deleteAdminPermission, arg.UserID, arg.Permission)
return err
}
const getAdminPermissions = `-- name: GetAdminPermissions :many
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
`
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AdminPermission
for rows.Next() {
var i AdminPermission
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.Permission,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAdminUsers = `-- name: GetAdminUsers :many
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE is_admin = TRUE ORDER BY created_at
`
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
rows, err := q.db.Query(ctx, getAdminUsers)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
row := q.db.QueryRow(ctx, getUserByEmail, email)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, email, password_hash, name, is_admin, created_at, updated_at FROM users WHERE id = $1
`
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const hasAdminPermission = `-- name: HasAdminPermission :one
SELECT EXISTS(
SELECT 1 FROM admin_permissions WHERE user_id = $1 AND permission = $2
) AS has_permission
`
type HasAdminPermissionParams struct {
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
row := q.db.QueryRow(ctx, hasAdminPermission, arg.UserID, arg.Permission)
var has_permission bool
err := row.Scan(&has_permission)
return has_permission, err
}
const insertAdminPermission = `-- name: InsertAdminPermission :exec
INSERT INTO admin_permissions (id, user_id, permission)
VALUES ($1, $2, $3)
`
type InsertAdminPermissionParams struct {
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
_, err := q.db.Exec(ctx, insertAdminPermission, arg.ID, arg.UserID, arg.Permission)
return err
}
const insertUser = `-- name: InsertUser :one
INSERT INTO users (id, email, password_hash, name)
VALUES ($1, $2, $3, $4)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserParams struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
}
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
row := q.db.QueryRow(ctx, insertUser,
arg.ID,
arg.Email,
arg.PasswordHash,
arg.Name,
)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const insertUserOAuth = `-- name: InsertUserOAuth :one
INSERT INTO users (id, email, name)
VALUES ($1, $2, $3)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at
`
type InsertUserOAuthParams struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email, arg.Name)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`
type SearchUsersByEmailPrefixRow struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
}
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
rows, err := q.db.Query(ctx, searchUsersByEmailPrefix, dollar_1)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SearchUsersByEmailPrefixRow
for rows.Next() {
var i SearchUsersByEmailPrefixRow
if err := rows.Scan(&i.ID, &i.Email); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setUserAdmin = `-- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
`
type SetUserAdminParams struct {
ID pgtype.UUID `json:"id"`
IsAdmin bool `json:"is_admin"`
}
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
return err
}
const updateUserName = `-- name: UpdateUserName :exec
UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
`
type UpdateUserNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
_, err := q.db.Exec(ctx, updateUserName, arg.ID, arg.Name)
return err
}

233
internal/email/email.go Normal file
View File

@ -0,0 +1,233 @@
// Package email provides transactional email sending via SMTP.
//
// Emails are rendered from embedded Go templates (html/template + text/template)
// and sent as multipart/alternative MIME messages. When SMTP is not configured
// (Host is empty), a no-op mailer is returned that logs and discards.
package email
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"log/slog"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net"
"net/smtp"
"net/textproto"
"net/url"
"strconv"
"strings"
)
// Config holds SMTP connection credentials. All fields except Host are
// optional — omitting Host disables email entirely (no-op mailer).
type Config struct {
Host string // SMTP server hostname
Port int // SMTP server port (default 587)
Username string // SMTP auth username
Password string // SMTP auth password
FromEmail string // envelope sender address
}
// Mailer sends transactional emails.
type Mailer interface {
Send(ctx context.Context, to string, subject string, data EmailData) error
}
// EmailData is the generic payload for all transactional emails.
// Templates conditionally render each field based on presence.
type EmailData struct {
RecipientName string // optional — used after "Hello"
Message string // main body (plain text; HTML template wraps it)
Button *Button // optional CTA button
Closing string // optional closing/footer message
}
// Button represents a call-to-action link rendered as a button in HTML
// and as a plain URL in the text variant.
type Button struct {
Text string // button label
URL string // target URL
}
// New constructs a Mailer. If cfg.Host is empty, returns a no-op mailer
// that logs at debug level and discards. Panics if templates fail to parse
// (indicates a build-time bug in embedded templates).
func New(cfg Config) Mailer {
if cfg.Host == "" {
slog.Info("email: SMTP not configured, using no-op mailer")
return &noopMailer{}
}
if cfg.Port == 0 {
cfg.Port = 587
}
tmpl := mustLoadTemplates()
slog.Info("email: SMTP configured", "host", cfg.Host, "port", cfg.Port, "from", cfg.FromEmail)
return &mailer{cfg: cfg, tmpl: tmpl}
}
// mailer is the live SMTP implementation.
type mailer struct {
cfg Config
tmpl *templates
}
func (m *mailer) Send(ctx context.Context, to string, subject string, data EmailData) error {
if data.Button != nil {
u, err := url.Parse(data.Button.URL)
if err != nil || (u.Scheme != "https" && u.Scheme != "http") {
return fmt.Errorf("invalid button URL scheme: %s", data.Button.URL)
}
}
htmlBody, err := m.tmpl.renderHTML(data)
if err != nil {
return fmt.Errorf("render html: %w", err)
}
textBody, err := m.tmpl.renderText(data)
if err != nil {
return fmt.Errorf("render text: %w", err)
}
msg, err := buildMIME(m.cfg.FromEmail, to, subject, htmlBody, textBody)
if err != nil {
return fmt.Errorf("build mime: %w", err)
}
if err := m.send(to, msg); err != nil {
return fmt.Errorf("send email to %s: %w", to, err)
}
slog.Info("email: sent", "to", to, "subject", subject)
return nil
}
// send dials the SMTP server and delivers the message.
// Port 465 uses implicit TLS; all other ports use STARTTLS.
func (m *mailer) send(to string, msg []byte) error {
addr := net.JoinHostPort(m.cfg.Host, strconv.Itoa(m.cfg.Port))
auth := smtp.PlainAuth("", m.cfg.Username, m.cfg.Password, m.cfg.Host)
if m.cfg.Port == 465 {
return m.sendImplicitTLS(addr, auth, to, msg)
}
// STARTTLS (port 587 or other).
return smtp.SendMail(addr, auth, m.cfg.FromEmail, []string{to}, msg)
}
// sendImplicitTLS handles port 465 (SMTPS) where the entire connection is TLS.
func (m *mailer) sendImplicitTLS(addr string, auth smtp.Auth, to string, msg []byte) error {
conn, err := tls.Dial("tcp", addr, &tls.Config{ServerName: m.cfg.Host})
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
defer conn.Close()
c, err := smtp.NewClient(conn, m.cfg.Host)
if err != nil {
return fmt.Errorf("smtp client: %w", err)
}
defer c.Close()
if err := c.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err := c.Mail(m.cfg.FromEmail); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err := c.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := c.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err := w.Write(msg); err != nil {
return fmt.Errorf("smtp write: %w", err)
}
if err := w.Close(); err != nil {
return fmt.Errorf("smtp close data: %w", err)
}
return c.Quit()
}
// buildMIME assembles a multipart/alternative message with text and HTML parts.
// Both parts are quoted-printable encoded per RFC 2045.
func buildMIME(from, to, subject, htmlBody, textBody string) ([]byte, error) {
var headerBuf bytes.Buffer
var bodyBuf bytes.Buffer
// Sanitize header values to prevent header injection.
from = sanitizeHeader(from)
to = sanitizeHeader(to)
// Encode "From" with display name.
encodedFrom := mime.QEncoding.Encode("utf-8", "Wrenn") + " <" + from + ">"
// Build multipart body first to get the boundary.
mw := multipart.NewWriter(&bodyBuf)
// Text part (first = lowest preference per RFC 2046).
textPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"text/plain; charset=utf-8"},
"Content-Transfer-Encoding": {"quoted-printable"},
})
if err != nil {
return nil, err
}
qpw := quotedprintable.NewWriter(textPart)
if _, err := qpw.Write([]byte(textBody)); err != nil {
return nil, err
}
if err := qpw.Close(); err != nil {
return nil, err
}
// HTML part (second = highest preference).
htmlPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"text/html; charset=utf-8"},
"Content-Transfer-Encoding": {"quoted-printable"},
})
if err != nil {
return nil, err
}
qpw = quotedprintable.NewWriter(htmlPart)
if _, err := qpw.Write([]byte(htmlBody)); err != nil {
return nil, err
}
if err := qpw.Close(); err != nil {
return nil, err
}
if err := mw.Close(); err != nil {
return nil, err
}
// Write headers.
fmt.Fprintf(&headerBuf, "From: %s\r\n", encodedFrom)
fmt.Fprintf(&headerBuf, "To: %s\r\n", to)
fmt.Fprintf(&headerBuf, "Subject: %s\r\n", mime.QEncoding.Encode("utf-8", subject))
fmt.Fprintf(&headerBuf, "MIME-Version: 1.0\r\n")
fmt.Fprintf(&headerBuf, "Content-Type: multipart/alternative; boundary=\"%s\"\r\n", mw.Boundary())
fmt.Fprintf(&headerBuf, "\r\n")
headerBuf.Write(bodyBuf.Bytes())
return headerBuf.Bytes(), nil
}
// sanitizeHeader strips CR and LF characters to prevent SMTP header injection.
func sanitizeHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// noopMailer discards emails when SMTP is not configured.
type noopMailer struct{}
func (n *noopMailer) Send(_ context.Context, to string, subject string, _ EmailData) error {
slog.Debug("email: no-op send", "to", to, "subject", subject)
return nil
}

View File

@ -0,0 +1,191 @@
package email
import (
"context"
"strings"
"testing"
)
func TestNoopMailerDoesNotError(t *testing.T) {
m := &noopMailer{}
err := m.Send(context.Background(), "test@example.com", "Test Subject", EmailData{
RecipientName: "Alice",
Message: "Hello world",
})
if err != nil {
t.Fatalf("noopMailer.Send() returned error: %v", err)
}
}
func TestNewReturnsNoopWhenHostEmpty(t *testing.T) {
m := New(Config{})
if _, ok := m.(*noopMailer); !ok {
t.Fatalf("expected noopMailer, got %T", m)
}
}
func TestNewReturnsMailerWhenHostSet(t *testing.T) {
m := New(Config{Host: "smtp.example.com"})
if _, ok := m.(*mailer); !ok {
t.Fatalf("expected *mailer, got %T", m)
}
}
func TestTemplateRenderHTML(t *testing.T) {
tmpl := mustLoadTemplates()
tests := []struct {
name string
data EmailData
want []string // substrings that must appear in output
}{
{
name: "with all fields",
data: EmailData{
RecipientName: "Alice",
Message: "Welcome to Wrenn!",
Button: &Button{Text: "Get Started", URL: "https://wrenn.dev"},
Closing: "See you soon.",
},
want: []string{"Alice", "Welcome to Wrenn!", "Get Started", "https://wrenn.dev", "See you soon."},
},
{
name: "message only",
data: EmailData{
Message: "Your password has been changed.",
},
want: []string{"Your password has been changed."},
},
{
name: "with button no closing",
data: EmailData{
RecipientName: "Bob",
Message: "Reset your password.",
Button: &Button{Text: "Reset Password", URL: "https://wrenn.dev/reset?token=abc"},
},
want: []string{"Bob", "Reset your password.", "Reset Password", "https://wrenn.dev/reset?token=abc"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
html, err := tmpl.renderHTML(tt.data)
if err != nil {
t.Fatalf("renderHTML() error: %v", err)
}
for _, s := range tt.want {
if !strings.Contains(html, s) {
t.Errorf("renderHTML() missing substring %q", s)
}
}
// Verify basic HTML structure.
if !strings.Contains(html, "<!DOCTYPE html>") {
t.Error("renderHTML() missing DOCTYPE")
}
if !strings.Contains(html, "wrenn.dev") {
t.Error("renderHTML() missing wrenn.dev reference")
}
})
}
}
func TestTemplateRenderText(t *testing.T) {
tmpl := mustLoadTemplates()
tests := []struct {
name string
data EmailData
want []string
}{
{
name: "with all fields",
data: EmailData{
RecipientName: "Alice",
Message: "Welcome to Wrenn!",
Button: &Button{Text: "Get Started", URL: "https://wrenn.dev"},
Closing: "See you soon.",
},
want: []string{"Hello Alice", "Welcome to Wrenn!", "Get Started: https://wrenn.dev", "See you soon."},
},
{
name: "message only",
data: EmailData{
Message: "Done.",
},
want: []string{"Hello,", "Done."},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, err := tmpl.renderText(tt.data)
if err != nil {
t.Fatalf("renderText() error: %v", err)
}
for _, s := range tt.want {
if !strings.Contains(text, s) {
t.Errorf("renderText() missing substring %q\nGot:\n%s", s, text)
}
}
})
}
}
func TestBuildMIME(t *testing.T) {
msg, err := buildMIME("noreply@wrenn.dev", "user@example.com", "Test Subject", "<h1>HTML</h1>", "Plain text")
if err != nil {
t.Fatalf("buildMIME() error: %v", err)
}
s := string(msg)
if !strings.Contains(s, "From:") {
t.Error("missing From header")
}
if !strings.Contains(s, "To: user@example.com") {
t.Error("missing To header")
}
if !strings.Contains(s, "Wrenn") {
t.Error("missing Wrenn sender name")
}
if !strings.Contains(s, "multipart/alternative") {
t.Error("missing multipart/alternative content type")
}
if !strings.Contains(s, "text/plain") {
t.Error("missing text/plain part")
}
if !strings.Contains(s, "text/html") {
t.Error("missing text/html part")
}
}
func TestBuildMIMENonASCII(t *testing.T) {
msg, err := buildMIME("noreply@wrenn.dev", "user@example.com", "Test", "<p>\u00c5ngstr\u00f6m</p>", "Hello \u00c5ngstr\u00f6m")
if err != nil {
t.Fatalf("buildMIME() error: %v", err)
}
s := string(msg)
// Non-ASCII characters should be QP-encoded, not appear as raw bytes.
// \u00c5 (U+00C5, 0xC3 0x85 in UTF-8) should be encoded as =C3=85.
if !strings.Contains(s, "=C3=85") {
t.Error("non-ASCII character not quoted-printable encoded")
}
}
func TestSanitizeHeader(t *testing.T) {
tests := []struct {
input string
want string
}{
{"normal@example.com", "normal@example.com"},
{"injected\r\nBcc: evil@example.com", "injectedBcc: evil@example.com"},
{"has\nnewline", "hasnewline"},
{"has\rcarriage", "hascarriage"},
}
for _, tt := range tests {
got := sanitizeHeader(tt.input)
if got != tt.want {
t.Errorf("sanitizeHeader(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}

View File

@ -0,0 +1,52 @@
package email
import (
"bytes"
"embed"
"fmt"
"html/template"
text_template "text/template"
)
//go:embed templates/*.html templates/*.txt
var templateFS embed.FS
// templates holds the parsed HTML and plain-text template sets.
type templates struct {
html *template.Template
text *text_template.Template
}
// mustLoadTemplates parses all embedded templates. Panics on error
// because malformed templates are a build-time bug.
func mustLoadTemplates() *templates {
html, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
panic(fmt.Sprintf("email: failed to parse HTML templates: %v", err))
}
text, err := text_template.ParseFS(templateFS, "templates/*.txt")
if err != nil {
panic(fmt.Sprintf("email: failed to parse text templates: %v", err))
}
return &templates{html: html, text: text}
}
// renderHTML executes the HTML base template with the given data.
func (t *templates) renderHTML(data EmailData) (string, error) {
var buf bytes.Buffer
if err := t.html.ExecuteTemplate(&buf, "base.html", data); err != nil {
return "", fmt.Errorf("execute html template: %w", err)
}
return buf.String(), nil
}
// renderText executes the plain-text base template with the given data.
func (t *templates) renderText(data EmailData) (string, error) {
var buf bytes.Buffer
if err := t.text.ExecuteTemplate(&buf, "base.txt", data); err != nil {
return "", fmt.Errorf("execute text template: %w", err)
}
return buf.String(), nil
}

View File

@ -0,0 +1,119 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<title>Wrenn</title>
<!--[if mso]>
<noscript>
<xml>
<o:OfficeDocumentSettings>
<o:PixelsPerInch>96</o:PixelsPerInch>
</o:OfficeDocumentSettings>
</xml>
</noscript>
<![endif]-->
<style type="text/css">
body, table, td, a { -webkit-text-size-adjust: 100%; -ms-text-size-adjust: 100%; }
table, td { mso-table-lspace: 0pt; mso-table-rspace: 0pt; }
img { -ms-interpolation-mode: bicubic; border: 0; height: auto; line-height: 100%; outline: none; text-decoration: none; }
body { margin: 0; padding: 0; width: 100% !important; }
</style>
</head>
<body style="margin: 0; padding: 0; background-color: #f4f3f1; font-family: 'Manrope', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; -webkit-font-smoothing: antialiased;">
<!-- Outer wrapper -->
<table role="presentation" cellpadding="0" cellspacing="0" width="100%" style="background-color: #f4f3f1;">
<tr>
<td align="center" style="padding: 40px 16px;">
<!-- Logo + Wordmark -->
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px;">
<tr>
<td align="left" style="padding-bottom: 32px;">
<a href="https://wrenn.dev" style="text-decoration: none;">
<table role="presentation" cellpadding="0" cellspacing="0">
<tr>
<td style="vertical-align: middle;">
<img src="https://wrenn.dev/logo.png" alt="Wrenn" width="36" height="36" style="display: block; border-radius: 6px;">
</td>
<td style="vertical-align: middle; padding-left: 10px;">
<span style="font-family: 'Alice', Georgia, 'Times New Roman', serif; font-size: 22px; color: #1a1917; letter-spacing: 0.01em;">Wrenn</span>
</td>
</tr>
</table>
</a>
</td>
</tr>
</table>
<!-- Card -->
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px; background-color: #ffffff; border: 1px solid #e5e4e0; border-radius: 8px;">
<tr>
<td style="padding: 44px 48px;">
<!-- Greeting -->
<p style="margin: 0 0 8px 0; font-size: 15px; line-height: 1.6; color: #3a3835;">
Hello{{if .RecipientName}} {{.RecipientName}}{{end}},
</p>
<!-- Message -->
<p style="margin: 0 0 36px 0; font-size: 15px; line-height: 1.7; color: #3a3835;">
{{.Message}}
</p>
<!-- Button -->
{{if .Button}}
<table role="presentation" cellpadding="0" cellspacing="0" style="margin: 0 0 36px 0;">
<tr>
<td align="center" style="background-color: #5e8c58; border-radius: 5px;">
<!--[if mso]>
<v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="{{.Button.URL}}" style="height:44px;v-text-anchor:middle;width:200px;" arcsize="12%" strokecolor="#5e8c58" fillcolor="#5e8c58">
<w:anchorlock/>
<center style="color:#ffffff;font-family:'Manrope',-apple-system,sans-serif;font-size:14px;font-weight:600;">{{.Button.Text}}</center>
</v:roundrect>
<![endif]-->
<!--[if !mso]><!-->
<a href="{{.Button.URL}}" target="_blank" style="display: inline-block; padding: 12px 32px; font-size: 14px; font-weight: 600; color: #ffffff; text-decoration: none; border-radius: 5px; background-color: #5e8c58;">
{{.Button.Text}}
</a>
<!--<![endif]-->
</td>
</tr>
</table>
<p style="margin: 0 0 12px 0; font-size: 12px; line-height: 1.5; color: #b5b0a8;">
If the button doesn't work, copy and paste this URL into your browser:<br>
<a href="{{.Button.URL}}" style="color: #5e8c58; word-break: break-all;">{{.Button.URL}}</a>
</p>
{{end}}
<!-- Closing -->
{{if .Closing}}
<p style="margin: {{if .Button}}20px{{else}}0{{end}} 0 0 0; font-size: 15px; line-height: 1.7; color: #3a3835;">
{{.Closing}}
</p>
{{end}}
</td>
</tr>
</table>
<!-- Footer -->
<table role="presentation" cellpadding="0" cellspacing="0" width="560" style="max-width: 560px;">
<tr>
<td style="padding: 32px 0 16px 0; text-align: center;">
<p style="margin: 0; font-size: 12px; line-height: 1.5; color: #9b9790;">
This is a transactional email from <a href="https://wrenn.dev" style="color: #5e8c58; text-decoration: none;">Wrenn</a>.
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>

View File

@ -0,0 +1,13 @@
Hello{{if .RecipientName}} {{.RecipientName}}{{end}},
{{.Message}}
{{if .Button}}
{{.Button.Text}}: {{.Button.URL}}
{{end}}{{if .Closing}}
{{.Closing}}
{{end}}
---
This is a transactional email from Wrenn (https://wrenn.dev).

View File

@ -3,6 +3,7 @@ package envdclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
@ -268,6 +269,82 @@ func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
return data, nil
}
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which quiesces
// continuous goroutines (port scanner, forwarder) and forces a GC cycle before
// Firecracker takes a VM snapshot. This ensures the Go runtime's page allocator
// is in a consistent state when vCPUs are frozen.
//
// Best-effort: the caller should log a warning on error but not abort the pause.
func (c *Client) PrepareSnapshot(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/snapshot/prepare", nil)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("prepare snapshot: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("prepare snapshot: status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
// env vars and the corresponding files under /run/wrenn/ inside the guest.
// Must be called after snapshot restore so envd picks up the new sandbox's metadata.
func (c *Client) PostInit(ctx context.Context) error {
return c.PostInitWithDefaults(ctx, "", nil)
}
// PostInitWithDefaults calls envd's POST /init endpoint with optional default
// user and environment variables. These are applied to envd's defaults so all
// subsequent process executions use them.
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string) error {
var body io.Reader
if defaultUser != "" || len(envVars) > 0 {
payload := make(map[string]any)
if defaultUser != "" {
payload["defaultUser"] = defaultUser
}
if len(envVars) > 0 {
payload["envVars"] = envVars
}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal init body: %w", err)
}
body = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", body)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("post init: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("post init: status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// ListDir lists directory contents inside the sandbox.
func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdpb.ListDirResponse, error) {
req := connect.NewRequest(&envdpb.ListDirRequest{
@ -282,3 +359,30 @@ func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdp
return resp.Msg, nil
}
// MakeDir creates a directory inside the sandbox.
func (c *Client) MakeDir(ctx context.Context, path string) (*envdpb.MakeDirResponse, error) {
req := connect.NewRequest(&envdpb.MakeDirRequest{
Path: path,
})
resp, err := c.filesystem.MakeDir(ctx, req)
if err != nil {
return nil, fmt.Errorf("make dir: %w", err)
}
return resp.Msg, nil
}
// Remove removes a file or directory inside the sandbox.
func (c *Client) Remove(ctx context.Context, path string) error {
req := connect.NewRequest(&envdpb.RemoveRequest{
Path: path,
})
if _, err := c.filesystem.Remove(ctx, req); err != nil {
return fmt.Errorf("remove: %w", err)
}
return nil
}

View File

@ -2,7 +2,9 @@ package envdclient
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
@ -31,6 +33,38 @@ func (c *Client) WaitUntilReady(ctx context.Context) error {
}
}
// FetchVersion queries envd's health endpoint and returns the reported version.
func (c *Client) FetchVersion(ctx context.Context) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)
if err != nil {
return "", fmt.Errorf("build health request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetch envd version: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
return "", fmt.Errorf("health check returned %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil || len(body) == 0 {
return "", nil // envd may not support version reporting yet
}
var data struct {
Version string `json:"version"`
}
if err := json.Unmarshal(body, &data); err != nil {
return "", nil // non-JSON response, old envd
}
return data.Version, nil
}
// healthCheck sends a single GET /health request to envd.
func (c *Client) healthCheck(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)

View File

@ -0,0 +1,187 @@
package envdclient
import (
"context"
"fmt"
"io"
"log/slog"
"connectrpc.com/connect"
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
)
// ProcessInfo holds metadata about a running process inside the sandbox.
type ProcessInfo struct {
PID uint32
Tag string
Cmd string
Args []string
}
// StartBackground starts a process that runs independently of the RPC stream.
// It opens a Start stream, reads the first StartEvent to obtain the PID,
// then closes the stream. The process continues running inside the VM because
// envd binds it to context.Background().
func (c *Client) StartBackground(ctx context.Context, tag, cmd string, args []string, envs map[string]string, cwd string) (uint32, error) {
stdin := false
cfg := &envdpb.ProcessConfig{
Cmd: cmd,
Args: args,
Envs: envs,
}
if cwd != "" {
cfg.Cwd = &cwd
}
req := connect.NewRequest(&envdpb.StartRequest{
Process: cfg,
Tag: &tag,
Stdin: &stdin,
})
stream, err := c.process.Start(ctx, req)
if err != nil {
return 0, fmt.Errorf("start background process: %w", err)
}
defer stream.Close()
// Read events until we get the StartEvent with the PID.
for stream.Receive() {
msg := stream.Msg()
if msg.Event == nil {
continue
}
if start, ok := msg.Event.GetEvent().(*envdpb.ProcessEvent_Start); ok {
return start.Start.GetPid(), nil
}
}
if err := stream.Err(); err != nil && err != io.EOF {
return 0, fmt.Errorf("start background process stream: %w", err)
}
return 0, fmt.Errorf("start background process: no start event received")
}
// ConnectProcess re-attaches to a running process by PID or tag and returns
// a channel of streaming events. The channel is closed when the process ends
// or the context is cancelled.
func (c *Client) ConnectProcess(ctx context.Context, pid uint32, tag string) (<-chan ExecStreamEvent, error) {
var selector *envdpb.ProcessSelector
if tag != "" {
selector = &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
}
} else {
selector = &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Pid{Pid: pid},
}
}
stream, err := c.process.Connect(ctx, connect.NewRequest(&envdpb.ConnectRequest{
Process: selector,
}))
if err != nil {
return nil, fmt.Errorf("connect process: %w", err)
}
ch := make(chan ExecStreamEvent, 16)
go func() {
defer close(ch)
defer stream.Close()
for stream.Receive() {
msg := stream.Msg()
if msg.Event == nil {
continue
}
var ev ExecStreamEvent
switch e := msg.Event.GetEvent().(type) {
case *envdpb.ProcessEvent_Start:
ev = ExecStreamEvent{Type: "start", PID: e.Start.GetPid()}
case *envdpb.ProcessEvent_Data:
switch o := e.Data.GetOutput().(type) {
case *envdpb.ProcessEvent_DataEvent_Stdout:
ev = ExecStreamEvent{Type: "stdout", Data: o.Stdout}
case *envdpb.ProcessEvent_DataEvent_Stderr:
ev = ExecStreamEvent{Type: "stderr", Data: o.Stderr}
default:
continue
}
case *envdpb.ProcessEvent_End:
ev = ExecStreamEvent{Type: "end", ExitCode: e.End.GetExitCode()}
if e.End.Error != nil {
ev.Error = e.End.GetError()
}
case *envdpb.ProcessEvent_Keepalive:
continue
}
select {
case ch <- ev:
case <-ctx.Done():
return
}
}
if err := stream.Err(); err != nil && err != io.EOF {
slog.Debug("connect process stream error", "error", err)
}
}()
return ch, nil
}
// ListProcesses returns all running processes inside the sandbox.
func (c *Client) ListProcesses(ctx context.Context) ([]ProcessInfo, error) {
resp, err := c.process.List(ctx, connect.NewRequest(&envdpb.ListRequest{}))
if err != nil {
return nil, fmt.Errorf("list processes: %w", err)
}
procs := make([]ProcessInfo, 0, len(resp.Msg.Processes))
for _, p := range resp.Msg.Processes {
info := ProcessInfo{
PID: p.Pid,
}
if p.Tag != nil {
info.Tag = *p.Tag
}
if p.Config != nil {
info.Cmd = p.Config.Cmd
info.Args = p.Config.Args
}
procs = append(procs, info)
}
return procs, nil
}
// KillProcess sends a signal to a process identified by PID or tag.
func (c *Client) KillProcess(ctx context.Context, pid uint32, tag string, signal envdpb.Signal) error {
var selector *envdpb.ProcessSelector
if tag != "" {
selector = &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
}
} else {
selector = &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Pid{Pid: pid},
}
}
_, err := c.process.SendSignal(ctx, connect.NewRequest(&envdpb.SendSignalRequest{
Process: selector,
Signal: signal,
}))
if err != nil {
return fmt.Errorf("kill process: %w", err)
}
return nil
}

220
internal/envdclient/pty.go Normal file
View File

@ -0,0 +1,220 @@
package envdclient
import (
"context"
"fmt"
"io"
"log/slog"
"connectrpc.com/connect"
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
)
// PtyEvent represents a single event from a PTY output stream.
type PtyEvent struct {
Type string // "started", "output", "end"
PID uint32
Data []byte
ExitCode int32
Error string
}
// PtyStart starts a new PTY process in the guest and returns a channel of events.
// The tag is the stable identifier used to reconnect via PtyConnect.
// The channel is closed when the process ends or ctx is cancelled.
// NOTE: The user parameter from PtyAttachRequest is not yet supported by envd's
// ProcessConfig proto. When envd adds user support, thread it through here.
func (c *Client) PtyStart(ctx context.Context, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan PtyEvent, error) {
stdin := true
cfg := &envdpb.ProcessConfig{
Cmd: cmd,
Args: args,
Envs: envs,
}
if cwd != "" {
cfg.Cwd = &cwd
}
req := connect.NewRequest(&envdpb.StartRequest{
Process: cfg,
Pty: &envdpb.PTY{
Size: &envdpb.PTY_Size{
Cols: cols,
Rows: rows,
},
},
Tag: &tag,
Stdin: &stdin,
})
stream, err := c.process.Start(ctx, req)
if err != nil {
return nil, fmt.Errorf("pty start: %w", err)
}
return drainPtyStream(ctx, &startStream{s: stream}, true), nil
}
// PtyConnect re-attaches to an existing PTY process by tag.
// Returns a channel of output events starting from the current point.
func (c *Client) PtyConnect(ctx context.Context, tag string) (<-chan PtyEvent, error) {
req := connect.NewRequest(&envdpb.ConnectRequest{
Process: &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
},
})
stream, err := c.process.Connect(ctx, req)
if err != nil {
return nil, fmt.Errorf("pty connect: %w", err)
}
return drainPtyStream(ctx, &connectStream{s: stream}, false), nil
}
// PtySendInput sends raw bytes to the PTY process identified by tag.
func (c *Client) PtySendInput(ctx context.Context, tag string, data []byte) error {
req := connect.NewRequest(&envdpb.SendInputRequest{
Process: &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
},
Input: &envdpb.ProcessInput{
Input: &envdpb.ProcessInput_Pty{Pty: data},
},
})
if _, err := c.process.SendInput(ctx, req); err != nil {
return fmt.Errorf("pty send input: %w", err)
}
return nil
}
// PtyResize updates the terminal dimensions for the PTY process identified by tag.
func (c *Client) PtyResize(ctx context.Context, tag string, cols, rows uint32) error {
req := connect.NewRequest(&envdpb.UpdateRequest{
Process: &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
},
Pty: &envdpb.PTY{
Size: &envdpb.PTY_Size{
Cols: cols,
Rows: rows,
},
},
})
if _, err := c.process.Update(ctx, req); err != nil {
return fmt.Errorf("pty resize: %w", err)
}
return nil
}
// PtyKill sends SIGKILL to the PTY process identified by tag.
func (c *Client) PtyKill(ctx context.Context, tag string) error {
req := connect.NewRequest(&envdpb.SendSignalRequest{
Process: &envdpb.ProcessSelector{
Selector: &envdpb.ProcessSelector_Tag{Tag: tag},
},
Signal: envdpb.Signal_SIGNAL_SIGKILL,
})
if _, err := c.process.SendSignal(ctx, req); err != nil {
return fmt.Errorf("pty kill: %w", err)
}
return nil
}
// eventStream is an interface covering both StartResponse and ConnectResponse streams.
type eventStream interface {
Receive() bool
Err() error
Close() error
}
type startStream struct {
s *connect.ServerStreamForClient[envdpb.StartResponse]
}
func (s *startStream) Receive() bool { return s.s.Receive() }
func (s *startStream) Err() error { return s.s.Err() }
func (s *startStream) Close() error { return s.s.Close() }
func (s *startStream) Event() *envdpb.ProcessEvent {
return s.s.Msg().GetEvent()
}
type connectStream struct {
s *connect.ServerStreamForClient[envdpb.ConnectResponse]
}
func (s *connectStream) Receive() bool { return s.s.Receive() }
func (s *connectStream) Err() error { return s.s.Err() }
func (s *connectStream) Close() error { return s.s.Close() }
func (s *connectStream) Event() *envdpb.ProcessEvent {
return s.s.Msg().GetEvent()
}
type eventProvider interface {
eventStream
Event() *envdpb.ProcessEvent
}
// drainPtyStream reads events from either a Start or Connect stream and maps
// them into PtyEvent values on a channel.
func drainPtyStream(ctx context.Context, stream eventProvider, expectStart bool) <-chan PtyEvent {
ch := make(chan PtyEvent, 16)
go func() {
defer close(ch)
defer stream.Close()
for stream.Receive() {
event := stream.Event()
if event == nil {
continue
}
var ev PtyEvent
switch e := event.GetEvent().(type) {
case *envdpb.ProcessEvent_Start:
if expectStart {
ev = PtyEvent{Type: "started", PID: e.Start.GetPid()}
} else {
continue
}
case *envdpb.ProcessEvent_Data:
switch o := e.Data.GetOutput().(type) {
case *envdpb.ProcessEvent_DataEvent_Pty:
ev = PtyEvent{Type: "output", Data: o.Pty}
case *envdpb.ProcessEvent_DataEvent_Stdout:
ev = PtyEvent{Type: "output", Data: o.Stdout}
case *envdpb.ProcessEvent_DataEvent_Stderr:
ev = PtyEvent{Type: "output", Data: o.Stderr}
default:
continue
}
case *envdpb.ProcessEvent_End:
ev = PtyEvent{Type: "end", ExitCode: e.End.GetExitCode()}
if e.End.Error != nil {
ev.Error = e.End.GetError()
}
case *envdpb.ProcessEvent_Keepalive:
continue
}
select {
case ch <- ev:
case <-ctx.Done():
return
}
}
if err := stream.Err(); err != nil && err != io.EOF {
slog.Debug("pty stream error", "error", err)
}
}()
return ch
}

View File

@ -1,73 +0,0 @@
package events
import (
"context"
"time"
)
// EventPublisher pushes events onto the notification stream.
// Satisfied by *channels.Publisher.
type EventPublisher interface {
Publish(ctx context.Context, e Event)
}
// ActorKind identifies what initiated an event.
type ActorKind string
const (
ActorUser ActorKind = "user"
ActorAPIKey ActorKind = "api_key"
ActorSystem ActorKind = "system"
)
// Actor describes who triggered an event.
type Actor struct {
Type ActorKind `json:"type"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
}
// Resource identifies the object the event relates to.
type Resource struct {
ID string `json:"id"`
Type string `json:"type"`
}
// Event is the canonical notification payload published to the Redis stream
// and delivered to channel subscribers.
type Event struct {
Event string `json:"event"`
Timestamp string `json:"timestamp"`
TeamID string `json:"team_id"`
Actor Actor `json:"actor"`
Resource Resource `json:"resource"`
}
// Event type constants.
const (
CapsuleCreated = "capsule.created"
CapsuleRunning = "capsule.running"
CapsulePaused = "capsule.paused"
CapsuleDestroyed = "capsule.destroyed"
SnapshotCreated = "template.snapshot.created"
SnapshotDeleted = "template.snapshot.deleted"
HostUp = "host.up"
HostDown = "host.down"
)
// AllEventTypes is the complete set of valid event type strings.
var AllEventTypes = []string{
CapsuleCreated,
CapsuleRunning,
CapsulePaused,
CapsuleDestroyed,
SnapshotCreated,
SnapshotDeleted,
HostUp,
HostDown,
}
// Now returns the current time formatted for event timestamps.
func Now() string {
return time.Now().UTC().Format(time.RFC3339)
}

View File

@ -15,6 +15,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
@ -68,10 +69,18 @@ func (s *Server) CreateSandbox(
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
// Apply template defaults (user, env vars) if provided.
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
slog.Warn("failed to set sandbox defaults", "sandbox", sb.ID, "error", err)
}
}
return connect.NewResponse(&pb.CreateSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
Metadata: sb.Metadata,
}), nil
}
@ -99,14 +108,24 @@ func (s *Server) ResumeSandbox(
ctx context.Context,
req *connect.Request[pb.ResumeSandboxRequest],
) (*connect.Response[pb.ResumeSandboxResponse], error) {
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec))
msg := req.Msg
sb, err := s.mgr.Resume(ctx, msg.SandboxId, int(msg.TimeoutSec), msg.KernelVersion)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
// Apply template defaults (user, env vars) if provided.
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
slog.Warn("failed to set sandbox defaults on resume", "sandbox", sb.ID, "error", err)
}
}
return connect.NewResponse(&pb.ResumeSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
Metadata: sb.Metadata,
}), nil
}
@ -252,6 +271,69 @@ func (s *Server) ReadFile(
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
}
func (s *Server) ListDir(
ctx context.Context,
req *connect.Request[pb.ListDirRequest],
) (*connect.Response[pb.ListDirResponse], error) {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
resp, err := client.ListDir(ctx, msg.Path, msg.Depth)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list dir: %w", err))
}
entries := make([]*pb.FileEntry, 0, len(resp.Entries))
for _, e := range resp.Entries {
entries = append(entries, entryInfoToPB(e))
}
return connect.NewResponse(&pb.ListDirResponse{Entries: entries}), nil
}
func (s *Server) MakeDir(
ctx context.Context,
req *connect.Request[pb.MakeDirRequest],
) (*connect.Response[pb.MakeDirResponse], error) {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
resp, err := client.MakeDir(ctx, msg.Path)
if err != nil {
return nil, fmt.Errorf("make dir: %w", err)
}
return connect.NewResponse(&pb.MakeDirResponse{
Entry: entryInfoToPB(resp.Entry),
}), nil
}
func (s *Server) RemovePath(
ctx context.Context,
req *connect.Request[pb.RemovePathRequest],
) (*connect.Response[pb.RemovePathResponse], error) {
msg := req.Msg
client, err := s.mgr.GetClient(msg.SandboxId)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if err := client.Remove(ctx, msg.Path); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("remove: %w", err))
}
return connect.NewResponse(&pb.RemovePathResponse{}), nil
}
func (s *Server) ExecStream(
ctx context.Context,
req *connect.Request[pb.ExecStreamRequest],
@ -436,6 +518,16 @@ func (s *Server) ReadFileStream(
// Stream file content in 64KB chunks.
buf := make([]byte, 64*1024)
for {
// Bail out early if the client disconnected or the context was cancelled.
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return connect.NewError(connect.CodeDeadlineExceeded, ctx.Err())
}
return connect.NewError(connect.CodeCanceled, ctx.Err())
default:
}
n, err := resp.Body.Read(buf)
if n > 0 {
chunk := make([]byte, n)
@ -474,6 +566,7 @@ func (s *Server) ListSandboxes(
CreatedAtUnix: sb.CreatedAt.Unix(),
LastActiveAtUnix: sb.LastActiveAt.Unix(),
TimeoutSec: int32(sb.TimeoutSec),
Metadata: sb.Metadata,
}
}
@ -545,3 +638,269 @@ func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
}
return out
}
func (s *Server) PtyAttach(
ctx context.Context,
req *connect.Request[pb.PtyAttachRequest],
stream *connect.ServerStream[pb.PtyAttachResponse],
) error {
msg := req.Msg
events, err := s.mgr.PtyAttach(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Cols, msg.Rows, msg.Envs, msg.Cwd)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("pty attach: %w", err))
}
for ev := range events {
var resp pb.PtyAttachResponse
switch ev.Type {
case "started":
resp.Event = &pb.PtyAttachResponse_Started{
Started: &pb.PtyStarted{Pid: ev.PID, Tag: msg.Tag},
}
case "output":
resp.Event = &pb.PtyAttachResponse_Output{
Output: &pb.PtyOutput{Data: ev.Data},
}
case "end":
resp.Event = &pb.PtyAttachResponse_Exited{
Exited: &pb.PtyExited{ExitCode: ev.ExitCode, Error: ev.Error},
}
default:
continue
}
if err := stream.Send(&resp); err != nil {
return err
}
}
return nil
}
func (s *Server) PtySendInput(
ctx context.Context,
req *connect.Request[pb.PtySendInputRequest],
) (*connect.Response[pb.PtySendInputResponse], error) {
msg := req.Msg
if err := s.mgr.PtySendInput(ctx, msg.SandboxId, msg.Tag, msg.Data); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty send input: %w", err))
}
return connect.NewResponse(&pb.PtySendInputResponse{}), nil
}
func (s *Server) PtyResize(
ctx context.Context,
req *connect.Request[pb.PtyResizeRequest],
) (*connect.Response[pb.PtyResizeResponse], error) {
msg := req.Msg
if err := s.mgr.PtyResize(ctx, msg.SandboxId, msg.Tag, msg.Cols, msg.Rows); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty resize: %w", err))
}
return connect.NewResponse(&pb.PtyResizeResponse{}), nil
}
func (s *Server) PtyKill(
ctx context.Context,
req *connect.Request[pb.PtyKillRequest],
) (*connect.Response[pb.PtyKillResponse], error) {
msg := req.Msg
if err := s.mgr.PtyKill(ctx, msg.SandboxId, msg.Tag); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("pty kill: %w", err))
}
return connect.NewResponse(&pb.PtyKillResponse{}), nil
}
// entryInfoToPB maps an envd EntryInfo to a hostagent FileEntry.
func entryInfoToPB(e *envdpb.EntryInfo) *pb.FileEntry {
if e == nil {
return nil
}
var fileType string
switch e.Type {
case envdpb.FileType_FILE_TYPE_FILE:
fileType = "file"
case envdpb.FileType_FILE_TYPE_DIRECTORY:
fileType = "directory"
case envdpb.FileType_FILE_TYPE_SYMLINK:
fileType = "symlink"
default:
fileType = "unknown"
}
entry := &pb.FileEntry{
Name: e.Name,
Path: e.Path,
Type: fileType,
Size: e.Size,
Mode: e.Mode,
Permissions: e.Permissions,
Owner: e.Owner,
Group: e.Group,
}
if e.ModifiedTime != nil {
entry.ModifiedAt = e.ModifiedTime.GetSeconds()
}
if e.SymlinkTarget != nil {
entry.SymlinkTarget = e.SymlinkTarget
}
return entry
}
// ── Background Processes ────────────────────────────────────────────
func (s *Server) StartBackground(
ctx context.Context,
req *connect.Request[pb.StartBackgroundRequest],
) (*connect.Response[pb.StartBackgroundResponse], error) {
msg := req.Msg
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
}
return connect.NewResponse(&pb.StartBackgroundResponse{
Pid: pid,
Tag: msg.Tag,
}), nil
}
func (s *Server) ListProcesses(
ctx context.Context,
req *connect.Request[pb.ListProcessesRequest],
) (*connect.Response[pb.ListProcessesResponse], error) {
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
}
entries := make([]*pb.ProcessEntry, 0, len(procs))
for _, p := range procs {
entries = append(entries, &pb.ProcessEntry{
Pid: p.PID,
Tag: p.Tag,
Cmd: p.Cmd,
Args: p.Args,
})
}
return connect.NewResponse(&pb.ListProcessesResponse{
Processes: entries,
}), nil
}
func (s *Server) KillProcess(
ctx context.Context,
req *connect.Request[pb.KillProcessRequest],
) (*connect.Response[pb.KillProcessResponse], error) {
msg := req.Msg
// Resolve PID/tag selector.
var pid uint32
var tag string
switch sel := msg.Selector.(type) {
case *pb.KillProcessRequest_Pid:
pid = sel.Pid
case *pb.KillProcessRequest_Tag:
tag = sel.Tag
default:
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
}
// Map signal string to envd enum.
var signal envdpb.Signal
switch msg.Signal {
case "", "SIGKILL":
signal = envdpb.Signal_SIGNAL_SIGKILL
case "SIGTERM":
signal = envdpb.Signal_SIGNAL_SIGTERM
default:
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("unsupported signal: %s (use SIGKILL or SIGTERM)", msg.Signal))
}
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
}
return connect.NewResponse(&pb.KillProcessResponse{}), nil
}
func (s *Server) ConnectProcess(
ctx context.Context,
req *connect.Request[pb.ConnectProcessRequest],
stream *connect.ServerStream[pb.ConnectProcessResponse],
) error {
msg := req.Msg
var pid uint32
var tag string
switch sel := msg.Selector.(type) {
case *pb.ConnectProcessRequest_Pid:
pid = sel.Pid
case *pb.ConnectProcessRequest_Tag:
tag = sel.Tag
default:
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("pid or tag is required"))
}
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return connect.NewError(connect.CodeNotFound, err)
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
}
for ev := range events {
var resp pb.ConnectProcessResponse
switch ev.Type {
case "start":
resp.Event = &pb.ConnectProcessResponse_Start{
Start: &pb.ExecStreamStart{Pid: ev.PID},
}
case "stdout":
resp.Event = &pb.ConnectProcessResponse_Data{
Data: &pb.ExecStreamData{
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
},
}
case "stderr":
resp.Event = &pb.ConnectProcessResponse_Data{
Data: &pb.ExecStreamData{
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
},
}
case "end":
resp.Event = &pb.ConnectProcessResponse_End{
End: &pb.ExecStreamEnd{
ExitCode: ev.ExitCode,
Error: ev.Error,
},
}
}
if err := stream.Send(&resp); err != nil {
return err
}
}
return nil
}

View File

@ -1,186 +0,0 @@
package id
import (
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const (
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
)
var base36Base = big.NewInt(36)
// --- Generation ---
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
func newUUID() pgtype.UUID {
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
}
func NewSandboxID() pgtype.UUID { return newUUID() }
func NewUserID() pgtype.UUID { return newUUID() }
func NewTeamID() pgtype.UUID { return newUUID() }
func NewAPIKeyID() pgtype.UUID { return newUUID() }
func NewHostID() pgtype.UUID { return newUUID() }
func NewHostTokenID() pgtype.UUID { return newUUID() }
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
func NewAuditLogID() pgtype.UUID { return newUUID() }
func NewBuildID() pgtype.UUID { return newUUID() }
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
func NewChannelID() pgtype.UUID { return newUUID() }
func NewTemplateID() pgtype.UUID { return newUUID() }
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
func NewTeamSlug() string {
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
return hexToken(32)
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
func NewRefreshToken() string {
return hexToken(32)
}
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
const (
PrefixSandbox = "cl-"
PrefixUser = "usr-"
PrefixTeam = "team-"
PrefixAPIKey = "key-"
PrefixHost = "host-"
PrefixHostToken = "htok-"
PrefixRefreshToken = "hrt-"
PrefixAuditLog = "log-"
PrefixBuild = "bld-"
PrefixAdminPermission = "perm-"
PrefixChannel = "ch-"
)
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
func UUIDToBase36(b [16]byte) string {
n := new(big.Int).SetBytes(b[:])
buf := make([]byte, base36IDLen)
mod := new(big.Int)
for i := base36IDLen - 1; i >= 0; i-- {
n.DivMod(n, base36Base, mod)
buf[i] = base36Alphabet[mod.Int64()]
}
return string(buf)
}
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
func base36ToUUID(s string) ([16]byte, error) {
if len(s) != base36IDLen {
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
}
n := new(big.Int)
for _, c := range s {
idx := strings.IndexRune(base36Alphabet, c)
if idx < 0 {
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
}
n.Mul(n, base36Base)
n.Add(n, big.NewInt(int64(idx)))
}
b := n.Bytes()
var out [16]byte
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
copy(out[16-len(b):], b)
return out, nil
}
func formatUUID(prefix string, id pgtype.UUID) string {
return prefix + UUIDToBase36(id.Bytes)
}
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
func FormatChannelID(id pgtype.UUID) string { return formatUUID(PrefixChannel, id) }
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
func parseUUID(prefix, s string) (pgtype.UUID, error) {
if !strings.HasPrefix(s, prefix) {
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
}
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
}
return pgtype.UUID{Bytes: b, Valid: true}, nil
}
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixChannel, s) }
// --- Well-known IDs ---
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
// (e.g. base templates, shared infrastructure).
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
// template. When both team_id and template_id are zero, the host agent uses
// the minimal rootfs at WRENN_DIR/images/minimal/.
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
func UUIDString(id pgtype.UUID) string {
return uuid.UUID(id.Bytes).String()
}
// --- Helpers ---
func hex8() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hexToken(nBytes int) string {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

View File

@ -1,118 +0,0 @@
package id
import (
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
func TestBase36RoundTrip(t *testing.T) {
for i := 0; i < 1000; i++ {
orig := uuid.New()
encoded := UUIDToBase36(orig)
if len(encoded) != base36IDLen {
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != orig {
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
}
}
}
func TestBase36ZeroUUID(t *testing.T) {
var zero [16]byte
encoded := UUIDToBase36(zero)
if encoded != "0000000000000000000000000" {
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != zero {
t.Fatalf("round-trip failed for zero UUID")
}
}
func TestFormatParseRoundTrip(t *testing.T) {
id := NewSandboxID()
formatted := FormatSandboxID(id)
if formatted[:3] != "cl-" {
t.Fatalf("expected cl- prefix, got %s", formatted)
}
if len(formatted) != 3+base36IDLen {
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
}
parsed, err := ParseSandboxID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != id {
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
}
}
func TestBase36InvalidInput(t *testing.T) {
// Wrong length.
if _, err := base36ToUUID("abc"); err == nil {
t.Fatal("expected error for short input")
}
// Invalid character.
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
t.Fatal("expected error for invalid character")
}
}
func TestPlatformTeamIDFormats(t *testing.T) {
formatted := FormatTeamID(PlatformTeamID)
parsed, err := ParseTeamID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != PlatformTeamID {
t.Fatalf("platform team ID round-trip failed")
}
}
func TestMaxUUID(t *testing.T) {
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
encoded := UUIDToBase36(max)
if len(encoded) != base36IDLen {
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != max {
t.Fatalf("round-trip failed for max UUID")
}
}
func BenchmarkFormatSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
b.ResetTimer()
for i := 0; i < b.N; i++ {
FormatSandboxID(id)
}
}
func BenchmarkParseSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
s := FormatSandboxID(id)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseSandboxID(s)
}
}

View File

@ -1,11 +1,15 @@
package layout
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// IsMinimal reports whether the given team and template IDs represent the
@ -47,6 +51,75 @@ func KernelPath(wrennDir string) string {
return filepath.Join(wrennDir, "kernels", "vmlinux")
}
// KernelPathVersioned returns the path to a specific kernel version.
func KernelPathVersioned(wrennDir, version string) string {
return filepath.Join(wrennDir, "kernels", "vmlinux-"+version)
}
// LatestKernel scans the kernels directory for files matching vmlinux-{semver}
// and returns the path and version of the latest one (by semver sort).
func LatestKernel(wrennDir string) (path, version string, err error) {
dir := filepath.Join(wrennDir, "kernels")
return latestVersionedFile(dir, "vmlinux-")
}
// latestVersionedFile scans dir for files with the given prefix, extracts the
// version suffix, sorts by semver, and returns the path and version of the latest.
func latestVersionedFile(dir, prefix string) (path, version string, err error) {
entries, err := os.ReadDir(dir)
if err != nil {
return "", "", fmt.Errorf("read directory %s: %w", dir, err)
}
var versions []string
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if v, ok := strings.CutPrefix(name, prefix); ok && v != "" {
versions = append(versions, v)
}
}
if len(versions) == 0 {
return "", "", fmt.Errorf("no %s* files found in %s", prefix, dir)
}
sort.Slice(versions, func(i, j int) bool {
return compareSemver(versions[i], versions[j]) < 0
})
latest := versions[len(versions)-1]
return filepath.Join(dir, prefix+latest), latest, nil
}
// compareSemver compares two dotted-numeric version strings.
// Returns -1 if a < b, 0 if equal, 1 if a > b.
func compareSemver(a, b string) int {
aParts := strings.Split(a, ".")
bParts := strings.Split(b, ".")
maxLen := max(len(aParts), len(bParts))
for i := 0; i < maxLen; i++ {
var av, bv int
if i < len(aParts) {
_, _ = fmt.Sscanf(aParts[i], "%d", &av)
}
if i < len(bParts) {
_, _ = fmt.Sscanf(bParts[i], "%d", &bv)
}
if av < bv {
return -1
}
if av > bv {
return 1
}
}
return 0
}
// ImagesRoot returns the root images directory.
func ImagesRoot(wrennDir string) string {
return filepath.Join(wrennDir, "images")

View File

@ -6,7 +6,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
func TestIsMinimal(t *testing.T) {

View File

@ -1,125 +0,0 @@
package lifecycle
import (
"crypto/tls"
"fmt"
"net/http"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
// Clients are created lazily on first access and evicted when a host is removed
// or goes unreachable. The pool is safe for concurrent use.
type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
scheme string // "http://" or "https://"
}
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
// Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
scheme: "http://",
}
}
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
// tlsCfg should already carry the CP client cert and CA trust anchor
// (use auth.CPClientTLSConfig to construct it).
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
transport := &http.Transport{
TLSClientConfig: tlsCfg,
}
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{
Timeout: 10 * time.Minute,
Transport: transport,
},
scheme: "https://",
}
}
// Get returns a Connect RPC client for the given host, creating one if necessary.
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
p.mu.RLock()
c, ok := p.clients[hostID]
p.mu.RUnlock()
if ok {
return c
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock.
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c
return c
}
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if h.Address == "" {
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
}
return p.Get(id.FormatHostID(h.ID), h.Address), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
// created on the next call to Get. Call this when a host's address changes or when
// a host is deleted.
func (p *HostClientPool) Evict(hostID string) {
p.mu.Lock()
delete(p.clients, hostID)
p.mu.Unlock()
}
// ensureScheme prepends the pool's configured scheme if the address has none.
func (p *HostClientPool) ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return p.scheme + addr
}
// Transport returns the http.RoundTripper used by this pool. Use this when you
// need to make raw HTTP requests to agent addresses with the same TLS settings
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
func (p *HostClientPool) Transport() http.RoundTripper {
if p.httpClient.Transport != nil {
return p.httpClient.Transport
}
return http.DefaultTransport
}
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
// Use this when constructing URLs that must use the same transport as the pool
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
func (p *HostClientPool) ResolveAddr(addr string) string {
return p.ensureScheme(addr)
}
// EnsureScheme adds "http://" if the address has no scheme.
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return "http://" + addr
}

View File

@ -1 +0,0 @@
package lifecycle

View File

@ -30,4 +30,5 @@ type Sandbox struct {
RootfsPath string
CreatedAt time.Time
LastActiveAt time.Time
Metadata map[string]string
}

View File

@ -24,7 +24,7 @@ func (a *SlotAllocator) Allocate() (int, error) {
a.mu.Lock()
defer a.mu.Unlock()
for i := 1; i <= 65534; i++ {
for i := 1; i <= 32767; i++ {
if !a.inUse[i] {
a.inUse[i] = true
return i, nil

View File

@ -131,26 +131,31 @@ type Slot struct {
}
// NewSlot computes the addressing for the given slot index (1-based).
// Index must be in [1, 32767] so that veth offset (index*2) fits in 16 bits.
func NewSlot(index int) *Slot {
if index < 1 || index > 32767 {
panic(fmt.Sprintf("slot index %d out of range [1, 32767]", index))
}
hostBaseIP := net.ParseIP(hostBase).To4()
vrtBaseIP := net.ParseIP(vrtBase).To4()
hostIP := make(net.IP, 4)
copy(hostIP, hostBaseIP)
hostIP[2] += byte(index >> 8)
hostIP[3] += byte(index & 0xFF)
hostIP[2] = hostBaseIP[2] + byte(index>>8)
hostIP[3] = hostBaseIP[3] + byte(index&0xFF)
vethOffset := index * vrtAddressesPerSlot
vethIP := make(net.IP, 4)
copy(vethIP, vrtBaseIP)
vethIP[2] += byte(vethOffset >> 8)
vethIP[3] += byte(vethOffset & 0xFF)
vethIP[2] = vrtBaseIP[2] + byte(vethOffset>>8)
vethIP[3] = vrtBaseIP[3] + byte(vethOffset&0xFF)
vpeerOffset := vethOffset + 1
vpeerIP := make(net.IP, 4)
copy(vpeerIP, vrtBaseIP)
vpeerIP[2] += byte(vpeerOffset >> 8)
vpeerIP[3] += byte(vpeerOffset & 0xFF)
vpeerIP[2] = vrtBaseIP[2] + byte(vpeerOffset>>8)
vpeerIP[3] = vrtBaseIP[3] + byte(vpeerOffset&0xFF)
return &Slot{
Index: index,

View File

@ -7,10 +7,11 @@ import (
)
// ExecContext holds mutable state that persists across recipe steps.
// It is initialized empty and updated by ENV and WORKDIR steps.
// It is initialized empty and updated by ENV, WORKDIR, and USER steps.
type ExecContext struct {
WorkDir string
EnvVars map[string]string
User string // Current unix user for command execution.
}
// This regex matches:
@ -25,7 +26,20 @@ var envRegex = regexp.MustCompile(`\$\$|\$\{([a-zA-Z0-9_]*)\}|\$([a-zA-Z0-9_]+)`
// If WORKDIR and/or ENV are set, they are prepended as a shell preamble:
//
// cd '/the/dir' && KEY='val' /bin/sh -c 'original command'
//
// If USER is set to a non-root user, the entire command is wrapped with su:
//
// su <user> -s /bin/sh -c '<preamble + command>'
func (c *ExecContext) WrappedCommand(cmd string) string {
inner := c.innerCommand(cmd)
if c.User != "" && c.User != "root" {
return "su " + shellescape(c.User) + " -s /bin/sh -c " + shellescape(inner)
}
return inner
}
// innerCommand builds the command with workdir/env preamble but without user wrapping.
func (c *ExecContext) innerCommand(cmd string) string {
prefix := c.shellPrefix()
if prefix == "" {
return cmd
@ -42,7 +56,11 @@ func (c *ExecContext) WrappedCommand(cmd string) string {
// simultaneously before a healthcheck is evaluated.
func (c *ExecContext) StartCommand(cmd string) string {
prefix := c.shellPrefix()
return prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
inner := prefix + "nohup /bin/sh -c " + shellescape(cmd) + " >/dev/null 2>&1 &"
if c.User != "" && c.User != "root" {
return "su " + shellescape(c.User) + " -s /bin/sh -c " + shellescape(inner)
}
return inner
}
// shellPrefix builds the "cd ... && KEY=val " preamble for a shell command.
@ -97,8 +115,11 @@ func expandEnv(s string, vars map[string]string) string {
})
}
// shellescape wraps s in single quotes, escaping any embedded single quotes.
// Shellescape wraps s in single quotes, escaping any embedded single quotes.
// This is POSIX-safe for paths, env values, and shell commands.
func shellescape(s string) string {
func Shellescape(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}
// shellescape is the package-internal alias for Shellescape.
func shellescape(s string) string { return Shellescape(s) }

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"path"
"strings"
"time"
@ -16,6 +17,10 @@ import (
// explicit --timeout flag.
const DefaultStepTimeout = 30 * time.Second
// BuildFilesDir is the directory inside the sandbox where uploaded build
// archives are extracted. COPY instructions reference paths relative to this.
const BuildFilesDir = "/tmp/build-files"
// BuildLogEntry is the per-step record stored in template_builds.logs (JSONB).
type BuildLogEntry struct {
Step int `json:"step"`
@ -32,13 +37,18 @@ type BuildLogEntry struct {
// the method on the hostagent Connect RPC client.
type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
// ProgressFunc is called after each step with the current step counter and
// accumulated log entries. Used for per-step DB progress updates.
type ProgressFunc func(step int, entries []BuildLogEntry)
// Execute runs steps sequentially against sandboxID using execFn.
//
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
// - startStep is the 1-based offset so entries are globally numbered across phases.
// - defaultTimeout applies to RUN steps with no per-step --timeout; 0 → 10 minutes.
// - bctx is mutated in place as ENV/WORKDIR steps execute, and carries forward
// - bctx is mutated in place as ENV/WORKDIR/USER steps execute, and carries forward
// into subsequent phases when the caller passes the same pointer.
// - onProgress is called after each step for live progress updates (may be nil).
//
// Returns all log entries appended during this call, the next step counter
// value, and whether all steps succeeded. On false the last entry contains
@ -53,6 +63,7 @@ func Execute(
defaultTimeout time.Duration,
bctx *ExecContext,
execFn ExecFunc,
onProgress ProgressFunc,
) (entries []BuildLogEntry, nextStep int, ok bool) {
if defaultTimeout <= 0 {
defaultTimeout = 10 * time.Minute
@ -72,19 +83,30 @@ func Execute(
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
case KindWORKDIR:
// Create the directory if it doesn't exist.
mkdirEntry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 10*time.Second, execFn,
"mkdir -p "+shellescape(st.Path))
if !mkdirEntry.Ok {
entries = append(entries, mkdirEntry)
return entries, step, false
}
bctx.WorkDir = st.Path
entries = append(entries, BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw, Ok: true})
mkdirEntry.Ok = true
entries = append(entries, mkdirEntry)
case KindUSER, KindCOPY:
verb := strings.ToUpper(strings.Fields(st.Raw)[0])
entries = append(entries, BuildLogEntry{
Step: step,
Phase: phase,
Cmd: st.Raw,
Stderr: verb + " is not yet supported",
Ok: false,
})
return entries, step, false
case KindUSER:
entry, succeeded := execUser(ctx, st, sandboxID, phase, step, bctx, execFn)
entries = append(entries, entry)
if !succeeded {
return entries, step, false
}
case KindCOPY:
entry, succeeded := execCopy(ctx, st, sandboxID, phase, step, bctx, execFn)
entries = append(entries, entry)
if !succeeded {
return entries, step, false
}
case KindSTART:
entry, succeeded := execStart(ctx, st, sandboxID, phase, step, bctx, execFn)
@ -104,6 +126,10 @@ func Execute(
return entries, step, false
}
}
if onProgress != nil {
onProgress(step, entries)
}
}
return entries, step, true
}
@ -145,6 +171,123 @@ func execRun(
return entry, entry.Ok
}
// execUser creates a unix user (if not exists), grants passwordless sudo,
// and updates bctx.User for subsequent steps.
func execUser(
ctx context.Context,
st Step,
sandboxID, phase string,
step int,
bctx *ExecContext,
execFn ExecFunc,
) (BuildLogEntry, bool) {
username := st.Key
// Create user if not exists, with home directory and bash shell.
// Grant passwordless sudo access (E2B convention).
// Uses printf %s to avoid shell injection in the sudoers line.
script := fmt.Sprintf(
"id %s >/dev/null 2>&1 || (adduser --disabled-password --gecos '' --shell /bin/bash %s && printf '%%s ALL=(ALL) NOPASSWD:ALL\\n' %s >> /etc/sudoers)",
shellescape(username), shellescape(username), shellescape(username),
)
entry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 30*time.Second, execFn, script)
if entry.Ok {
bctx.User = username
// Update HOME so ~ expands correctly in subsequent RUN/WORKDIR steps.
if bctx.EnvVars == nil {
bctx.EnvVars = make(map[string]string)
}
if username == "root" {
bctx.EnvVars["HOME"] = "/root"
} else {
bctx.EnvVars["HOME"] = "/home/" + username
}
}
return entry, entry.Ok
}
// execCopy copies a file or directory from the build archive (extracted at
// BuildFilesDir) to the destination path inside the sandbox. Ownership is
// set to the current user from bctx.
func execCopy(
ctx context.Context,
st Step,
sandboxID, phase string,
step int,
bctx *ExecContext,
execFn ExecFunc,
) (BuildLogEntry, bool) {
// Validate all source paths: must be relative and not escape the archive directory.
var srcPaths []string
for _, s := range st.Srcs {
cleaned := path.Clean(s)
if strings.HasPrefix(cleaned, "..") || strings.HasPrefix(cleaned, "/") {
return BuildLogEntry{
Step: step,
Phase: phase,
Cmd: st.Raw,
Stderr: fmt.Sprintf("COPY source must be a relative path within the archive: %q", s),
}, false
}
srcPaths = append(srcPaths, shellescape(BuildFilesDir+"/"+cleaned))
}
dst := st.Dst
// Resolve relative destination against the current WORKDIR.
if dst != "" && dst[0] != '/' && bctx.WorkDir != "" {
dst = bctx.WorkDir + "/" + dst
}
owner := "root"
if bctx.User != "" {
owner = bctx.User
}
script := fmt.Sprintf(
"cp -r %s %s && chown -R %s:%s %s",
strings.Join(srcPaths, " "), shellescape(dst), shellescape(owner), shellescape(owner), shellescape(dst),
)
entry := execRawShell(ctx, st.Raw, sandboxID, phase, step, 60*time.Second, execFn, script)
return entry, entry.Ok
}
// execRawShell runs a shell command directly (as root) without ExecContext
// wrapping. Used for internal operations like user creation and file copy.
func execRawShell(
ctx context.Context,
raw, sandboxID, phase string,
step int,
timeout time.Duration,
execFn ExecFunc,
shellCmd string,
) BuildLogEntry {
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
resp, err := execFn(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: "/bin/sh",
Args: []string{"-c", shellCmd},
TimeoutSec: int32(timeout.Seconds()),
}))
entry := BuildLogEntry{
Step: step,
Phase: phase,
Cmd: raw,
Elapsed: time.Since(start).Milliseconds(),
}
if err != nil {
entry.Stderr = fmt.Sprintf("exec error: %v", err)
return entry
}
entry.Stdout = string(resp.Msg.Stdout)
entry.Stderr = string(resp.Msg.Stderr)
entry.Exit = resp.Msg.ExitCode
entry.Ok = resp.Msg.ExitCode == 0
return entry
}
func execStart(
ctx context.Context,
st Step,

View File

@ -24,9 +24,11 @@ type Step struct {
Raw string // original string, preserved for logging
Shell string // KindRUN, KindSTART: the shell command text
Timeout time.Duration // KindRUN: 0 means use caller's default
Key string // KindENV: variable name
Key string // KindENV: variable name; KindUSER: username
Value string // KindENV: variable value
Path string // KindWORKDIR: directory path
Srcs []string // KindCOPY: source paths (relative to build archive)
Dst string // KindCOPY: destination path inside sandbox
}
// ParseStep parses a single recipe instruction string into a Step.
@ -61,9 +63,9 @@ func ParseStep(s string) (Step, error) {
case "WORKDIR":
return parseWORKDIR(s, rest)
case "USER":
return Step{Kind: KindUSER, Raw: s}, nil
return parseUSER(s, rest)
case "COPY":
return Step{Kind: KindCOPY, Raw: s}, nil
return parseCOPY(s, rest)
default:
return Step{}, fmt.Errorf("unknown instruction %q (expected RUN, START, ENV, WORKDIR, USER, or COPY)", keyword)
}
@ -127,3 +129,33 @@ func parseWORKDIR(raw, path string) (Step, error) {
}
return Step{Kind: KindWORKDIR, Raw: raw, Path: path}, nil
}
func parseUSER(raw, username string) (Step, error) {
if username == "" {
return Step{}, fmt.Errorf("USER requires a username: %q", raw)
}
// Validate: alphanumeric, hyphens, underscores only; must start with a letter or underscore.
for i, c := range username {
if i == 0 && !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
return Step{}, fmt.Errorf("USER username must start with a letter or underscore: %q", raw)
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') {
return Step{}, fmt.Errorf("USER username contains invalid character %q: %q", string(c), raw)
}
}
return Step{Kind: KindUSER, Raw: raw, Key: username}, nil
}
func parseCOPY(raw, rest string) (Step, error) {
if rest == "" {
return Step{}, fmt.Errorf("COPY requires <src>... <dst>: %q", raw)
}
parts := strings.Fields(rest)
if len(parts) < 2 {
return Step{}, fmt.Errorf("COPY requires <src>... <dst>: %q", raw)
}
// Last argument is the destination, everything before is sources.
dst := parts[len(parts)-1]
srcs := parts[:len(parts)-1]
return Step{Kind: KindCOPY, Raw: raw, Srcs: srcs, Dst: dst}, nil
}

View File

@ -1,6 +1,7 @@
package recipe
import (
"reflect"
"testing"
"time"
)
@ -111,16 +112,42 @@ func TestParseStep(t *testing.T) {
input: "WORKDIR",
wantErr: true,
},
// USER and COPY stubs
// USER
{
name: "USER stub",
name: "USER basic",
input: "USER www-data",
want: Step{Kind: KindUSER, Raw: "USER www-data"},
want: Step{Kind: KindUSER, Raw: "USER www-data", Key: "www-data"},
},
{
name: "COPY stub",
name: "USER empty",
input: "USER",
wantErr: true,
},
{
name: "USER invalid chars",
input: "USER bad user",
wantErr: true,
},
// COPY
{
name: "COPY basic",
input: "COPY config.yaml /etc/app/config.yaml",
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml"},
want: Step{Kind: KindCOPY, Raw: "COPY config.yaml /etc/app/config.yaml", Srcs: []string{"config.yaml"}, Dst: "/etc/app/config.yaml"},
},
{
name: "COPY multiple sources",
input: "COPY a.txt b.txt /dest/",
want: Step{Kind: KindCOPY, Raw: "COPY a.txt b.txt /dest/", Srcs: []string{"a.txt", "b.txt"}, Dst: "/dest/"},
},
{
name: "COPY missing dst",
input: "COPY config.yaml",
wantErr: true,
},
{
name: "COPY empty",
input: "COPY",
wantErr: true,
},
// Unknown keyword
{
@ -148,7 +175,7 @@ func TestParseStep(t *testing.T) {
if err != nil {
t.Fatalf("ParseStep(%q) unexpected error: %v", tc.input, err)
}
if got != tc.want {
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("ParseStep(%q)\n got %+v\n want %+v", tc.input, got, tc.want)
}
})

View File

@ -0,0 +1,30 @@
package sandbox
import (
"fmt"
"os/exec"
"strings"
)
// DetectFirecrackerVersion runs the firecracker binary with --version and
// parses the semver from the output (e.g. "Firecracker v1.14.1" → "1.14.1").
func DetectFirecrackerVersion(binaryPath string) (string, error) {
out, err := exec.Command(binaryPath, "--version").Output()
if err != nil {
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
}
// Output is typically "Firecracker v1.14.1\n" or similar.
line := strings.TrimSpace(string(out))
for _, field := range strings.Fields(line) {
v := strings.TrimPrefix(field, "v")
if v != field || strings.Contains(field, ".") {
// Either had a "v" prefix or contains a dot — likely the version.
if strings.Count(v, ".") >= 1 {
return v, nil
}
}
}
return "", fmt.Errorf("could not parse version from firecracker output: %q", line)
}

View File

@ -6,9 +6,11 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// DefaultDiskSizeMB is the standard disk size for base images. Images smaller
@ -66,6 +68,73 @@ func EnsureImageSizes(wrennDir string, targetMB int) error {
return nil
}
// ParseSizeToMB parses a human-readable size string into megabytes.
// Supported suffixes: G, Gi (gibibytes), M, Mi (mebibytes).
// Examples: "5G" → 5120, "2Gi" → 2048, "1000M" → 1000, "512Mi" → 512.
func ParseSizeToMB(s string) (int, error) {
s = strings.TrimSpace(s)
if s == "" {
return 0, fmt.Errorf("empty size string")
}
// Find where the numeric part ends.
i := 0
for i < len(s) && (s[i] == '.' || (s[i] >= '0' && s[i] <= '9')) {
i++
}
if i == 0 {
return 0, fmt.Errorf("invalid size %q: no numeric value", s)
}
numStr := s[:i]
suffix := strings.TrimSpace(s[i:])
num, err := strconv.ParseFloat(numStr, 64)
if err != nil {
return 0, fmt.Errorf("invalid size %q: %w", s, err)
}
switch suffix {
case "G", "Gi":
return int(num * 1024), nil
case "M", "Mi", "":
return int(num), nil
default:
return 0, fmt.Errorf("invalid size %q: unknown suffix %q (use G, Gi, M, or Mi)", s, suffix)
}
}
// ShrinkMinimalImage shrinks the built-in minimal rootfs back to its minimum
// size using resize2fs -M. This is the inverse of EnsureImageSizes and should
// be called during graceful shutdown so the image is stored compactly on disk.
func ShrinkMinimalImage(wrennDir string) {
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
shrinkImage(minimalRootfs)
}
// shrinkImage shrinks a single rootfs image to its minimum size.
func shrinkImage(rootfs string) {
if _, err := os.Stat(rootfs); err != nil {
return
}
slog.Info("shrinking base image", "path", rootfs)
if out, err := exec.Command("e2fsck", "-fy", rootfs).CombinedOutput(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
slog.Warn("e2fsck before shrink failed", "path", rootfs, "output", string(out), "error", err)
return
}
}
if out, err := exec.Command("resize2fs", "-M", rootfs).CombinedOutput(); err != nil {
slog.Warn("resize2fs -M failed", "path", rootfs, "output", string(out), "error", err)
return
}
slog.Info("base image shrunk", "path", rootfs)
}
// expandImage expands a single rootfs image if it is smaller than targetBytes.
func expandImage(rootfs string, targetBytes int64, targetMB int) error {
info, err := os.Stat(rootfs)

View File

@ -17,19 +17,28 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/devicemapper"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/models"
"git.omukk.dev/wrenn/wrenn/internal/network"
"git.omukk.dev/wrenn/wrenn/internal/snapshot"
"git.omukk.dev/wrenn/wrenn/internal/uffd"
"git.omukk.dev/wrenn/wrenn/internal/vm"
"git.omukk.dev/wrenn/wrenn/pkg/id"
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
)
// Config holds the paths and defaults for the sandbox manager.
type Config struct {
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
EnvdTimeout time.Duration
WrennDir string // root directory (e.g. /var/lib/wrenn); all sub-paths derived via layout package
EnvdTimeout time.Duration
DefaultRootfsSizeMB int // target size for template rootfs images; 0 → DefaultDiskSizeMB
// Resolved at startup by the host agent.
KernelPath string // path to the latest vmlinux-x.y.z
KernelVersion string // semver extracted from filename
FirecrackerBin string // path to the firecracker binary
FirecrackerVersion string // semver from firecracker --version
AgentVersion string // host agent version (injected via ldflags)
}
// Manager orchestrates sandbox lifecycle: VM, network, filesystem, envd.
@ -84,6 +93,35 @@ type snapshotParent struct {
// preventing the crash.
const maxDiffGenerations = 8
// buildMetadata constructs the metadata map with version information.
func (m *Manager) buildMetadata(envdVersion string) map[string]string {
meta := map[string]string{
"kernel_version": m.cfg.KernelVersion,
"firecracker_version": m.cfg.FirecrackerVersion,
"agent_version": m.cfg.AgentVersion,
}
if envdVersion != "" {
meta["envd_version"] = envdVersion
}
return meta
}
// resolveKernelPath returns the kernel path for the given version hint.
// If the exact version exists on disk, it is used. Otherwise, falls back to
// the latest kernel (m.cfg.KernelPath).
func (m *Manager) resolveKernelPath(versionHint string) string {
if versionHint == "" {
return m.cfg.KernelPath
}
exact := layout.KernelPathVersioned(m.cfg.WrennDir, versionHint)
if _, err := os.Stat(exact); err == nil {
return exact
}
slog.Warn("requested kernel version not found, using latest",
"requested", versionHint, "latest", m.cfg.KernelVersion)
return m.cfg.KernelPath
}
// New creates a new sandbox manager.
func New(cfg Config) *Manager {
if cfg.EnvdTimeout == 0 {
@ -173,7 +211,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
TemplateID: id.UUIDString(templateID),
KernelPath: layout.KernelPath(m.cfg.WrennDir),
KernelPath: m.cfg.KernelPath,
RootfsPath: dmDev.DevicePath,
VCPUs: vcpus,
MemoryMB: memoryMB,
@ -183,6 +221,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
GuestIP: slot.GuestIP,
GatewayIP: slot.TapIP,
NetMask: slot.GuestNetMask,
FirecrackerBin: m.cfg.FirecrackerBin,
}
if _, err := m.vm.Create(ctx, vmCfg); err != nil {
@ -209,6 +248,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
return nil, fmt.Errorf("wait for envd: %w", err)
}
// Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx)
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
@ -224,6 +266,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template
RootfsPath: dmDev.DevicePath,
CreatedAt: now,
LastActiveAt: now,
Metadata: m.buildMetadata(envdVersion),
},
slot: slot,
client: client,
@ -326,6 +369,20 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
sb.connTracker.Drain(2 * time.Second)
slog.Debug("pause: proxy connections drained", "id", sandboxID)
// Step 0b: Signal envd to quiesce continuous goroutines (port scanner,
// forwarder) and run GC before freezing vCPUs. This prevents Go runtime
// page allocator corruption ("bad summary data") on snapshot restore.
// Best-effort: a failure is logged but does not abort the pause.
func() {
prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second)
defer prepCancel()
if err := sb.client.PrepareSnapshot(prepCtx); err != nil {
slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err)
} else {
slog.Debug("pause: envd goroutines quiesced", "id", sandboxID)
}
}()
pauseStart := time.Now()
// Step 1: Pause the VM (freeze vCPUs).
@ -542,7 +599,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
// Resume restores a paused sandbox from its snapshot using UFFD for
// lazy memory loading. The sandbox gets a new network slot.
func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) (*models.Sandbox, error) {
func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int, kernelVersion string) (*models.Sandbox, error) {
pauseDir := layout.PauseSnapshotDir(m.cfg.WrennDir, sandboxID)
if _, err := os.Stat(pauseDir); err != nil {
return nil, fmt.Errorf("no snapshot found for sandbox %s", sandboxID)
@ -656,7 +713,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
// Restore VM from snapshot.
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
KernelPath: layout.KernelPath(m.cfg.WrennDir),
KernelPath: m.resolveKernelPath(kernelVersion),
RootfsPath: dmDev.DevicePath,
VCPUs: 1, // Placeholder; overridden by snapshot.
MemoryMB: int(header.Metadata.Size / (1024 * 1024)), // Placeholder; overridden by snapshot.
@ -666,6 +723,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
GuestIP: slot.GuestIP,
GatewayIP: slot.TapIP,
NetMask: slot.GuestNetMask,
FirecrackerBin: m.cfg.FirecrackerBin,
}
resumeSnapPath := filepath.Join(pauseDir, snapshot.SnapFileName)
@ -697,6 +755,14 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
return nil, fmt.Errorf("wait for envd: %w", err)
}
// Trigger envd to re-read MMDS so it picks up the new sandbox/template IDs.
if err := client.PostInit(waitCtx); err != nil {
slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err)
}
// Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx)
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
@ -710,6 +776,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int)
RootfsPath: dmDev.DevicePath,
CreatedAt: now,
LastActiveAt: now,
Metadata: m.buildMetadata(envdVersion),
},
slot: slot,
client: client,
@ -880,6 +947,18 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
return 0, fmt.Errorf("sandbox %s not found", sandboxID)
}
// Flush guest page cache to disk before stopping the VM. Without this,
// files written by the build (e.g. pip-installed packages) may exist in the
// guest's page cache but not yet on the dm block device — flatten would then
// capture 0-byte files.
func() {
syncCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if _, err := sb.client.Exec(syncCtx, "/bin/sync"); err != nil {
slog.Warn("flatten: guest sync failed (non-fatal)", "id", sb.ID, "error", err)
}
}()
// Stop the VM but keep the dm device alive for flattening.
m.stopSampler(sb)
if err := m.vm.Destroy(ctx, sb.ID); err != nil {
@ -919,8 +998,8 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
// Clean up dm device and loop device now that flatten is complete.
m.cleanupDM(sb)
// Shrink the flattened image to its minimum size so stored templates are
// compact. EnsureImageSizes will re-expand them on the next agent startup.
// Shrink the flattened image to its minimum size, then re-expand to the
// configured default rootfs size so sandboxes see the full disk from boot.
if out, err := exec.Command("e2fsck", "-fy", outputPath).CombinedOutput(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() > 1 {
slog.Warn("e2fsck before shrink failed (non-fatal)", "output", string(out), "error", err)
@ -930,6 +1009,15 @@ func (m *Manager) FlattenRootfs(ctx context.Context, sandboxID string, teamID, t
slog.Warn("resize2fs -M failed (non-fatal)", "output", string(out), "error", err)
}
// Re-expand to default rootfs size.
targetMB := m.cfg.DefaultRootfsSizeMB
if targetMB <= 0 {
targetMB = DefaultDiskSizeMB
}
if err := expandImage(outputPath, int64(targetMB)*1024*1024, targetMB); err != nil {
slog.Warn("failed to expand template to default size (non-fatal)", "error", err)
}
sizeBytes, err := snapshot.DirSize(flattenDstDir, "")
if err != nil {
slog.Warn("failed to calculate template size", "error", err)
@ -1057,7 +1145,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
vmCfg := vm.VMConfig{
SandboxID: sandboxID,
TemplateID: id.UUIDString(templateID),
KernelPath: layout.KernelPath(m.cfg.WrennDir),
KernelPath: m.cfg.KernelPath,
RootfsPath: dmDev.DevicePath,
VCPUs: vcpus,
MemoryMB: memoryMB,
@ -1067,6 +1155,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
GuestIP: slot.GuestIP,
GatewayIP: slot.TapIP,
NetMask: slot.GuestNetMask,
FirecrackerBin: m.cfg.FirecrackerBin,
}
snapPath := filepath.Join(tmplDir, snapshot.SnapFileName)
@ -1098,6 +1187,14 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
return nil, fmt.Errorf("wait for envd: %w", err)
}
// Trigger envd to re-read MMDS so it picks up the new sandbox/template IDs.
if err := client.PostInit(waitCtx); err != nil {
slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err)
}
// Fetch envd version (best-effort).
envdVersion, _ := client.FetchVersion(ctx)
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
@ -1113,6 +1210,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team
RootfsPath: dmDev.DevicePath,
CreatedAt: now,
LastActiveAt: now,
Metadata: m.buildMetadata(envdVersion),
},
slot: slot,
client: client,
@ -1213,6 +1311,155 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) {
return sb.client, nil
}
// SetDefaults calls envd's PostInit to configure the default user and
// environment variables for a running sandbox. This is called by the host
// agent after sandbox creation or resume when the template specifies defaults.
func (m *Manager) SetDefaults(ctx context.Context, sandboxID, defaultUser string, defaultEnv map[string]string) error {
if defaultUser == "" && len(defaultEnv) == 0 {
return nil
}
sb, err := m.get(sandboxID)
if err != nil {
return err
}
if sb.Status != models.StatusRunning {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
return sb.client.PostInitWithDefaults(ctx, defaultUser, defaultEnv)
}
// PtyAttach starts a new PTY process or reconnects to an existing one.
// If cmd is non-empty, starts a new process. If empty, reconnects using tag.
func (m *Manager) PtyAttach(ctx context.Context, sandboxID, tag, cmd string, args []string, cols, rows uint32, envs map[string]string, cwd string) (<-chan envdclient.PtyEvent, error) {
sb, err := m.get(sandboxID)
if err != nil {
return nil, err
}
if sb.Status != models.StatusRunning {
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
if cmd != "" {
return sb.client.PtyStart(ctx, tag, cmd, args, cols, rows, envs, cwd)
}
return sb.client.PtyConnect(ctx, tag)
}
// PtySendInput sends raw bytes to a PTY process in a sandbox.
func (m *Manager) PtySendInput(ctx context.Context, sandboxID, tag string, data []byte) error {
sb, err := m.get(sandboxID)
if err != nil {
return err
}
if sb.Status != models.StatusRunning {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
return sb.client.PtySendInput(ctx, tag, data)
}
// PtyResize updates the terminal dimensions for a PTY process in a sandbox.
func (m *Manager) PtyResize(ctx context.Context, sandboxID, tag string, cols, rows uint32) error {
sb, err := m.get(sandboxID)
if err != nil {
return err
}
if sb.Status != models.StatusRunning {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
return sb.client.PtyResize(ctx, tag, cols, rows)
}
// PtyKill sends SIGKILL to a PTY process in a sandbox.
func (m *Manager) PtyKill(ctx context.Context, sandboxID, tag string) error {
sb, err := m.get(sandboxID)
if err != nil {
return err
}
if sb.Status != models.StatusRunning {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
return sb.client.PtyKill(ctx, tag)
}
// StartBackground starts a background process inside a sandbox.
func (m *Manager) StartBackground(ctx context.Context, sandboxID, tag, cmd string, args []string, envs map[string]string, cwd string) (uint32, error) {
sb, err := m.get(sandboxID)
if err != nil {
return 0, err
}
if sb.Status != models.StatusRunning {
return 0, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
return sb.client.StartBackground(ctx, tag, cmd, args, envs, cwd)
}
// ConnectProcess re-attaches to a running process inside a sandbox.
func (m *Manager) ConnectProcess(ctx context.Context, sandboxID string, pid uint32, tag string) (<-chan envdclient.ExecStreamEvent, error) {
sb, err := m.get(sandboxID)
if err != nil {
return nil, err
}
if sb.Status != models.StatusRunning {
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
return sb.client.ConnectProcess(ctx, pid, tag)
}
// ListProcesses returns all running processes inside a sandbox.
func (m *Manager) ListProcesses(ctx context.Context, sandboxID string) ([]envdclient.ProcessInfo, error) {
sb, err := m.get(sandboxID)
if err != nil {
return nil, err
}
if sb.Status != models.StatusRunning {
return nil, fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
return sb.client.ListProcesses(ctx)
}
// KillProcess sends a signal to a process inside a sandbox.
func (m *Manager) KillProcess(ctx context.Context, sandboxID string, pid uint32, tag string, signal envdpb.Signal) error {
sb, err := m.get(sandboxID)
if err != nil {
return err
}
if sb.Status != models.StatusRunning {
return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status)
}
m.mu.Lock()
sb.LastActiveAt = time.Now()
m.mu.Unlock()
return sb.client.KillProcess(ctx, pid, tag, signal)
}
// AcquireProxyConn atomically looks up a sandbox by ID and registers an
// in-flight proxy connection. Returns the sandbox's host-reachable IP, the
// connection tracker, and true on success. The caller must call

View File

@ -1 +0,0 @@
package scheduler

View File

@ -1,71 +0,0 @@
package scheduler
import (
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/db"
)
// HostScheduler selects a host for a new sandbox. Implementations may use
// different strategies (round-robin, least-loaded, tag-based, etc.).
type HostScheduler interface {
// SelectHost returns a host that can accept a new sandbox.
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
// are considered. For non-BYOC teams, only online regular (platform) hosts
// are considered. Returns an error if no suitable host is available.
SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error)
}
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
// It re-fetches the host list on every call so that newly registered or
// recovered hosts are considered immediately.
type RoundRobinScheduler struct {
db *db.Queries
counter atomic.Int64
}
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
return &RoundRobinScheduler{db: queries}
}
// SelectHost returns the next eligible online host in round-robin order.
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
}
var eligible []db.Host
for _, h := range hosts {
if h.Status != "online" || h.Address == "" {
continue
}
if isByoc {
// BYOC team: only use hosts belonging to this team.
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
continue
}
} else {
// Non-BYOC team: only use platform (regular) hosts.
if h.Type != "regular" {
continue
}
}
eligible = append(eligible, h)
}
if len(eligible) == 0 {
if isByoc {
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
}
return db.Host{}, fmt.Errorf("no online platform hosts available")
}
idx := s.counter.Add(1) - 1
return eligible[int(idx%int64(len(eligible)))], nil
}

View File

@ -1 +0,0 @@
package scheduler

View File

@ -1 +0,0 @@
package scheduler

View File

@ -1,65 +0,0 @@
package service
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
)
// APIKeyService provides API key operations shared between the REST API and the dashboard.
type APIKeyService struct {
DB *db.Queries
}
// APIKeyCreateResult holds the result of creating an API key, including the
// plaintext key which is only available at creation time.
type APIKeyCreateResult struct {
Row db.TeamApiKey
Plaintext string
}
// Create generates a new API key for the given team.
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
plaintext, hash, err := auth.GenerateAPIKey()
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
}
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
ID: id.NewAPIKeyID(),
TeamID: teamID,
Name: name,
KeyHash: hash,
KeyPrefix: auth.APIKeyPrefix(plaintext),
CreatedBy: userID,
})
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
}
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
}
// List returns all API keys belonging to the given team.
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}

View File

@ -1,113 +0,0 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
)
const auditMaxLimit = 200
// AuditEntry is a single audit log record returned by List.
type AuditEntry struct {
ID string
TeamID string
ActorType string
ActorID string // empty for system
ActorName string // empty for system
ResourceType string
ResourceID string // empty when not applicable
Action string
Scope string
Status string // 'success', 'info', 'warning', 'error'
Metadata map[string]any
CreatedAt time.Time
}
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
TeamID pgtype.UUID
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
ResourceTypes []string // empty = no filter; multiple values = OR match
Actions []string // empty = no filter; multiple values = OR match
Before time.Time // zero = no cursor (start from latest)
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
type AuditService struct {
DB *db.Queries
}
// List returns a page of audit log entries for the given team.
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
limit := p.Limit
if limit <= 0 {
limit = 50
}
if limit > auditMaxLimit {
limit = auditMaxLimit
}
scopes := []string{"team"}
if p.AdminScoped {
scopes = append(scopes, "admin")
}
var before pgtype.Timestamptz
if !p.Before.IsZero() {
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
}
resourceTypes := p.ResourceTypes
if resourceTypes == nil {
resourceTypes = []string{}
}
actions := p.Actions
if actions == nil {
actions = []string{}
}
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
TeamID: p.TeamID,
Column2: scopes,
Column3: resourceTypes,
Column4: actions,
Column5: before,
ID: p.BeforeID,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list audit logs: %w", err)
}
entries := make([]AuditEntry, len(rows))
for i, row := range rows {
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
ID: id.FormatAuditLogID(row.ID),
TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
Scope: row.Scope,
Status: row.Status,
Metadata: meta,
CreatedAt: row.CreatedAt.Time,
}
}
return entries, nil
}

View File

@ -1,605 +0,0 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
"git.omukk.dev/wrenn/wrenn/internal/recipe"
"git.omukk.dev/wrenn/wrenn/internal/scheduler"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
const (
buildQueueKey = "wrenn:build_queue"
buildCommandTimeout = 30 * time.Second
)
// preBuildCmds run before the user recipe to prepare the build environment.
var preBuildCmds = []string{
"RUN apt update",
}
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
var postBuildCmds = []string{
"RUN apt clean",
"RUN apt autoremove -y",
"RUN rm -rf /var/lib/apt/lists/*",
}
// buildAgentClient is the subset of the host agent client used by the build worker.
type buildAgentClient interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
}
// BuildService handles template build orchestration.
type BuildService struct {
DB *db.Queries
Redis *redis.Client
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
mu sync.Mutex
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
}
// BuildCreateParams holds the parameters for creating a template build.
type BuildCreateParams struct {
Name string
BaseTemplate string
Recipe []string
Healthcheck string
VCPUs int32
MemoryMB int32
SkipPrePost bool
}
// Create inserts a new build record and enqueues it to Redis.
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
if p.BaseTemplate == "" {
p.BaseTemplate = "minimal"
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
recipeJSON, err := json.Marshal(p.Recipe)
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
}
buildID := id.NewBuildID()
buildIDStr := id.FormatBuildID(buildID)
newTemplateID := id.NewTemplateID()
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
if p.SkipPrePost {
defaultSteps = 0
}
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
ID: buildID,
Name: p.Name,
BaseTemplate: p.BaseTemplate,
Recipe: recipeJSON,
Healthcheck: p.Healthcheck,
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TotalSteps: int32(len(p.Recipe) + defaultSteps),
TemplateID: newTemplateID,
TeamID: id.PlatformTeamID,
SkipPrePost: p.SkipPrePost,
})
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
}
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
}
return build, nil
}
// Get returns a single build by ID.
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
return s.DB.GetTemplateBuild(ctx, buildID)
}
// List returns all builds ordered by creation time.
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
return s.DB.ListTemplateBuilds(ctx)
}
// Cancel cancels a pending or running build. For pending builds the status is
// updated in the DB and the worker skips it when dequeued. For running builds
// the per-build context is cancelled, which causes the current exec step to
// abort; executeBuild then detects the cancellation and records the status.
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
build, err := s.DB.GetTemplateBuild(ctx, buildID)
if err != nil {
return fmt.Errorf("get build: %w", err)
}
switch build.Status {
case "success", "failed", "cancelled":
return fmt.Errorf("build is already %s", build.Status)
}
// Mark cancelled in DB first. This handles both pending builds (which haven't
// been picked up yet) and acts as a flag for executeBuild to check on start.
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
ID: buildID, Status: "cancelled",
}); err != nil {
return fmt.Errorf("update build status: %w", err)
}
// If the build is currently running, signal its context.
buildIDStr := id.FormatBuildID(buildID)
s.mu.Lock()
cancel, running := s.cancelMap[buildIDStr]
s.mu.Unlock()
if running {
cancel()
}
return nil
}
// StartWorkers launches n goroutines that consume from the Redis build queue.
// The returned cancel function stops all workers.
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
for i := range n {
go s.worker(ctx, i)
}
slog.Info("build workers started", "count", n)
return cancel
}
func (s *BuildService) worker(ctx context.Context, workerID int) {
log := slog.With("worker", workerID)
for {
// BLPOP blocks until a build ID is available or context is cancelled.
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
if err != nil {
if ctx.Err() != nil {
log.Info("build worker shutting down")
return
}
log.Error("redis BLPOP error", "error", err)
time.Sleep(time.Second)
continue
}
// result[0] is the key, result[1] is the build ID (formatted string).
buildIDStr := result[1]
log.Info("picked up build", "build_id", buildIDStr)
s.executeBuild(ctx, buildIDStr)
}
}
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
log := slog.With("build_id", buildIDStr)
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
log.Error("invalid build ID from queue", "error", err)
return
}
// Create a per-build context so this build can be cancelled independently of
// the worker. Register in cancelMap before fetching the build so that a
// concurrent Cancel call can always find and signal it.
buildCtx, buildCancel := context.WithCancel(ctx)
defer buildCancel()
s.mu.Lock()
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[buildIDStr] = buildCancel
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.cancelMap, buildIDStr)
s.mu.Unlock()
}()
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
if err != nil {
log.Error("failed to fetch build", "error", err)
return
}
// Skip if already cancelled (Cancel was called before we dequeued).
if build.Status == "cancelled" {
log.Info("build already cancelled, skipping")
return
}
// Mark as running.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "running",
}); err != nil {
log.Error("failed to update build status", "error", err)
return
}
// Parse user recipe.
var userRecipe []string
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
return
}
// Pick a platform host and create a sandbox.
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
return
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
return
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
baseTeamID := id.PlatformTeamID
baseTemplateID := id.MinimalTemplateID
if build.BaseTemplate != "minimal" {
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
return
}
baseTeamID = baseTmpl.TeamID
baseTemplateID = baseTmpl.ID
}
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: build.BaseTemplate,
TeamId: id.UUIDString(baseTeamID),
TemplateId: id.UUIDString(baseTemplateID),
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
TimeoutSec: 0, // no auto-pause for builds
DiskSizeMb: 5120, // 5 GB for template builds
}))
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
return
}
_ = resp
// Record sandbox/host association.
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
ID: buildID,
SandboxID: sandboxID,
HostID: host.ID,
})
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
// valid; panic on error is appropriate here since it would be a programmer mistake.
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
}
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
return
}
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
}
var logs []recipe.BuildLogEntry
step := 0
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
if err != nil {
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
envVars = map[string]string{
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"HOME": "/root",
}
}
bctx := &recipe.ExecContext{EnvVars: envVars}
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec)
logs = append(logs, newEntries...)
step = nextStep
s.updateLogs(buildCtx, buildID, step, logs)
if !ok {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
// If the build was cancelled, status is already set — don't overwrite with "failed".
if buildCtx.Err() != nil {
return false
}
last := newEntries[len(newEntries)-1]
reason := last.Stderr
if reason == "" {
reason = fmt.Sprintf("exit code %d", last.Exit)
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
}
return ok
}
if !build.SkipPrePost {
if !runPhase("pre-build", preBuildSteps, 0) {
return
}
}
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
return
}
if !build.SkipPrePost {
if !runPhase("post-build", postBuildSteps, 0) {
return
}
}
// Healthcheck or direct snapshot.
var sizeBytes int64
if build.Healthcheck != "" {
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
return
}
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
return
}
// Healthcheck passed → full snapshot (with memory/CPU state).
log.Info("healthcheck passed, creating snapshot")
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
return
}
sizeBytes = snapResp.Msg.SizeBytes
} else {
// No healthcheck → image-only template (rootfs only).
log.Info("no healthcheck, flattening rootfs")
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
return
}
sizeBytes = flatResp.Msg.SizeBytes
}
// Insert into templates table as a global (platform) template.
templateType := "base"
if build.Healthcheck != "" {
templateType = "snapshot"
}
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
ID: build.TemplateID,
Name: build.Name,
Type: templateType,
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
SizeBytes: sizeBytes,
TeamID: id.PlatformTeamID,
}); err != nil {
log.Error("failed to insert template record", "error", err)
// Build succeeded on disk, just DB record failed — don't mark as failed.
}
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
// No additional destroy needed.
// Mark build as success.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "success",
}); err != nil {
log.Error("failed to mark build as success", "error", err)
}
log.Info("template build completed successfully", "name", build.Name)
}
// waitForHealthcheck repeatedly executes the healthcheck command inside the
// sandbox according to the config's interval, timeout, start-period, and
// retries.
// During the start period, failures are not counted toward the retry budget.
// Returns nil on the first successful check, or an error if retries are
// exhausted, the deadline passes, or the context is cancelled.
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig) error {
ticker := time.NewTicker(hc.Interval)
defer ticker.Stop()
// When retries > 0, set a deadline based on the retry budget.
// When retries == 0 (unlimited), rely solely on the parent context deadline.
var deadlineCh <-chan time.Time
if hc.Retries > 0 {
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
defer deadline.Stop()
deadlineCh = deadline.C
}
startedAt := time.Now()
failCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadlineCh:
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
case <-ticker.C:
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", hc.Cmd},
TimeoutSec: int32(hc.Timeout.Seconds()),
}))
cancel()
if err != nil {
slog.Debug("healthcheck exec error (retrying)", "error", err)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
}
}
continue
}
if resp.Msg.ExitCode == 0 {
return nil
}
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
}
}
}
}
}
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
logsJSON, err := json.Marshal(logs)
if err != nil {
slog.Warn("failed to marshal build logs", "error", err)
return
}
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
ID: buildID,
CurrentStep: int32(step),
Logs: logsJSON,
}); err != nil {
slog.Warn("failed to update build progress", "error", err)
}
}
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
ID: buildID,
Error: errMsg,
}); err != nil {
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
}
}
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
// Use a detached context so cleanup succeeds even during shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
}
}
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
// the build agent and returns environment variables
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", "env"},
TimeoutSec: 10,
}))
if err != nil {
return nil, fmt.Errorf("fetch env: %w", err)
}
if resp.Msg.ExitCode != 0 {
return nil, fmt.Errorf("fetch env: command exited with code %d",
resp.Msg.ExitCode)
}
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
}
// parseSandboxEnv converts the raw newline-separated output of an 'env'
// command into a map.
// It skips empty lines and malformed entries, and correctly handles values
// containing '='.
func parseSandboxEnv(raw string) map[string]string {
envVars := make(map[string]string)
for line := range strings.SplitSeq(raw, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
envVars[parts[0]] = parts[1]
}
return envVars
}

View File

@ -1,628 +0,0 @@
package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/auth"
"git.omukk.dev/wrenn/wrenn/internal/db"
"git.omukk.dev/wrenn/wrenn/internal/id"
"git.omukk.dev/wrenn/wrenn/internal/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// HostService provides host management operations.
type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
// HostCreateResult holds the created host and the one-time registration token.
type HostCreateResult struct {
Host db.Host
RegistrationToken string
}
// HostRegisterParams holds the parameters for host agent registration.
type HostRegisterParams struct {
Token string
Arch string
CPUCores int32
MemoryMB int32
DiskGB int32
Address string
}
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
type regTokenPayload struct {
HostID string `json:"host_id"`
TokenID string `json:"token_id"`
}
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
}
if p.Type == "regular" {
if !p.IsRequestorAdmin {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: platform admin, or team owner/admin.
if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: p.RequestingUserID,
TeamID: p.TeamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
// Validate team exists, is not deleted, and has BYOC enabled.
if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
if !team.IsByoc {
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
}
}
hostID := id.NewHostID()
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: p.TeamID,
Provider: p.Provider,
AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
}
// Generate registration token and store in Redis + Postgres audit trail.
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
}
if host.Status != "pending" {
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
if err == redis.Nil {
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
}
if err != nil {
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
}
var payload regTokenPayload
if err := json.Unmarshal(raw, &payload); err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
hostID, err := id.ParseHostID(payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
tokenID, err := id.ParseHostTokenID(payload.TokenID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: hostID,
Arch: p.Arch,
CpuCores: p.CPUCores,
MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB,
Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
}
if rowsAffected == 0 {
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
}
if n == 0 {
return fmt.Errorf("host not found")
}
return nil
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return HostDeletePreview{}, err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return &HostHasSandboxesError{SandboxIDs: ids}
}
hostIDStr := id.FormatHostID(hostID)
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return host, nil
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

Some files were not shown because too many files have changed in this diff Show More