1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev>

Reviewed-on: wrenn/wrenn#50
This commit is contained in:
2026-05-24 21:10:37 +00:00
parent 4707f16c76
commit 05ddf62399
203 changed files with 15815 additions and 9344 deletions

View File

@ -3,11 +3,20 @@ package api
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
@ -20,3 +29,119 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
}
return pool.GetForHost(host)
}
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
// and verifies it is running. On failure it writes the appropriate HTTP error and
// returns false.
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
return sb, sandboxID, sandboxIDStr, true
}
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket. The
// auth context must already be populated by upstream middleware — browser
// clients via the wrenn_sid cookie (sent automatically on WS upgrade),
// SDK clients via X-API-Key. Requests without an auth context are rejected
// with a 401 before the upgrade.
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, auth.AuthContext, error) {
ac, hasAuth := auth.FromContext(r.Context())
if !hasAuth {
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
return nil, auth.AuthContext{}, fmt.Errorf("unauthenticated")
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
}
return conn, ac, nil
}
// resolveTemplateSizes queries a host agent for the actual disk usage of any
// templates with size_bytes <= 0 (e.g. system base templates seeded with
// size_bytes = 0 before the rootfs was built). Results are persisted to the
// DB so subsequent requests serve the correct size without an RPC call.
// Errors are logged but do not prevent the caller from serving the templates.
func resolveTemplateSizes(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, templates []db.Template) []db.Template {
needResolve := false
for _, t := range templates {
if t.SizeBytes <= 0 {
needResolve = true
break
}
}
if !needResolve {
return templates
}
hosts, err := queries.ListActiveHosts(ctx)
if err != nil || len(hosts) == 0 {
slog.Warn("resolveTemplateSizes: no active hosts available", "error", err)
return templates
}
agent, err := pool.GetForHost(hosts[0])
if err != nil {
slog.Warn("resolveTemplateSizes: failed to connect to host",
"host_id", id.UUIDString(hosts[0].ID), "error", err)
return templates
}
for i, t := range templates {
if t.SizeBytes > 0 {
continue
}
resp, err := agent.GetTemplateSize(ctx, connect.NewRequest(&pb.GetTemplateSizeRequest{
TeamId: formatUUIDForRPC(t.TeamID),
TemplateId: formatUUIDForRPC(t.ID),
}))
if err != nil {
slog.Warn("resolveTemplateSizes: failed to get size from host",
"template", t.Name, "error", err)
continue
}
templates[i].SizeBytes = resp.Msg.SizeBytes
if err := queries.UpdateTemplateSize(ctx, db.UpdateTemplateSizeParams{
ID: t.ID,
SizeBytes: resp.Msg.SizeBytes,
}); err != nil {
slog.Warn("resolveTemplateSizes: failed to persist size",
"template", t.Name, "error", err)
}
}
return templates
}
// updateLastActive updates the sandbox last_active_at timestamp.
// Uses a background context with timeout for streaming handlers where
// the request context may already be cancelled.
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
}

View File

@ -0,0 +1,54 @@
package api
import (
"context"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// collectAuthHooks filters the extension list down to those that implement
// cpextension.AuthHook.
func collectAuthHooks(extensions []cpextension.Extension) []cpextension.AuthHook {
var hooks []cpextension.AuthHook
for _, ext := range extensions {
if h, ok := ext.(cpextension.AuthHook); ok {
hooks = append(hooks, h)
}
}
return hooks
}
// fireOnSignup runs every OnSignup hook sequentially. The first error wins —
// signup must abort so the cloud-side billing customer creation stays the
// gating step. Other hook errors after a failure are not run.
func fireOnSignup(ctx context.Context, hooks []cpextension.AuthHook, userID, teamID pgtype.UUID, email string) error {
for _, h := range hooks {
if err := h.OnSignup(ctx, userID, teamID, email); err != nil {
return err
}
}
return nil
}
// fireOnLogin runs every OnLogin hook. Errors are logged and swallowed —
// login must never be blocked by a misbehaving extension.
func fireOnLogin(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
for _, h := range hooks {
if err := h.OnLogin(ctx, userID); err != nil {
slog.Warn("auth hook OnLogin failed", "user_id", id.FormatUserID(userID), "error", err)
}
}
}
// fireOnSoftDelete runs every OnAccountSoftDelete hook. Errors are logged.
func fireOnSoftDelete(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
for _, h := range hooks {
if err := h.OnAccountSoftDelete(ctx, userID); err != nil {
slog.Warn("auth hook OnAccountSoftDelete failed", "user_id", id.FormatUserID(userID), "error", err)
}
}
}

View File

@ -1,14 +1,9 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
@ -17,8 +12,6 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type adminCapsuleHandler struct {
@ -49,15 +42,18 @@ func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
MemoryMB: req.MemoryMB,
TimeoutSec: req.TimeoutSec,
})
ac.TeamID = id.PlatformTeamID
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
if err != nil {
if sb.ID.Valid {
h.audit.LogSandboxDestroySystem(r.Context(), id.PlatformTeamID, sb.ID, "cleanup_after_create_error", nil)
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
ac.TeamID = id.PlatformTeamID
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/admin/capsules.
@ -106,26 +102,27 @@ func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
ac.TeamID = id.PlatformTeamID
err = h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID)
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
w.WriteHeader(http.StatusAccepted)
}
type adminSnapshotRequest struct {
Name string `json:"name"`
}
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot. Takes a live
// snapshot of a platform-owned capsule and registers the result as a
// platform template (team_id = 00000000-...).
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
@ -133,117 +130,22 @@ func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
}
var req adminSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
// Verify sandbox exists and belongs to platform team BEFORE any
// destructive operations (template overwrite).
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
// Check if name already exists as a platform template.
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
// Delete old snapshot files from all hosts before removing the DB record.
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
if r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
sb, name, err := h.svc.CreateSnapshot(r.Context(), sandboxID, id.PlatformTeamID, req.Name)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context so the snapshot completes even if the client disconnects.
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: req.Name,
TeamId: formatUUIDForRPC(id.PlatformTeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
ac := auth.MustFromContext(r.Context())
ac.TeamID = id.PlatformTeamID
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: id.PlatformTeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
// Destroy the ephemeral capsule after successful snapshot.
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
// Don't fail the response — the snapshot was created successfully.
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}

View File

@ -20,6 +20,8 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
@ -140,14 +142,15 @@ type switchTeamRequest struct {
type authHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
sessions *session.Service
mailer email.Mailer
rdb *redis.Client
redirectURL string
authHooks []cpextension.AuthHook
}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, sessions *session.Service, mailer email.Mailer, rdb *redis.Client, redirectURL string, hooks []cpextension.AuthHook) *authHandler {
return &authHandler{db: db, pool: pool, sessions: sessions, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/"), authHooks: hooks}
}
type signupRequest struct {
@ -166,11 +169,41 @@ type activateRequest struct {
}
type authResponse struct {
Token string `json:"token"`
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
IsAdmin bool `json:"is_admin"`
}
// issueSession creates a new session and writes both the session and CSRF
// cookies to the response. On success it writes the authResponse JSON body.
func (h *authHandler) issueSession(
w http.ResponseWriter,
r *http.Request,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
) error {
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
if err != nil {
return err
}
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
return nil
}
// clientIP returns the request's apparent client IP, honoring
// X-Forwarded-For when behind a reverse proxy.
func clientIP(r *http.Request) string {
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
if i := strings.IndexByte(fwd, ','); i > 0 {
return strings.TrimSpace(fwd[:i])
}
return strings.TrimSpace(fwd)
}
return r.RemoteAddr
}
type signupResponse struct {
@ -253,8 +286,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
}
// Generate activation token and store in Redis.
rawToken := generateActivationToken()
tokenHash := hashActivationToken(rawToken)
rawToken := generateOpaqueToken()
tokenHash := hashOpaqueToken(rawToken)
redisKey := activationKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
@ -296,7 +329,7 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
tokenHash := hashActivationToken(req.Token)
tokenHash := hashOpaqueToken(req.Token)
redisKey := activationKeyPrefix + tokenHash
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
@ -345,18 +378,26 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
// Fire OnSignup before issuing a session — billing must succeed first.
if err := fireOnSignup(ctx, h.authHooks, userID, team.ID, user.Email); err != nil {
slog.Error("activate: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", err)
writeError(w, http.StatusInternalServerError, "signup_hook_failed", "failed to finalize account setup")
return
}
if err := h.issueSession(w, r, userID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
slog.Error("activate: failed to issue session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
return
}
fireOnLogin(ctx, h.authHooks, userID)
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
Role: role,
IsAdmin: isAdmin,
})
}
@ -427,18 +468,20 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
if err := h.issueSession(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
slog.Error("login: failed to issue session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
return
}
fireOnLogin(ctx, h.authHooks, user.ID)
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
Role: role,
IsAdmin: isAdmin,
})
}
@ -503,24 +546,54 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
// Rotate the SID so any leaked old cookie loses access at the moment of
// privilege change.
newSess, err := h.sessions.Rotate(ctx, ac.SessionID, ac.UserID, teamID, user.Email, user.Name, membership.Role, user.IsAdmin, r.UserAgent(), clientIP(r))
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
slog.Error("switch team: failed to rotate session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to switch team")
return
}
setSessionCookies(w, newSess.RawSID, newSess.CSRFToken, isSecure(r))
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: user.Email,
Name: user.Name,
Role: membership.Role,
IsAdmin: user.IsAdmin,
})
}
// Logout handles POST /v1/auth/logout — revokes the caller's current session.
func (h *authHandler) Logout(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
if err := h.sessions.Revoke(r.Context(), ac.SessionID); err != nil {
slog.Warn("logout: revoke failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
w.WriteHeader(http.StatusNoContent)
}
// LogoutAll handles POST /v1/auth/logout-all — revokes every session for the
// current user, including the caller's own.
func (h *authHandler) LogoutAll(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
if err := h.sessions.RevokeAllForUser(r.Context(), ac.UserID); err != nil {
slog.Error("logout-all: revoke failed", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to revoke sessions")
return
}
clearSessionCookies(w, isSecure(r))
w.WriteHeader(http.StatusNoContent)
}
// --- helpers ---
func generateActivationToken() string {
// generateOpaqueToken returns a fresh 16-byte hex-encoded random token,
// used for email activation links and password reset links.
func generateOpaqueToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
@ -528,7 +601,9 @@ func generateActivationToken() string {
return hex.EncodeToString(b)
}
func hashActivationToken(raw string) string {
// hashOpaqueToken returns the SHA-256 hex digest of raw, used as the
// lookup key for one-shot tokens stored in Redis.
func hashOpaqueToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}

View File

@ -0,0 +1,158 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/internal/recipe"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
// buildStreamKeepalive is the interval between server pings on an idle build
// stream, preventing intermediaries from closing the WebSocket.
const buildStreamKeepalive = 30 * time.Second
// buildStreamHandler serves the live admin build console WebSocket.
type buildStreamHandler struct {
db *db.Queries
broker *service.BuildBroker
}
func newBuildStreamHandler(db *db.Queries, broker *service.BuildBroker) *buildStreamHandler {
return &buildStreamHandler{db: db, broker: broker}
}
// Stream handles WS /v1/admin/builds/{id}/stream. On connect it replays the
// completed-step history from the DB log, sends the current build status,
// then live-tails events from the build broker until the build finishes or
// the client disconnects. Admin auth is enforced by upstream middleware.
func (h *buildStreamHandler) Stream(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.db.GetTemplateBuild(r.Context(), buildID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "build not found")
return
}
conn, _, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("build stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
h.runStream(r.Context(), conn, build)
}
func (h *buildStreamHandler) runStream(ctx context.Context, conn *websocket.Conn, build db.TemplateBuild) {
ws := &wsWriter{conn: conn}
buildIDStr := id.FormatBuildID(build.ID)
// Replay completed-step history from the DB log snapshot. lastStep is the
// highest step number already delivered, used to dedup overlapping live
// events for a step that finished between the DB read and the subscribe.
lastStep := replayBuildHistory(ws, build)
ws.writeJSON(service.BuildStreamEvent{
Type: "build-status",
Status: build.Status,
CurrentStep: build.CurrentStep,
TotalSteps: build.TotalSteps,
Error: build.Error,
})
// A finished build has no live events to follow.
if service.IsTerminalBuildStatus(build.Status) {
return
}
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
events, release := h.broker.Subscribe(buildIDStr)
defer release()
// Drain client reads so a disconnect cancels the stream. The client sends
// nothing meaningful; any read error means the socket is gone.
go func() {
for {
if _, _, err := conn.ReadMessage(); err != nil {
cancel()
return
}
}
}()
ticker := time.NewTicker(buildStreamKeepalive)
defer ticker.Stop()
for {
select {
case <-streamCtx.Done():
return
case <-ticker.C:
ws.writeJSON(map[string]string{"type": "ping"})
case ev, ok := <-events:
if !ok {
return
}
// Skip step events already covered by the history replay.
if ev.Type != "build-status" && ev.Step > 0 && ev.Step <= lastStep {
continue
}
ws.writeJSON(ev)
if ev.Type == "build-status" && service.IsTerminalBuildStatus(ev.Status) {
return
}
}
}
}
// replayBuildHistory synthesizes step-start/output/step-end events from the
// build's persisted log entries and writes them to the WebSocket. It returns
// the highest step number replayed.
func replayBuildHistory(ws *wsWriter, build db.TemplateBuild) int {
if len(build.Logs) == 0 {
return 0
}
var entries []recipe.BuildLogEntry
if err := json.Unmarshal(build.Logs, &entries); err != nil {
slog.Warn("build stream: bad log JSON", "build_id", id.FormatBuildID(build.ID), "error", err)
return 0
}
lastStep := 0
for _, e := range entries {
ws.writeJSON(service.BuildStreamEvent{Type: "step-start", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd})
if out := e.Stdout + e.Stderr; out != "" {
ws.writeJSON(service.BuildStreamEvent{
Type: "output",
Step: e.Step,
Data: base64.StdEncoding.EncodeToString([]byte(out)),
})
}
ws.writeJSON(service.BuildStreamEvent{
Type: "step-end", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd,
Exit: e.Exit, Ok: e.Ok, ElapsedMs: e.Elapsed,
})
if e.Step > lastStep {
lastStep = e.Step
}
}
return lastStep
}

View File

@ -9,7 +9,6 @@ import (
"strings"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/layout"
@ -20,7 +19,6 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type buildHandler struct {
@ -42,6 +40,7 @@ type createBuildRequest struct {
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SkipPrePost bool `json:"skip_pre_post"`
RunAsRoot bool `json:"run_as_root"`
}
type buildResponse struct {
@ -181,6 +180,7 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
RunAsRoot: req.RunAsRoot,
Archive: archive,
ArchiveName: archiveName,
})
@ -238,6 +238,9 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
return
}
// Resolve actual on-disk sizes for templates with unknown size.
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
type templateResponse struct {
Name string `json:"name"`
Type string `json:"type"`
@ -246,6 +249,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
CreatedAt string `json:"created_at"`
Protected bool `json:"protected"`
}
resp := make([]templateResponse, len(templates))
@ -257,6 +261,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
MemoryMB: t.MemoryMb,
SizeBytes: t.SizeBytes,
TeamID: id.FormatTeamID(t.TeamID),
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
}
if t.CreatedAt.Valid {
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -280,29 +285,17 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
return
}
// Broadcast delete to all online hosts.
hosts, _ := h.db.ListActiveHosts(ctx)
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(tmpl.TeamID),
TemplateId: formatUUIDForRPC(tmpl.ID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
// Remove the files from every host before dropping the DB record, so a
// failure leaves the template intact and retryable rather than orphaned.
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusConflict, "delete_failed",
"could not remove template files from all hosts: "+err.Error())
return
}
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {

View File

@ -3,14 +3,11 @@ package api
import (
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"time"
"unicode/utf8"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
// Exec handles POST /v1/capsules/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
SandboxID: sandboxIDStr,
@ -142,6 +119,8 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
Envs: req.Envs,
Cwd: req.Cwd,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
@ -151,41 +130,24 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
duration := time.Since(start)
// Update last active.
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
// Use base64 encoding if output contains non-UTF-8 bytes.
stdout := resp.Msg.Stdout
stderr := resp.Msg.Stderr
encoding := "utf-8"
encoding := "utf-8"
stdoutStr, stderrStr := string(stdout), string(stderr)
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
ExitCode: resp.Msg.ExitCode,
DurationMs: duration.Milliseconds(),
Encoding: encoding,
})
return
stdoutStr = base64.StdEncoding.EncodeToString(stdout)
stderrStr = base64.StdEncoding.EncodeToString(stderr)
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),
Stdout: stdoutStr,
Stderr: stderrStr,
ExitCode: resp.Msg.ExitCode,
DurationMs: duration.Milliseconds(),
Encoding: encoding,

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -20,13 +19,12 @@ import (
)
type execStreamHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
}
var upgrader = websocket.Upgrader{
@ -59,37 +57,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
slog.Error("websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -186,18 +156,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
}
}
// Update last active using a fresh context (the request context may be cancelled).
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func sendWSError(conn *websocket.Conn, msg string) {

View File

@ -7,11 +7,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -108,23 +94,11 @@ type readFileRequest struct {
// Download handles POST /v1/capsules/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -8,11 +8,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -103,6 +89,12 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// Open client-streaming RPC to host agent.
stream := agent.WriteFileStream(ctx)
var streamClosed bool
defer func() {
if !streamClosed {
_, _ = stream.CloseAndReceive()
}
}()
// Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{
@ -141,6 +133,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
}
// Close and receive response.
streamClosed = true
if _, err := stream.CloseAndReceive(); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
@ -153,23 +146,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -4,11 +4,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -58,23 +56,11 @@ type removeRequest struct {
// ListDir handles POST /v1/capsules/{id}/files/list.
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
// Remove handles POST /v1/capsules/{id}/files/remove.
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
@ -21,10 +22,11 @@ type hostHandler struct {
svc *service.HostService
queries *db.Queries
audit *audit.AuditLogger
monitor *HostMonitor
}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger, monitor *HostMonitor) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al, monitor: monitor}
}
// Request/response types.
@ -98,6 +100,11 @@ type hostResponse struct {
CreatedBy string `json:"created_by"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
RunningVcpus int32 `json:"running_vcpus"`
RunningMemoryMb int32 `json:"running_memory_mb"`
RunningDiskMb int32 `json:"running_disk_mb"`
PausedMemoryMb int32 `json:"paused_memory_mb"`
PausedDiskMb int32 `json:"paused_disk_mb"`
}
func hostToResponse(h db.Host) hostResponse {
@ -136,12 +143,37 @@ func hostToResponse(h db.Host) hostResponse {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
resp.LastHeartbeatAt = &s
}
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
return resp
}
func hostToResponseWithLoad(h db.ListHostsByTeamRow) hostResponse {
resp := hostToResponse(db.Host{
ID: h.ID,
Type: h.Type,
TeamID: h.TeamID,
Provider: h.Provider,
AvailabilityZone: h.AvailabilityZone,
Arch: h.Arch,
CpuCores: h.CpuCores,
MemoryMb: h.MemoryMb,
DiskGb: h.DiskGb,
Address: h.Address,
Status: h.Status,
LastHeartbeatAt: h.LastHeartbeatAt,
CreatedBy: h.CreatedBy,
CreatedAt: h.CreatedAt,
UpdatedAt: h.UpdatedAt,
})
resp.RunningVcpus = h.RunningVcpus
resp.RunningMemoryMb = h.RunningMemoryMb
resp.RunningDiskMb = h.RunningDiskMb
resp.PausedMemoryMb = h.PausedMemoryMb
resp.PausedDiskMb = h.PausedDiskMb
return resp
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
@ -233,7 +265,7 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponse(host)
resp[i] = hostToResponseWithLoad(host)
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
@ -335,6 +367,54 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
writeError(w, status, code, msg)
}
// AdminList handles GET /v1/admin/hosts.
// Returns all hosts with per-host resource consumption. Admin-only.
func (h *hostHandler) AdminList(w http.ResponseWriter, r *http.Request) {
hosts, err := h.svc.ListAdmin(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
return
}
// Collect unique team IDs to fetch team names.
var teamNames map[string]string
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponseWithLoad(db.ListHostsByTeamRow(host))
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
}
writeJSON(w, http.StatusOK, resp)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
@ -426,9 +506,12 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
return
}
// Log marked_up if the host just recovered from unreachable.
// If the host just recovered from unreachable, log it and trigger immediate
// reconciliation so "missing" sandboxes are resolved without waiting for the
// next monitor tick.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
go h.monitor.ReconcileHost(context.Background(), hc.HostID)
}
w.WriteHeader(http.StatusNoContent)

View File

@ -2,9 +2,6 @@ package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
@ -20,6 +17,8 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
@ -34,42 +33,62 @@ type meHandler struct {
db *db.Queries
pool *pgxpool.Pool
rdb *redis.Client
jwtSecret []byte
hmacKey []byte // HMAC key for OAuth state and link cookies (was jwtSecret)
sessions *session.Service
mailer email.Mailer
oauthRegistry *oauth.Registry
redirectURL string
teamSvc *service.TeamService
authHooks []cpextension.AuthHook
}
func newMeHandler(
db *db.Queries,
pool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
hmacKey []byte,
sessions *session.Service,
mailer email.Mailer,
registry *oauth.Registry,
redirectURL string,
teamSvc *service.TeamService,
hooks []cpextension.AuthHook,
) *meHandler {
return &meHandler{
db: db,
pool: pool,
rdb: rdb,
jwtSecret: jwtSecret,
hmacKey: hmacKey,
sessions: sessions,
mailer: mailer,
oauthRegistry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
teamSvc: teamSvc,
authHooks: hooks,
}
}
type meResponse struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
IsAdmin bool `json:"is_admin"`
HasPassword bool `json:"has_password"`
Providers []string `json:"providers"`
}
type sessionRow struct {
ID string `json:"id"`
UserAgent string `json:"user_agent"`
IPAddress string `json:"ip_address"`
CreatedAt string `json:"created_at"`
LastSeenAt string `json:"last_seen_at"`
ExpiresAt string `json:"expires_at"`
Current bool `json:"current"`
}
type updateNameRequest struct {
Name string `json:"name"`
}
@ -116,14 +135,19 @@ func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, meResponse{
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(ac.TeamID),
Name: user.Name,
Email: user.Email,
Role: ac.Role,
IsAdmin: user.IsAdmin,
HasPassword: user.PasswordHash.Valid,
Providers: providerNames,
})
}
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
// UpdateName handles PATCH /v1/me — updates the user's name and refreshes
// any cached session blobs so the new name shows up on next request.
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
@ -148,31 +172,11 @@ func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
if err := h.sessions.InvalidateCacheForUser(ctx, ac.UserID); err != nil {
slog.Warn("update name: invalidate session cache failed", "error", err)
}
team, role, err := loginTeam(ctx, h.db, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: req.Name,
})
w.WriteHeader(http.StatusNoContent)
}
// ChangePassword handles POST /v1/me/password.
@ -235,6 +239,14 @@ func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
return
}
// Revoke every session for this user — including the caller's — so a new
// password resets all device access. Clear cookies on the response so the
// caller is signed out immediately.
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
slog.Warn("change password: revoke sessions failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
isAdding := !user.PasswordHash.Valid
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@ -285,8 +297,8 @@ func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request)
return
}
rawToken := generateResetToken()
tokenHash := hashResetToken(rawToken)
rawToken := generateOpaqueToken()
tokenHash := hashOpaqueToken(rawToken)
redisKey := passwordResetKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
@ -330,7 +342,7 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
}
ctx := r.Context()
tokenHash := hashResetToken(req.Token)
tokenHash := hashOpaqueToken(req.Token)
redisKey := passwordResetKeyPrefix + tokenHash
// GetDel atomically retrieves and removes the token in a single round-trip,
@ -371,6 +383,11 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
return
}
// Reset invalidates every active session for the user.
if err := h.sessions.RevokeAllForUser(ctx, userID); err != nil {
slog.Warn("confirm password reset: revoke sessions failed", "error", err)
}
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@ -404,7 +421,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
return
}
mac := computeHMAC(h.jwtSecret, state+":"+"login")
mac := computeHMAC(h.hmacKey, state+":"+"login")
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state + ":" + mac + ":" + "login",
@ -416,7 +433,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
})
userIDStr := id.FormatUserID(ac.UserID)
linkMac := computeHMAC(h.jwtSecret, userIDStr)
linkMac := computeHMAC(h.hmacKey, userIDStr)
http.SetCookie(w, &http.Cookie{
Name: "oauth_link_user_id",
Value: userIDStr + ":" + linkMac,
@ -552,6 +569,14 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
// Revoke every active session and clear the caller's cookies.
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
slog.Warn("delete account: revoke sessions failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
fireOnSoftDelete(ctx, h.authHooks, ac.UserID)
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
go func() {
@ -569,17 +594,4 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// --- helpers ---
func generateResetToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hashResetToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}
// (token helpers live in handlers_auth.go)

View File

@ -14,10 +14,12 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
@ -25,18 +27,22 @@ import (
type oauthHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
hmacKey []byte // HMAC key for OAuth state and link cookies
sessions *session.Service
registry *oauth.Registry
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
authHooks []cpextension.AuthHook
}
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, hmacKey []byte, sessions *session.Service, registry *oauth.Registry, redirectURL string, hooks []cpextension.AuthHook) *oauthHandler {
return &oauthHandler{
db: db,
pool: pool,
jwtSecret: jwtSecret,
hmacKey: hmacKey,
sessions: sessions,
registry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
authHooks: hooks,
}
}
@ -61,7 +67,7 @@ func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
intent = "login"
}
mac := computeHMAC(h.jwtSecret, state+":"+intent)
mac := computeHMAC(h.hmacKey, state+":"+intent)
cookieVal := state + ":" + mac + ":" + intent
http.SetCookie(w, &http.Cookie{
@ -121,7 +127,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if len(parts) == 3 && parts[2] == "signup" {
intent = "signup"
}
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce+":"+intent)), []byte(expectedMAC)) {
if !hmac.Equal([]byte(computeHMAC(h.hmacKey, nonce+":"+intent)), []byte(expectedMAC)) {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
@ -164,7 +170,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
// Verify the HMAC to prevent cookie forgery.
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.hmacKey, linkParts[0])), []byte(linkParts[1])) {
slog.Warn("oauth link: invalid or tampered link cookie")
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
return
@ -244,13 +250,12 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
slog.Error("oauth login: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
fireOnLogin(ctx, h.authHooks, user.ID)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -374,10 +379,10 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
// Fire OnSignup before session issuance — billing must succeed first.
if hookErr := fireOnSignup(ctx, h.authHooks, userID, teamID, email); hookErr != nil {
slog.Error("oauth signup: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", hookErr)
redirectWithError(w, r, redirectBase, "signup_hook_failed")
return
}
@ -385,14 +390,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "wrenn_oauth_new_signup",
Value: "1",
Path: "/auth/",
Path: "/",
MaxAge: 60,
HttpOnly: false,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
if err := h.issueSessionAndRedirect(w, r, userID, teamID, email, profile.Name, "owner", isFirstUser, redirectBase); err != nil {
slog.Error("oauth: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
fireOnLogin(ctx, h.authHooks, userID)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -431,33 +441,35 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
return
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
slog.Error("oauth: retry login: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
fireOnLogin(ctx, h.authHooks, user.ID)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
// Set auth data as short-lived cookies instead of URL query parameters.
// This prevents token leakage via server access logs, Referer headers, and browser history.
for _, c := range []http.Cookie{
{Name: "wrenn_oauth_token", Value: token},
{Name: "wrenn_oauth_user_id", Value: userID},
{Name: "wrenn_oauth_team_id", Value: teamID},
{Name: "wrenn_oauth_email", Value: email},
{Name: "wrenn_oauth_name", Value: name},
} {
c.Path = "/auth/"
c.MaxAge = 60
c.HttpOnly = false // frontend JS must read these
c.SameSite = http.SameSiteLaxMode
c.Secure = isSecure(r)
http.SetCookie(w, &c)
// issueSessionAndRedirect creates a session, sets the session and CSRF
// cookies, and redirects to the frontend dashboard. The redirectBase param
// is the OAuth callback URL; we ignore it after success and send the user to
// /dashboard directly (callback page will probe /v1/me to hydrate state).
func (h *oauthHandler) issueSessionAndRedirect(
w http.ResponseWriter,
r *http.Request,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
redirectBase string,
) error {
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
if err != nil {
return err
}
http.Redirect(w, r, base, http.StatusFound)
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
// Send the user to the callback page so the SPA can probe /v1/me and
// trigger any post-OAuth UX (e.g. the new-signup name confirmation).
http.Redirect(w, r, redirectBase, http.StatusFound)
return nil
}
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
@ -477,7 +489,3 @@ func computeHMAC(key []byte, data string) string {
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func isSecure(r *http.Request) bool {
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
}

View File

@ -5,7 +5,6 @@ import (
"log/slog"
"net/http"
"strconv"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -20,13 +19,12 @@ import (
)
type processHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
return &processHandler{db: db, pool: pool}
}
// processResponse is a single entry in the process list.
@ -44,23 +42,11 @@ type processListResponse struct {
// ListProcesses handles GET /v1/capsules/{id}/processes.
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -95,24 +81,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
// The selector can be a numeric PID or a string tag.
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -146,14 +120,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// wsProcessOut is the JSON message sent to the WebSocket client.
type wsProcessOut struct {
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
PID uint32 `json:"pid,omitempty"` // only for "start"
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
@ -166,37 +132,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
slog.Error("process stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -207,17 +145,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
sendWSError(conn, "sandbox host is not reachable")
return
}
@ -236,7 +174,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
if err != nil {
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
sendWSError(conn, "failed to connect to process: "+err.Error())
return
}
defer stream.Close()
@ -257,42 +195,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.ConnectProcessResponse_Start:
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
case *pb.ConnectProcessResponse_Data:
switch o := ev.Data.Output.(type) {
case *pb.ExecStreamData_Stdout:
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
case *pb.ExecStreamData_Stderr:
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
}
case *pb.ConnectProcessResponse_End:
exitCode := ev.End.ExitCode
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
}
}
if err := stream.Err(); err != nil {
if streamCtx.Err() == nil {
sendProcessWSError(conn, err.Error())
sendWSError(conn, err.Error())
}
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
}
}
func sendProcessWSError(conn *websocket.Conn, msg string) {
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
updateLastActive(h.db, sandboxID, sandboxIDStr)
}

View File

@ -30,13 +30,12 @@ const (
)
type ptyHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
return &ptyHandler{db: db, pool: pool}
}
// --- WebSocket message types ---
@ -90,40 +89,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
return
}
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
slog.Error("pty websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -168,18 +136,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) handleStart(

View File

@ -37,8 +37,8 @@ type sandboxResponse struct {
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIP string `json:"guest_ip,omitempty"`
HostIP string `json:"host_ip,omitempty"`
DiskSizeMB int32 `json:"disk_size_mb"`
DiskUsedMB *int64 `json:"disk_used_mb,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
LastActiveAt *string `json:"last_active_at,omitempty"`
@ -54,8 +54,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
VCPUs: sb.Vcpus,
MemoryMB: sb.MemoryMb,
TimeoutSec: sb.TimeoutSec,
GuestIP: sb.GuestIp,
HostIP: sb.HostIp,
DiskSizeMB: sb.DiskSizeMb,
}
if len(sb.Metadata) > 0 {
var meta map[string]string
@ -101,14 +100,17 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
MemoryMB: req.MemoryMB,
TimeoutSec: req.TimeoutSec,
})
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
if err != nil {
if sb.ID.Valid {
h.audit.LogSandboxDestroySystem(r.Context(), ac.TeamID, sb.ID, "cleanup_after_create_error", nil)
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/capsules.
@ -145,7 +147,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
return
}
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
resp := sandboxToResponse(sb)
diskBytes, err := h.svc.GetDiskUsage(r.Context(), sandboxID, ac.TeamID)
if err == nil {
diskUsedMB := diskBytes / (1024 * 1024)
resp.DiskUsedMB = &diskUsedMB
}
writeJSON(w, http.StatusOK, resp)
}
// Pause handles POST /v1/capsules/{id}/pause.
@ -160,14 +170,14 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxPause(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// Resume handles POST /v1/capsules/{id}/resume.
@ -182,14 +192,14 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxResume(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// Ping handles POST /v1/capsules/{id}/ping.
@ -223,12 +233,13 @@ func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
err = h.svc.Destroy(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
w.WriteHeader(http.StatusAccepted)
}

View File

@ -0,0 +1,169 @@
package api
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
type sandboxEventHandler struct {
db *db.Queries
eventPub *channels.Publisher
}
func newSandboxEventHandler(queries *db.Queries, eventPub *channels.Publisher) *sandboxEventHandler {
return &sandboxEventHandler{db: queries, eventPub: eventPub}
}
type sandboxEventRequest struct {
Event string `json:"event"`
SandboxID string `json:"sandbox_id"`
HostID string `json:"host_id"`
HostIP string `json:"host_ip,omitempty"`
Error string `json:"error,omitempty"`
Timestamp int64 `json:"timestamp"`
}
// Handle receives lifecycle event callbacks from host agents, translates the
// raw host event into the canonical events.Event taxonomy, and publishes once
// to the unified Redis stream. The SandboxEventConsumer (independent
// consumer group) drives DB reconciliation; the channels dispatcher delivers
// to subscribed channels; the SSE relay mirrors via Pub/Sub.
func (h *sandboxEventHandler) Handle(w http.ResponseWriter, r *http.Request) {
var req sandboxEventRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Event == "" || req.SandboxID == "" || req.HostID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "event, sandbox_id, and host_id are required")
return
}
hc := auth.MustHostFromContext(r.Context())
callerHostID := id.FormatHostID(hc.HostID)
if callerHostID != req.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host_id does not match authenticated host")
return
}
if req.Timestamp == 0 {
req.Timestamp = time.Now().Unix()
}
evt, ok := h.translate(r.Context(), req)
if !ok {
// Unknown event type — log and accept so the host agent doesn't retry.
slog.Warn("sandbox event callback: untranslatable event", "event", req.Event, "sandbox_id", req.SandboxID)
w.WriteHeader(http.StatusNoContent)
return
}
h.eventPub.Publish(r.Context(), evt)
w.WriteHeader(http.StatusNoContent)
}
// translate converts a raw host-agent callback into the canonical event.
// For failure events without an in-flight verb (e.g. sandbox.failed), the
// current DB status is consulted to pick the appropriate verb.
func (h *sandboxEventHandler) translate(ctx context.Context, req sandboxEventRequest) (events.Event, bool) {
sandboxUUID, parseErr := id.ParseSandboxID(req.SandboxID)
if parseErr != nil {
return events.Event{}, false
}
var teamID pgtype.UUID
if sb, dbErr := h.db.GetSandbox(ctx, sandboxUUID); dbErr == nil {
teamID = sb.TeamID
}
base := events.Event{
Timestamp: time.Unix(req.Timestamp, 0).UTC().Format(time.RFC3339),
TeamID: id.FormatTeamID(teamID),
Actor: events.SystemActor(),
Resource: events.Resource{ID: req.SandboxID, Type: "sandbox"},
}
switch req.Event {
case "sandbox.started":
meta := map[string]string{}
if req.HostIP != "" {
meta["host_ip"] = req.HostIP
}
meta["host_id"] = req.HostID
base.Event = events.CapsuleCreate
base.Outcome = events.OutcomeSuccess
base.Metadata = meta
case "sandbox.resumed":
meta := map[string]string{"host_id": req.HostID}
if req.HostIP != "" {
meta["host_ip"] = req.HostIP
}
base.Event = events.CapsuleResume
base.Outcome = events.OutcomeSuccess
base.Metadata = meta
case "sandbox.paused":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeSuccess
case "sandbox.auto_paused":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeSuccess
base.Metadata = map[string]string{"reason": "ttl_expired"}
case "sandbox.stopped":
base.Event = events.CapsuleDestroy
base.Outcome = events.OutcomeSuccess
case "sandbox.pause_failed":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
case "sandbox.resume_failed":
base.Event = events.CapsuleResume
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
case "sandbox.failed", "sandbox.error":
// Pick a verb based on the sandbox's current DB status.
verb := h.verbForFailure(ctx, sandboxUUID)
base.Event = verb
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
default:
return events.Event{}, false
}
return base, true
}
func (h *sandboxEventHandler) verbForFailure(ctx context.Context, sandboxID pgtype.UUID) string {
sb, err := h.db.GetSandbox(ctx, sandboxID)
if err != nil {
return events.CapsuleDestroy
}
switch sb.Status {
case "starting":
return events.CapsuleCreate
case "resuming":
return events.CapsuleResume
case "pausing":
return events.CapsulePause
case "snapshotting":
// A snapshot pauses then resumes the VM; a host-side failure leaves the
// sandbox errored, not destroyed. Route through CapsuleCreate so the
// consumer's handleFailed marks it "error" rather than removing the row.
return events.CapsuleCreate
default:
return events.CapsuleDestroy
}
}

View File

@ -0,0 +1,66 @@
package api
import (
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
)
type sessionsHandler struct {
sessions *session.Service
}
func newSessionsHandler(svc *session.Service) *sessionsHandler {
return &sessionsHandler{sessions: svc}
}
// List handles GET /v1/me/sessions — returns all active sessions for the
// current user, flagging the caller's own session as current.
func (h *sessionsHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
rows, err := h.sessions.ListForUser(r.Context(), ac.UserID)
if err != nil {
slog.Error("list sessions: db error", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sessions")
return
}
out := make([]sessionRow, 0, len(rows))
for _, row := range rows {
out = append(out, sessionRow{
ID: row.ID,
UserAgent: row.UserAgent,
IPAddress: row.IpAddress,
CreatedAt: row.CreatedAt.Time.UTC().Format(time.RFC3339),
LastSeenAt: row.LastSeenAt.Time.UTC().Format(time.RFC3339),
ExpiresAt: row.ExpiresAt.Time.UTC().Format(time.RFC3339),
Current: row.ID == ac.SessionID,
})
}
writeJSON(w, http.StatusOK, map[string]any{"sessions": out})
}
// Delete handles DELETE /v1/me/sessions/{id} — revokes a single session
// belonging to the current user. If the caller revokes their own session,
// their cookies are cleared on the response.
func (h *sessionsHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
sid := chi.URLParam(r, "id")
if sid == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "missing session id")
return
}
if err := h.sessions.DeleteForUser(r.Context(), sid, ac.UserID); err != nil {
slog.Error("delete session: db error", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete session")
return
}
if sid == ac.SessionID {
clearSessionCookies(w, isSecure(r))
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
@ -25,49 +24,53 @@ import (
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
svc *service.TemplateService
sandboxSvc *service.SandboxService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
// deleteSnapshotEverywhere removes a template's files from every active host.
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
//
// It is strict by design: deletion is reported successful only when every
// active host has either removed the files or reported NotFound (it never
// held them). If any host is offline or returns an error, it returns an error
// and the caller MUST NOT delete the DB record — doing so would orphan the
// files on disk with no record left to retry against.
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
hosts, err := queries.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
id.FormatHostID(host.ID), host.Status)
}
agent, err := pool.GetForHost(host)
if err != nil {
continue
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
// NotFound just means this host never held the template.
if connect.CodeOf(err) == connect.CodeNotFound {
continue
}
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
}
}
return nil
}
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
type snapshotResponse struct {
Name string `json:"name"`
Type string `json:"type"`
@ -76,6 +79,7 @@ type snapshotResponse struct {
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
Protected bool `json:"protected"`
Metadata map[string]string `json:"metadata,omitempty"`
}
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
}
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
return resp
}
// Create handles POST /v1/snapshots.
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
// registers the result as a new template.
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.SandboxID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
ac := auth.MustFromContext(r.Context())
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
ac := auth.MustFromContext(ctx)
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
// template is registered by a background goroutine; clients learn of the
// result via the SSE template.snapshot.create event (or by polling).
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/snapshots.
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
return
}
// Resolve actual on-disk sizes for templates with unknown size (e.g.
// system base templates seeded with size_bytes = 0). This queries a host
// agent and persists the result to the DB for subsequent requests.
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
resp := make([]snapshotResponse, len(templates))
for i, t := range templates {
resp[i] = templateToResponse(t)
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
return
}
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusConflict, "delete_failed",
"could not remove snapshot files from all hosts: "+err.Error())
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
h.audit.LogSnapshotDelete(r.Context(), ac, name)
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,79 @@
package api
import (
"fmt"
"net/http"
"time"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const sseKeepaliveInterval = 30 * time.Second
type sseHandler struct {
broker *SSEBroker
}
func newSSEHandler(broker *SSEBroker) *sseHandler {
return &sseHandler{broker: broker}
}
// Stream handles GET /v1/events/stream. Authentication is performed by
// upstream middleware: browser clients use the wrenn_sid cookie (which the
// EventSource API forwards automatically); SDK clients use X-API-Key.
func (h *sseHandler) Stream(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
return
}
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), false)
}
// AdminStream handles GET /v1/admin/events/stream. Upstream middleware
// must enforce session auth + requireAdmin before reaching this handler.
func (h *sseHandler) AdminStream(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "admin session required")
return
}
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), true)
}
func (h *sseHandler) serveSSE(w http.ResponseWriter, r *http.Request, teamID string, isAdmin bool) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
subID, ch := h.broker.Subscribe(teamID, isAdmin)
defer h.broker.Unsubscribe(subID)
fmt.Fprintf(w, "event: connected\ndata: {\"message\":\"connected\"}\n\n")
flusher.Flush()
keepalive := time.NewTicker(sseKeepaliveInterval)
defer keepalive.Stop()
ctx := r.Context()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", msg.EventType, msg.Data)
flusher.Flush()
case <-keepalive.C:
fmt.Fprintf(w, ": keepalive\n\n")
flusher.Flush()
}
}
}

View File

@ -14,19 +14,21 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type teamHandler struct {
svc *service.TeamService
audit *audit.AuditLogger
mailer email.Mailer
svc *service.TeamService
audit *audit.AuditLogger
mailer email.Mailer
sessions *session.Service
}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
return &teamHandler{svc: svc, audit: al, mailer: mailer}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer, sessions *session.Service) *teamHandler {
return &teamHandler{svc: svc, audit: al, mailer: mailer, sessions: sessions}
}
// teamResponse is the JSON shape for a team.
@ -366,6 +368,11 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
return
}
// Drop cached session blobs so the new role propagates immediately.
if err := h.sessions.InvalidateCacheForUser(r.Context(), targetUserID); err != nil {
_ = err
}
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
w.WriteHeader(http.StatusNoContent)
}
@ -450,6 +457,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
OwnerEmail string `json:"owner_email"`
ActiveSandboxCount int32 `json:"active_sandbox_count"`
ChannelCount int32 `json:"channel_count"`
RunningVcpus int32 `json:"running_vcpus"`
RunningMemoryMb int32 `json:"running_memory_mb"`
}
resp := make([]adminTeamResponse, len(teams))
@ -465,6 +474,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
OwnerEmail: t.OwnerEmail,
ActiveSandboxCount: t.ActiveSandboxCount,
ChannelCount: t.ChannelCount,
RunningVcpus: t.RunningVcpus,
RunningMemoryMb: t.RunningMemoryMb,
}
if t.DeletedAt != nil {
s := t.DeletedAt.Format(time.RFC3339)

View File

@ -11,19 +11,21 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type usersHandler struct {
db *db.Queries
svc *service.UserService
audit *audit.AuditLogger
db *db.Queries
svc *service.UserService
audit *audit.AuditLogger
sessions *session.Service
}
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger) *usersHandler {
return &usersHandler{db: db, svc: svc, audit: al}
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger, sessions *session.Service) *usersHandler {
return &usersHandler{db: db, svc: svc, audit: al, sessions: sessions}
}
// Search handles GET /v1/users/search?email=<prefix>
@ -158,6 +160,10 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
if req.Active {
h.audit.LogUserActivate(r.Context(), ac, userID, user.Email)
} else {
// Disabled users must be kicked out of every active session.
if err := h.sessions.RevokeAllForUser(r.Context(), userID); err != nil {
_ = err
}
h.audit.LogUserDeactivate(r.Context(), ac, userID, user.Email)
}
w.WriteHeader(http.StatusNoContent)
@ -215,5 +221,14 @@ func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
}
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
}
// Invalidate cached session blobs so the new is_admin flag is reflected
// on the user's next request without waiting for the Redis TTL.
if err := h.sessions.InvalidateCacheForUser(r.Context(), userID); err != nil {
// Cache is best-effort; the DB is authoritative and requireAdmin
// always re-reads it.
_ = err
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,102 +0,0 @@
package api
import (
"context"
"fmt"
"time"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}
// setAdminWSFlag marks the context as an admin WebSocket route.
func setAdminWSFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
}
// isAdminWSRoute checks if the request context was marked as admin WS.
func isAdminWSRoute(ctx context.Context) bool {
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
return v
}
// wsAuthMsg is the first message a browser WS client sends to authenticate.
type wsAuthMsg struct {
Type string `json:"type"`
Token string `json:"token"`
}
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
// authenticated context. The caller must send this as the first message after
// connecting.
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg wsAuthMsg
if err := conn.ReadJSON(&msg); err != nil {
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
}
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
if msg.Type != "auth" || msg.Token == "" {
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
}
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
}
user, err := queries.GetUserByID(ctx, userID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if user.Status != "active" {
return auth.AuthContext{}, fmt.Errorf("account deactivated")
}
return auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
}, nil
}
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
// returning an AuthContext with the platform team ID.
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
if err != nil {
return auth.AuthContext{}, err
}
user, err := queries.GetUserByID(ctx, ac.UserID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if !user.IsAdmin {
return auth.AuthContext{}, fmt.Errorf("admin access required")
}
ac.TeamID = id.PlatformTeamID
return ac, nil
}

View File

@ -2,6 +2,7 @@ package api
import (
"context"
"errors"
"log/slog"
"time"
@ -15,10 +16,30 @@ import (
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// errInferredTransientTimeout marks a state change that the reconciler
// inferred after a transient (starting/resuming) sandbox failed to settle
// within the grace period. Used as the err value on system audit calls so
// the published event carries Outcome=error with a human-readable message.
var errInferredTransientTimeout = errors.New("transient state did not settle within grace period")
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// transientGracePeriod is how long a sandbox is allowed to stay in a transient
// status (starting, resuming, pausing, stopping) before the monitor infers a
// final state. This prevents the monitor from racing against in-flight RPCs
// that may not have registered the sandbox on the host agent yet.
const transientGracePeriod = 2 * time.Minute
// snapshotGracePeriod is the grace for a sandbox stuck in "snapshotting" while
// the VM is still alive on the host. Snapshots dump guest RAM and flatten the
// rootfs, which can run for minutes on large sandboxes, and the agent reports
// the VM as alive throughout — so we must not race the in-flight operation.
// It exceeds the background goroutine's 10-minute deadline, so reaching it
// means the control plane crashed mid-snapshot and the sandbox needs recovery.
const snapshotGracePeriod = 15 * time.Minute
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
@ -77,6 +98,21 @@ func (m *HostMonitor) run(ctx context.Context) {
}
}
// ReconcileHost triggers immediate active reconciliation for a single host.
// Called when a host transitions from unreachable → online so sandboxes marked
// "missing" are resolved without waiting for the next monitor tick.
func (m *HostMonitor) ReconcileHost(ctx context.Context, hostID pgtype.UUID) {
host, err := m.db.GetHost(ctx, hostID)
if err != nil {
slog.Warn("host monitor: reconcile-on-connect: failed to get host", "error", err)
return
}
if host.Status != "online" {
return
}
m.checkHost(ctx, host)
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
@ -116,21 +152,29 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
// Build map of sandbox ID -> reported status. Transient statuses
// (pausing/resuming/starting/stopping) are coerced to a presence-only
// entry: ListSandboxes can observe the in-memory status mid-transition
// (Pause flips the status under m.mu while List holds m.mu.RLock), and
// writing those transient labels into the DB would force the transient
// reconciliation phase to wait the full grace period before resolving.
// Recording the presence keeps "missing → restore" and "running →
// orphan-stop" logic correct without overwriting with stale labels;
// the next monitor tick reads the settled status.
aliveStatus := make(map[string]string, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
status := sb.Status
switch status {
case "pausing", "resuming", "starting", "stopping":
status = ""
}
aliveStatus[sb.SandboxId] = status
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
// Handles transient heartbeat gaps where the host was actually fine. The
// reported status must be honored: a sandbox the agent paused while CP
// was disconnected must not be silently promoted back to running.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
@ -139,34 +183,65 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
restoreByStatus := make(map[string][]db.Sandbox)
var toStop []db.Sandbox
for _, sb := range missingSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
status, ok := aliveStatus[sbIDStr]
if !ok {
toStop = append(toStop, sb)
continue
}
if status == "" {
continue
}
restoreByStatus[status] = append(restoreByStatus[status], sb)
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
for status, sbs := range restoreByStatus {
ids := make([]pgtype.UUID, len(sbs))
for i, sb := range sbs {
ids[i] = sb.ID
}
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
if err := m.db.BulkRestoreMissingToStatus(ctx, db.BulkRestoreMissingToStatusParams{
Column1: ids,
Status: status,
}); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
continue
}
// Only restore→paused emits a notification (per design: running restore is silent).
if status == "paused" {
for _, sb := range sbs {
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "restored_after_host_recovery", nil)
}
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
ids := make([]pgtype.UUID, len(toStop))
for i, sb := range toStop {
ids[i] = sb.ID
}
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Column1: ids,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range toStop {
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
}
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
// --- Reconcile running sandboxes in DB against live host state ---
// Three cases per DB-running row:
// absent on host -> stopped
// present and running -> no change
// present but paused/etc. -> sync DB to reported status (catches the
// shutdown-pause notify failure case)
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
@ -177,40 +252,196 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
return
}
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
var toStop []db.Sandbox
syncByStatus := make(map[string][]db.Sandbox)
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sbIDStr]; ok {
status, ok := aliveStatus[sbIDStr]
if !ok {
toStop = append(toStop, sb)
continue
}
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
if status == "running" || status == "" {
continue
}
syncByStatus[status] = append(syncByStatus[status], sb)
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
ids := make([]pgtype.UUID, len(toStop))
for i, sb := range toStop {
ids[i] = sb.ID
}
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Column1: ids,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range toStop {
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
}
}
}
for status, sbs := range syncByStatus {
ids := make([]pgtype.UUID, len(sbs))
for i, sb := range sbs {
ids[i] = sb.ID
}
slog.Info("host monitor: syncing running→reported status", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: ids,
Status: status,
}); err != nil {
slog.Warn("host monitor: failed to sync running sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
continue
}
if status == "paused" {
for _, sb := range sbs {
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "host_state_sync", nil)
}
}
}
// --- Reconcile DB-stopped + agent-paused zombies ---
// A sandbox the agent reports as 'paused' but DB has as 'stopped' is an
// orphan from a previous bug where a successful pause's auto_paused
// callback was lost (e.g. CP unreachable during agent shutdown). With the
// agent-side fix (RestorePausedSandboxes), the snapshot survives across
// agent restarts and surfaces here. Authoritative direction: DB wins
// (user already saw 'stopped' and may have stopped tracking it).
// Issue Destroy so the on-disk snapshot dir is removed and the agent's
// slot reservation released.
//
// Gate: only run the DB query if the agent reports at least one paused
// sandbox. Otherwise we'd fetch every historically-stopped sandbox on
// this host every monitor tick — unbounded growth over a host's lifetime.
hasPaused := false
for _, status := range aliveStatus {
if status == "paused" {
hasPaused = true
break
}
}
if hasPaused {
stoppedSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"stopped"},
})
if err != nil {
slog.Warn("host monitor: failed to list stopped sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range stoppedSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
status, ok := aliveStatus[sbIDStr]
if !ok || status != "paused" {
continue
}
slog.Info("host monitor: destroying DB-stopped agent-paused zombie",
"host_id", id.FormatHostID(host.ID), "sandbox_id", sbIDStr)
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sbIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("host monitor: zombie destroy failed",
"sandbox_id", sbIDStr, "error", err)
continue
}
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "paused_zombie_cleanup", nil)
}
}
}
// --- Reconcile transient statuses (starting, resuming, pausing, stopping) ---
// These represent in-flight operations. If the sandbox is no longer alive on
// the host, infer the final state based on the transient status.
transientSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"starting", "resuming", "pausing", "stopping", "snapshotting"},
})
if err != nil {
slog.Warn("host monitor: failed to list transient sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
for _, sb := range transientSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if agentStatus, ok := aliveStatus[sbIDStr]; ok {
// Sandbox is alive on host — the background goroutine should
// finalize the transition. For starting/resuming, if the sandbox
// is alive it means creation/resume succeeded.
if sb.Status == "starting" || sb.Status == "resuming" {
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: sb.Status, Status_2: "running",
}); err == nil {
slog.Info("host monitor: promoted transient sandbox to running", "sandbox_id", sbIDStr, "from", sb.Status)
}
}
// A snapshot keeps the source sandbox alive throughout, so an alive
// sandbox does NOT mean the snapshot finished. Only recover it once
// it has been stuck past the snapshot grace period (i.e. the CP
// crashed mid-op). Recover to the sandbox's actual host-side status:
// a running sandbox is snapshotted live and stays running, but a
// paused sandbox is snapshotted from disk and must return to paused.
if sb.Status == "snapshotting" &&
sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) >= snapshotGracePeriod {
recoverTo := agentStatus
if recoverTo != "running" && recoverTo != "paused" {
// Coerced/unknown agent label — default to running.
recoverTo = "running"
}
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: "snapshotting", Status_2: recoverTo,
}); err == nil {
slog.Info("host monitor: recovered stuck snapshotting sandbox", "sandbox_id", sbIDStr, "to", recoverTo)
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "snapshot_recovered", nil)
}
}
continue
}
// Sandbox is not alive on host. If the transition is recent, give the
// in-flight RPC time to finish before declaring a final state.
if sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) < transientGracePeriod {
slog.Debug("host monitor: transient sandbox still within grace period",
"sandbox_id", sbIDStr, "status", sb.Status,
"age", time.Since(sb.LastUpdated.Time).Round(time.Second))
continue
}
// Grace period expired — infer final state.
var finalStatus string
switch sb.Status {
case "starting", "resuming":
finalStatus = "error"
case "pausing":
finalStatus = "paused"
case "stopping":
finalStatus = "stopped"
case "snapshotting":
// VM is gone but DB says snapshotting → the snapshot died with the VM.
finalStatus = "error"
}
fromStatus := sb.Status
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: fromStatus, Status_2: finalStatus,
}); err == nil {
slog.Info("host monitor: resolved transient sandbox", "sandbox_id", sbIDStr, "from", fromStatus, "to", finalStatus)
inferredErr := errInferredTransientTimeout
switch fromStatus {
case "starting":
m.audit.LogSandboxCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "resuming":
m.audit.LogSandboxResumeSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "pausing":
// Pause assumed to have succeeded host-side; emit success with inferred metadata.
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
case "snapshotting":
// VM gone mid-snapshot; the sandbox is errored.
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "stopping":
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
}
}
}
}

View File

@ -50,7 +50,9 @@ func agentErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", err.Error()
case connect.CodeInvalidArgument:
return http.StatusBadRequest, "invalid_request", err.Error()
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
case connect.CodeAlreadyExists:
return http.StatusConflict, "already_exists", err.Error()
case connect.CodeFailedPrecondition:
return http.StatusConflict, "conflict", err.Error()
case connect.CodePermissionDenied:
return http.StatusForbidden, "forbidden", err.Error()

View File

@ -26,19 +26,10 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
}
}
// markAdminWS flags the request context as an admin WebSocket route.
// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin
// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin.
func markAdminWS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context())))
})
}
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// Must run after requireSession (depends on AuthContext being present).
// Re-validates against the DB — the session blob's is_admin flag is for
// UI hints; the DB is the source of truth for admin access.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -1,145 +0,0 @@
package api
import (
"log/slog"
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
// Both stamp TeamID into the request context via auth.AuthContext.
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try API key first.
if key := r.Header.Get("X-API-Key"); key != "" {
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Try JWT bearer token from Authorization header.
tokenStr := ""
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr != "" {
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
if err != nil {
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject
// unauthenticated requests. It injects auth context when valid credentials
// are present (supporting SDK clients that set X-API-Key on WebSocket
// upgrades) and passes through otherwise so the handler can authenticate
// after the WebSocket upgrade via the first message.
func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try API key.
if key := r.Header.Get("X-API-Key"); key != "" {
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err == nil {
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// Try JWT bearer token.
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr := strings.TrimPrefix(header, "Bearer ")
if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil {
if teamID, err := id.ParseTeamID(claims.TeamID); err == nil {
if userID, err := id.ParseUserID(claims.Subject); err == nil {
if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" {
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
}
}
}
// No valid credentials — pass through for handler to authenticate.
next.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,38 @@
package api
import (
"crypto/subtle"
"net/http"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
)
// requireCSRF enforces double-submit CSRF: the wrenn_csrf cookie value must
// equal the X-CSRF-Token header. Skipped for safe methods (GET/HEAD/OPTIONS)
// and for requests authenticated via X-API-Key (SDK clients are not
// vulnerable to cross-site request forgery against cookie auth).
func requireCSRF() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, r)
return
}
// API-key auth path: no CSRF check needed.
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
next.ServeHTTP(w, r)
return
}
cookie, err := r.Cookie(csrfCookieName)
header := r.Header.Get("X-CSRF-Token")
if err != nil || cookie.Value == "" || header == "" ||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,69 +0,0 @@
package api
import (
"log/slog"
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireJWT validates a JWT from the Authorization: Bearer header.
// It also verifies the user is still active in the database.
// WebSocket upgrade requests without an Authorization header are passed through
// — WS handlers authenticate via the first message after upgrade.
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var tokenStr string
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}
claims, err := auth.VerifyJWT(secret, tokenStr)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
IsAdmin: claims.IsAdmin,
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,34 @@
package api
import (
"net/http"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
sessionmw "git.omukk.dev/wrenn/wrenn/pkg/auth/session/middleware"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// Internal aliases — the canonical implementations live in the public
// pkg/auth/session/middleware package so cloud extensions can call them.
const csrfCookieName = sessionmw.CSRFCookieName
func requireSession(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
return sessionmw.RequireSession(svc, queries)
}
func requireSessionOrAPIKey(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
return sessionmw.RequireSessionOrAPIKey(svc, queries)
}
func setSessionCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
sessionmw.SetCookies(w, sid, csrfToken, secure)
}
func clearSessionCookies(w http.ResponseWriter, secure bool) {
sessionmw.ClearCookies(w, secure)
}
func isSecure(r *http.Request) bool {
return sessionmw.IsSecure(r)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,310 @@
package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const (
unifiedEventStream = "wrenn:events"
reconcilerConsumerGrp = "wrenn-sandbox-reconciler-v1"
reconcilerConsumer = "cp-0"
)
// SandboxEventConsumer reads capsule lifecycle events from the unified Redis
// stream and drives DB state reconciliation. Uses an independent consumer
// group so its cursor is separate from the channels dispatcher.
type SandboxEventConsumer struct {
rdb *redis.Client
db *db.Queries
audit *audit.AuditLogger
hooks []cpextension.SandboxEventHook
}
// NewSandboxEventConsumer creates a consumer.
func NewSandboxEventConsumer(rdb *redis.Client, queries *db.Queries, al *audit.AuditLogger, hooks []cpextension.SandboxEventHook) *SandboxEventConsumer {
return &SandboxEventConsumer{rdb: rdb, db: queries, audit: al, hooks: hooks}
}
// Start launches the consumer goroutine. Reads from "$" so prior history
// is not replayed.
func (c *SandboxEventConsumer) Start(ctx context.Context) {
go c.run(ctx)
}
func (c *SandboxEventConsumer) run(ctx context.Context) {
err := c.rdb.XGroupCreateMkStream(ctx, unifiedEventStream, reconcilerConsumerGrp, "$").Err()
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
slog.Error("sandbox event consumer: failed to create consumer group", "error", err)
return
}
for {
select {
case <-ctx.Done():
return
default:
}
streams, err := c.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: reconcilerConsumerGrp,
Consumer: reconcilerConsumer,
Streams: []string{unifiedEventStream, ">"},
Count: 10,
Block: 5 * time.Second,
}).Result()
if err != nil {
if err == redis.Nil || ctx.Err() != nil {
continue
}
slog.Warn("sandbox event consumer: xreadgroup error", "error", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, msg := range stream.Messages {
c.handleMessage(ctx, msg)
}
}
}
}
func (c *SandboxEventConsumer) handleMessage(ctx context.Context, msg redis.XMessage) {
ack := true
defer func() {
if !ack {
return
}
ackCtx, ackCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer ackCancel()
if err := c.rdb.XAck(ackCtx, unifiedEventStream, reconcilerConsumerGrp, msg.ID).Err(); err != nil {
slog.Warn("sandbox event consumer: xack failed", "id", msg.ID, "error", err)
}
}()
payload, ok := msg.Values["payload"].(string)
if !ok {
slog.Warn("sandbox event consumer: message missing payload", "id", msg.ID)
return
}
var event events.Event
if err := json.Unmarshal([]byte(payload), &event); err != nil {
slog.Warn("sandbox event consumer: failed to unmarshal event", "id", msg.ID, "error", err)
return
}
// Only capsule.* events drive sandbox reconciliation.
if !strings.HasPrefix(event.Event, "capsule.") || event.Event == events.CapsuleStateChanged {
return
}
// Only system-actor events represent host-side state we need to reflect
// in the DB; user-actor events are already mirrored by the handler that
// produced them.
if event.Actor.Type != events.ActorSystem {
// Exception: handlers publish capsule.create with user actor before
// the host has reported back. Those are owned by the service goroutine.
return
}
sandboxID, err := id.ParseSandboxID(event.Resource.ID)
if err != nil {
slog.Warn("sandbox event consumer: invalid sandbox ID", "sandbox_id", event.Resource.ID, "error", err)
return
}
switch event.Event {
case events.CapsuleCreate:
if event.Outcome == events.OutcomeSuccess {
c.handleStarted(ctx, sandboxID, event, "starting")
} else {
c.handleFailed(ctx, sandboxID, event)
}
case events.CapsuleResume:
if event.Outcome == events.OutcomeSuccess {
c.handleStarted(ctx, sandboxID, event, "resuming")
} else {
c.handleFailed(ctx, sandboxID, event)
}
case events.CapsulePause:
if event.Outcome == events.OutcomeSuccess {
c.handleAutoPaused(ctx, sandboxID)
}
case events.CapsuleDestroy:
if event.Outcome == events.OutcomeSuccess {
c.handleStopped(ctx, sandboxID)
}
}
// Dispatch to extension hooks (cloud billing, audit shipping, etc.). Any
// hook error suppresses the ack so the message will be redelivered. Hooks
// MUST be idempotent — duplicate deliveries are expected on transient
// failures.
if len(c.hooks) > 0 && event.Outcome == events.OutcomeSuccess {
if verb, ok := canonicalSandboxVerb(event.Event); ok {
teamID, _ := id.ParseTeamID(event.TeamID)
meta := map[string]any{}
for k, v := range event.Metadata {
meta[k] = v
}
ev := cpextension.SandboxEvent{
SandboxID: sandboxID,
TeamID: teamID,
Type: verb,
OccurredAt: parseEventTimestamp(event.Timestamp),
Metadata: meta,
}
for _, h := range c.hooks {
if err := h.OnSandboxEvent(ctx, ev); err != nil {
slog.Warn("sandbox event hook failed; leaving message un-acked", "id", msg.ID, "event", event.Event, "error", err)
ack = false
return
}
}
}
}
}
func canonicalSandboxVerb(event string) (string, bool) {
switch event {
case events.CapsuleCreate:
return "created", true
case events.CapsuleResume:
return "resumed", true
case events.CapsulePause:
return "paused", true
case events.CapsuleDestroy:
return "destroyed", true
}
return "", false
}
func parseEventTimestamp(s string) time.Time {
if s == "" {
return time.Now().UTC()
}
t, err := time.Parse(time.RFC3339, s)
if err != nil {
return time.Now().UTC()
}
return t
}
// handleStarted is a fallback writer for capsule.create.success and
// capsule.resume.success. The background goroutine in SandboxService is the
// primary writer; this only succeeds if the goroutine's conditional update
// was missed.
func (c *SandboxEventConsumer) handleStarted(ctx context.Context, sandboxID pgtype.UUID, event events.Event, fromStatus string) {
hostIP := event.Metadata["host_ip"]
now := time.Now()
if _, err := c.db.UpdateSandboxRunningIf(ctx, db.UpdateSandboxRunningIfParams{
ID: sandboxID,
Status: fromStatus,
HostIp: hostIP,
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
}); err != nil {
return
}
}
func (c *SandboxEventConsumer) handleAutoPaused(ctx context.Context, sandboxID pgtype.UUID) {
for _, fromStatus := range []string{"running", "pausing"} {
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID, Status: fromStatus, Status_2: "paused",
}); err == nil {
slog.Debug("sandbox event consumer: auto-paused fallback applied", "sandbox_id", id.FormatSandboxID(sandboxID), "from", fromStatus)
return
}
}
}
func (c *SandboxEventConsumer) handleStopped(ctx context.Context, sandboxID pgtype.UUID) {
// stopping → stopped (CP-initiated destroy completed). No audit row here;
// the handler that issued the destroy already wrote one.
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID,
Status: "stopping",
Status_2: "stopped",
}); err == nil {
return
}
// running → stopped (autonomous destroy, e.g. TTL destroy fallback).
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID,
Status: "running",
Status_2: "stopped",
}); err != nil && !errors.Is(err, pgx.ErrNoRows) {
slog.Warn("sandbox event consumer: failed to update sandbox to stopped", "sandbox_id", id.FormatSandboxID(sandboxID), "error", err)
}
}
// handleFailed marks a sandbox as "error" when a verb event reports failure
// and writes a system audit row. The DB update is idempotent — the
// SandboxService background goroutine usually wrote "error" already on the
// fast-fail path, which settles in seconds and so never reaches the
// HostMonitor's transient-timeout reconciliation.
//
// audit.Log writes the row only — it does NOT republish an event, which would
// loop back into this consumer. Do not switch to LogSandboxCreateSystem here.
func (c *SandboxEventConsumer) handleFailed(ctx context.Context, sandboxID pgtype.UUID, event events.Event) {
for _, fromStatus := range []string{"running", "starting", "pausing", "resuming", "snapshotting"} {
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID, Status: fromStatus, Status_2: "error",
}); err == nil {
break
}
}
// The HostMonitor transient-timeout reconciler emits failure events via
// LogSandboxCreateSystem / LogSandboxResumeSystem, which already write
// their own audit row before publishing — auditing again here would
// double-count. Those helpers publish with reason="transient_timeout";
// the un-audited fast-fail (createInBackground) and host-callback paths
// do not, so only they need a row written here.
if event.Metadata["reason"] == "transient_timeout" {
return
}
action := "create"
if event.Event == events.CapsuleResume {
action = "resume"
}
reason := event.Metadata["reason"]
if reason == "" {
reason = action + "_failed"
}
meta := map[string]any{"reason": reason}
if event.Error != "" {
meta["error"] = event.Error
}
teamID, _ := id.ParseTeamID(event.TeamID)
c.audit.Log(ctx, audit.Entry{
TeamID: teamID,
ActorType: "system",
ResourceType: "sandbox",
ResourceID: id.FormatSandboxID(sandboxID),
Action: action,
Scope: "team",
Status: "error",
Metadata: meta,
})
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
_ "embed"
"fmt"
"net/http"
@ -13,9 +14,11 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
authsession "git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
"git.omukk.dev/wrenn/wrenn/pkg/service"
@ -26,15 +29,19 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
router chi.Router
BuildSvc *service.BuildService
version string
router chi.Router
BuildSvc *service.BuildService
SSERelay *SSERelay
SessionSvc *authsession.Service
version string
}
// New constructs the chi router and registers all routes.
// Extensions are called after core routes are registered, allowing cloud
// or third-party code to add routes and middleware.
// New constructs the chi router and registers all routes. The jwtSecret is
// still used to sign host-agent JWTs (long-lived, for the wrenn-agent → cp
// trust path) and to HMAC OAuth state/link cookies; user authentication has
// migrated to opaque session cookies backed by the session service.
func New(
ctx context.Context,
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
@ -45,10 +52,12 @@ func New(
oauthRedirectURL string,
ca *auth.CA,
al *audit.AuditLogger,
eventPub *channels.Publisher,
channelSvc *channels.Service,
mailer email.Mailer,
extensions []cpextension.Extension,
sctx cpextension.ServerContext,
monitor *HostMonitor,
version string,
) *Server {
r := chi.NewRouter()
@ -61,8 +70,28 @@ func New(
}
}
// Session service backs cookie-based browser auth. The cpserver wires it
// through ServerContext so cloud-repo extensions can share the same
// instance; fall back to constructing one here if the host program
// instantiates the API directly (tests, ad-hoc tooling).
sessionSvc := sctx.Sessions
if sessionSvc == nil {
sessionSvc = authsession.NewService(queries, rdb)
}
// Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
sandboxSvc.PublishEvent = func(ctx context.Context, event service.SandboxStateEvent) {
if evt, ok := serviceEventToCanonical(event); ok {
// State-change events are ephemeral UI signals — mirror them to the
// dashboard via Pub/Sub only, never to durable channel subscribers.
if evt.Event == events.CapsuleStateChanged {
eventPub.PublishTransient(ctx, evt)
} else {
eventPub.Publish(ctx, evt)
}
}
}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
@ -72,30 +101,40 @@ func New(
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
usageSvc := &service.UsageService{DB: queries}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
buildBroker := service.NewBuildBroker(rdb)
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool, jwtSecret)
execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
fsH := newFSHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
snapshots := newSnapshotHandler(templateSvc, sandboxSvc, queries, pool, al)
authHooks := collectAuthHooks(extensions)
authH := newAuthHandler(queries, pgPool, sessionSvc, mailer, rdb, oauthRedirectURL, authHooks)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, sessionSvc, oauthRegistry, oauthRedirectURL, authHooks)
apiKeys := newAPIKeyHandler(apiKeySvc, al)
hostH := newHostHandler(hostSvc, queries, al)
teamH := newTeamHandler(teamSvc, al, mailer)
usersH := newUsersHandler(queries, userSvc, al)
hostH := newHostHandler(hostSvc, queries, al, monitor)
teamH := newTeamHandler(teamSvc, al, mailer, sessionSvc)
usersH := newUsersHandler(queries, userSvc, al, sessionSvc)
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
usageH := newUsageHandler(usageSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool, al)
buildStreamH := newBuildStreamHandler(queries, buildBroker)
channelH := newChannelHandler(channelSvc, al)
ptyH := newPtyHandler(queries, pool, jwtSecret)
processH := newProcessHandler(queries, pool, jwtSecret)
ptyH := newPtyHandler(queries, pool)
processH := newProcessHandler(queries, pool)
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
sandboxEvtH := newSandboxEventHandler(queries, eventPub)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, sessionSvc, mailer, oauthRegistry, oauthRedirectURL, teamSvc, authHooks)
sessionsH := newSessionsHandler(sessionSvc)
// SSE real-time event streaming.
sseBroker := NewSSEBroker()
sseRelay := NewSSERelay(rdb, queries, sseBroker)
sseH := newSSEHandler(sseBroker)
// Health check.
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
@ -107,7 +146,8 @@ func New(
r.Get("/openapi.yaml", serveOpenAPI)
r.Get("/docs", serveDocs)
// Unauthenticated auth endpoints.
// Unauthenticated auth endpoints. CSRF is not required here — there is
// no session cookie yet to be forged.
r.Post("/v1/auth/signup", authH.Signup)
r.Post("/v1/auth/login", authH.Login)
r.Post("/v1/auth/activate", authH.Activate)
@ -118,31 +158,45 @@ func New(
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
// JWT-authenticated: self-service account management.
csrf := requireCSRF()
// Session-authenticated: logout endpoints (must be inside the session
// group so the handler sees the current SID via AuthContext).
r.Group(func(r chi.Router) {
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/v1/auth/logout", authH.Logout)
r.Post("/v1/auth/logout-all", authH.LogoutAll)
r.Post("/v1/auth/switch-team", authH.SwitchTeam)
})
// Session-authenticated: self-service account management.
r.Route("/v1/me", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Get("/", meH.GetMe)
r.Patch("/", meH.UpdateName)
r.Post("/password", meH.ChangePassword)
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
r.Delete("/providers/{provider}", meH.DisconnectProvider)
r.Delete("/", meH.DeleteAccount)
r.Get("/sessions", sessionsH.List)
r.Delete("/sessions/{id}", sessionsH.Delete)
})
// JWT-authenticated: switch active team.
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
// JWT-authenticated: API key management.
// Session-authenticated: API key management.
r.Route("/v1/api-keys", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", apiKeys.Create)
r.Get("/", apiKeys.List)
r.Delete("/{id}", apiKeys.Delete)
})
// JWT-authenticated: team management.
// Session-authenticated: team management.
r.Route("/v1/teams", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Get("/", teamH.List)
r.Post("/", teamH.Create)
r.Route("/{id}", func(r chi.Router) {
@ -157,14 +211,15 @@ func New(
})
})
// JWT-authenticated: user search (for add-member UI).
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Session-authenticated: user search (for add-member UI).
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/users/search", usersH.Search)
// Capsule lifecycle: accepts API key or JWT bearer token.
// Capsule lifecycle: API key (SDK) or session cookie (browser).
r.Route("/v1/capsules", func(r chi.Router) {
// Auth-required routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
@ -174,7 +229,8 @@ func New(
r.Route("/{id}", func(r chi.Router) {
// Auth-required non-WS routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Get("/", sandbox.Get)
r.Delete("/", sandbox.Destroy)
r.Post("/exec", exec.Exec)
@ -193,11 +249,11 @@ func New(
r.Delete("/processes/{selector}", processH.KillProcess)
})
// WebSocket endpoints — handlers authenticate after upgrade.
// optionalAPIKeyOrJWT injects auth context from headers when
// present (SDK clients) but does not reject when absent (browsers).
// WebSocket endpoints — middleware injects auth context from
// cookie (browser) or X-API-Key (SDK) before upgrade. CSRF is
// not applicable to GET upgrades.
r.Group(func(r chi.Router) {
r.Use(optionalAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
@ -205,9 +261,10 @@ func New(
})
})
// Snapshot / template management: accepts API key or JWT bearer token.
// Snapshot / template management: API key (SDK) or session (browser).
r.Route("/v1/snapshots", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Post("/", snapshots.Create)
r.Get("/", snapshots.List)
r.Delete("/{name}", snapshots.Delete)
@ -221,12 +278,14 @@ func New(
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat.
// Host-token-authenticated: heartbeat and lifecycle callbacks.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
r.With(requireHostToken(jwtSecret)).Post("/sandbox-events", sandboxEvtH.Handle)
// JWT-authenticated: host CRUD and tags.
// Session-authenticated: host CRUD and tags.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", hostH.Create)
r.Get("/", hostH.List)
r.Route("/{id}", func(r chi.Router) {
@ -241,9 +300,10 @@ func New(
})
})
// JWT-authenticated: notification channels.
// Session-authenticated: notification channels.
r.Route("/v1/channels", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", channelH.Create)
r.Get("/", channelH.List)
r.Post("/test", channelH.Test)
@ -255,18 +315,27 @@ func New(
})
})
// JWT-authenticated: audit log.
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
// SSE event stream: browser sends wrenn_sid cookie on EventSource
// automatically; SDKs may set X-API-Key. Ticket-based auth is gone.
r.With(requireSessionOrAPIKey(queries, sessionSvc)).Get("/v1/events/stream", sseH.Stream)
// Platform admin routes — require JWT + DB-validated admin status.
// Session-authenticated: audit log.
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/audit-logs", auditH.List)
// Platform admin routes — session + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
// Admin SSE event stream (sees all teams).
r.With(requireSession(queries, sessionSvc), requireAdmin(queries), injectPlatformTeam()).Get("/events/stream", sseH.AdminStream)
// Auth-required admin routes (non-capsule + capsule list/create).
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(csrf)
r.Get("/teams", teamH.AdminListTeams)
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/hosts", hostH.AdminList)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
@ -281,11 +350,21 @@ func New(
r.Get("/capsules", adminCapsules.List)
})
// Admin build console WebSocket — cookie + admin DB check before
// upgrade, no CSRF (WS upgrade). Builds are platform-scoped, not
// sandbox-scoped, so this sits outside the /capsules/{id} router.
r.Group(func(r chi.Router) {
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Get("/builds/{id}/stream", buildStreamH.Stream)
})
r.Route("/capsules/{id}", func(r chi.Router) {
// Auth-required non-WS admin capsule routes.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(csrf)
r.Use(injectPlatformTeam())
r.Get("/", adminCapsules.Get)
r.Delete("/", adminCapsules.Destroy)
@ -301,11 +380,13 @@ func New(
r.Delete("/processes/{selector}", processH.KillProcess)
})
// Admin WebSocket endpoints — handlers authenticate after upgrade
// via wsAuthenticateAdmin. markAdminWS sets the context flag so
// handlers know to use admin auth instead of regular auth.
// Admin WebSocket endpoints — browser auth via cookie + admin DB check
// before upgrade. markAdminWS is retained as a context hint for any
// admin-specific behavior downstream.
r.Group(func(r chi.Router) {
r.Use(markAdminWS)
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(injectPlatformTeam())
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
@ -318,7 +399,7 @@ func New(
ext.RegisterRoutes(r, sctx)
}
return &Server{router: r, BuildSvc: buildSvc, version: version}
return &Server{router: r, BuildSvc: buildSvc, SSERelay: sseRelay, SessionSvc: sessionSvc, version: version}
}
// Handler returns the HTTP handler.
@ -363,3 +444,101 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
</body>
</html>`)
}
// serviceEventToCanonical maps a SandboxService background-goroutine event
// into the canonical events.Event taxonomy for unified publishing. Returns
// false for events that should not be broadcast.
func serviceEventToCanonical(e service.SandboxStateEvent) (events.Event, bool) {
var (
eventType string
outcome events.Outcome
metadata map[string]string
)
switch e.Event {
case "sandbox.started":
eventType = events.CapsuleCreate
outcome = events.OutcomeSuccess
case "sandbox.resumed":
eventType = events.CapsuleResume
outcome = events.OutcomeSuccess
case "sandbox.paused":
eventType = events.CapsulePause
outcome = events.OutcomeSuccess
case "sandbox.auto_paused":
eventType = events.CapsulePause
outcome = events.OutcomeSuccess
metadata = map[string]string{"reason": "ttl_expired"}
case "sandbox.stopped":
eventType = events.CapsuleDestroy
outcome = events.OutcomeSuccess
case "sandbox.pause_failed":
// reason must be non-empty or channels.isRedundantSystemFollowup
// filters this system-actor event out of webhook delivery.
eventType = events.CapsulePause
outcome = events.OutcomeError
metadata = map[string]string{"reason": "pause_failed"}
case "sandbox.resume_failed":
eventType = events.CapsuleResume
outcome = events.OutcomeError
metadata = map[string]string{"reason": "resume_failed"}
case "sandbox.failed":
// First-boot failure from the createInBackground goroutine. Without
// this case the event falls through to default and is dropped — no
// SSE push, no channel delivery, no DB reconciliation. reason must be
// non-empty or channels.isRedundantSystemFollowup filters it out.
eventType = events.CapsuleCreate
outcome = events.OutcomeError
metadata = map[string]string{"reason": "create_failed"}
case "sandbox.snapshotted":
// Completion of an async snapshot. The resource is the template name,
// not the sandbox, so the dashboard's snapshot list refreshes.
return events.Event{
Event: events.SnapshotCreate,
Outcome: events.OutcomeSuccess,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
}, true
case "sandbox.snapshot_failed":
return events.Event{
Event: events.SnapshotCreate,
Outcome: events.OutcomeError,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
Metadata: map[string]string{"reason": "snapshot_failed"},
Error: e.Error,
}, true
case "sandbox.state_changed":
// Transient badge transition with no terminal verb of its own. Carries
// from/to in metadata; routed via Pub/Sub only by the caller.
return events.Event{
Event: events.CapsuleStateChanged,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
Metadata: e.Metadata,
}, true
default:
return events.Event{}, false
}
if e.HostIP != "" {
if metadata == nil {
metadata = map[string]string{}
}
metadata["host_ip"] = e.HostIP
}
return events.Event{
Event: eventType,
Outcome: outcome,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
Metadata: metadata,
Error: e.Error,
}, true
}

View File

@ -0,0 +1,80 @@
package api
import (
"encoding/json"
"log/slog"
"sync"
"sync/atomic"
)
const sseChannelBuffer = 32
type sseMessage struct {
EventType string
Data json.RawMessage
}
type sseSubscriber struct {
teamID string
isAdmin bool
ch chan sseMessage
}
// SSEBroker is an in-process fan-out hub that dispatches events to connected
// SSE clients, filtering by team ownership.
type SSEBroker struct {
mu sync.RWMutex
nextID atomic.Uint64
subscribers map[uint64]*sseSubscriber
}
// NewSSEBroker constructs a broker with no subscribers.
func NewSSEBroker() *SSEBroker {
return &SSEBroker{
subscribers: make(map[uint64]*sseSubscriber),
}
}
// Subscribe registers a new SSE client. Returns a subscriber ID (for
// Unsubscribe) and a receive-only channel that delivers filtered events.
func (b *SSEBroker) Subscribe(teamID string, isAdmin bool) (uint64, <-chan sseMessage) {
id := b.nextID.Add(1)
ch := make(chan sseMessage, sseChannelBuffer)
sub := &sseSubscriber{teamID: teamID, isAdmin: isAdmin, ch: ch}
b.mu.Lock()
b.subscribers[id] = sub
b.mu.Unlock()
slog.Debug("sse: client subscribed", "sub_id", id, "team_id", teamID, "admin", isAdmin)
return id, ch
}
// Unsubscribe removes a client. The handler loop exits via context cancellation;
// the channel is not closed here to avoid send-on-closed-channel races with Dispatch.
func (b *SSEBroker) Unsubscribe(id uint64) {
b.mu.Lock()
delete(b.subscribers, id)
b.mu.Unlock()
slog.Debug("sse: client unsubscribed", "sub_id", id)
}
// Dispatch fans out an event to all matching subscribers. Admin subscribers
// receive all events; team subscribers only receive events for their team.
// Non-blocking: events are dropped for slow consumers.
func (b *SSEBroker) Dispatch(eventType string, teamID string, data json.RawMessage) {
b.mu.RLock()
defer b.mu.RUnlock()
msg := sseMessage{EventType: eventType, Data: data}
for id, sub := range b.subscribers {
if !sub.isAdmin && sub.teamID != teamID {
continue
}
select {
case sub.ch <- msg:
default:
slog.Warn("sse: dropped event for slow consumer", "sub_id", id, "event", eventType)
}
}
}

147
internal/api/sse_relay.go Normal file
View File

@ -0,0 +1,147 @@
package api
import (
"context"
"encoding/json"
"log/slog"
"time"
"github.com/jackc/pgx/v5"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const ssePubSubChannel = "wrenn:sse"
// sseEventPayload is the JSON envelope sent to SSE clients.
type sseEventPayload struct {
Event string `json:"event"`
Outcome events.Outcome `json:"outcome,omitempty"`
Timestamp string `json:"timestamp"`
TeamID string `json:"team_id"`
Actor events.Actor `json:"actor"`
Resource events.Resource `json:"resource"`
Metadata map[string]string `json:"metadata,omitempty"`
Error string `json:"error,omitempty"`
Sandbox *sandboxResponse `json:"sandbox,omitempty"`
}
// SSERelay subscribes to the Redis Pub/Sub channel and dispatches hydrated
// events to the in-process broker. One instance per CP process.
type SSERelay struct {
rdb *redis.Client
db *db.Queries
broker *SSEBroker
}
// NewSSERelay constructs the relay.
func NewSSERelay(rdb *redis.Client, queries *db.Queries, broker *SSEBroker) *SSERelay {
return &SSERelay{rdb: rdb, db: queries, broker: broker}
}
// Start launches the Pub/Sub subscription goroutine. Returns when ctx is cancelled.
func (r *SSERelay) Start(ctx context.Context) {
go r.run(ctx)
}
func (r *SSERelay) run(ctx context.Context) {
for {
if ctx.Err() != nil {
return
}
r.subscribe(ctx)
// Backoff before reconnecting.
select {
case <-ctx.Done():
return
case <-time.After(2 * time.Second):
}
}
}
func (r *SSERelay) subscribe(ctx context.Context) {
pubsub := r.rdb.Subscribe(ctx, ssePubSubChannel)
defer pubsub.Close()
ch := pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
r.handleMessage(ctx, msg)
}
}
}
func (r *SSERelay) handleMessage(ctx context.Context, msg *redis.Message) {
var event events.Event
if err := json.Unmarshal([]byte(msg.Payload), &event); err != nil {
slog.Warn("sse relay: failed to unmarshal event", "error", err)
return
}
payload := sseEventPayload{
Event: event.Event,
Outcome: event.Outcome,
Timestamp: event.Timestamp,
TeamID: event.TeamID,
Actor: event.Actor,
Resource: event.Resource,
Metadata: event.Metadata,
Error: event.Error,
}
// Hydrate sandbox state for capsule events.
if isCapsuleEvent(event.Event) {
sb, err := r.hydrateSandbox(ctx, event.Resource.ID)
if err != nil {
slog.Debug("sse relay: sandbox hydration failed (may be deleted)", "sandbox_id", event.Resource.ID, "error", err)
} else {
payload.Sandbox = sb
}
}
data, err := json.Marshal(payload)
if err != nil {
slog.Warn("sse relay: failed to marshal payload", "error", err)
return
}
r.broker.Dispatch(event.Event, event.TeamID, data)
}
func (r *SSERelay) hydrateSandbox(ctx context.Context, sandboxIDStr string) (*sandboxResponse, error) {
queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
return nil, err
}
sb, err := r.db.GetSandbox(queryCtx, sandboxID)
if err != nil {
if err == pgx.ErrNoRows {
return nil, nil
}
return nil, err
}
resp := sandboxToResponse(sb)
return &resp, nil
}
func isCapsuleEvent(eventType string) bool {
switch eventType {
case events.CapsuleCreate, events.CapsulePause, events.CapsuleResume, events.CapsuleDestroy, events.CapsuleStateChanged:
return true
}
return false
}