forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -3,11 +3,20 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
@ -20,3 +29,119 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
|
||||
}
|
||||
return pool.GetForHost(host)
|
||||
}
|
||||
|
||||
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
|
||||
// and verifies it is running. On failure it writes the appropriate HTTP error and
|
||||
// returns false.
|
||||
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
|
||||
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
|
||||
return sb, sandboxID, sandboxIDStr, true
|
||||
}
|
||||
|
||||
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket. The
|
||||
// auth context must already be populated by upstream middleware — browser
|
||||
// clients via the wrenn_sid cookie (sent automatically on WS upgrade),
|
||||
// SDK clients via X-API-Key. Requests without an auth context are rejected
|
||||
// with a 401 before the upgrade.
|
||||
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, auth.AuthContext, error) {
|
||||
ac, hasAuth := auth.FromContext(r.Context())
|
||||
if !hasAuth {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
|
||||
return nil, auth.AuthContext{}, fmt.Errorf("unauthenticated")
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
||||
}
|
||||
return conn, ac, nil
|
||||
}
|
||||
|
||||
// resolveTemplateSizes queries a host agent for the actual disk usage of any
|
||||
// templates with size_bytes <= 0 (e.g. system base templates seeded with
|
||||
// size_bytes = 0 before the rootfs was built). Results are persisted to the
|
||||
// DB so subsequent requests serve the correct size without an RPC call.
|
||||
// Errors are logged but do not prevent the caller from serving the templates.
|
||||
func resolveTemplateSizes(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, templates []db.Template) []db.Template {
|
||||
needResolve := false
|
||||
for _, t := range templates {
|
||||
if t.SizeBytes <= 0 {
|
||||
needResolve = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needResolve {
|
||||
return templates
|
||||
}
|
||||
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil || len(hosts) == 0 {
|
||||
slog.Warn("resolveTemplateSizes: no active hosts available", "error", err)
|
||||
return templates
|
||||
}
|
||||
|
||||
agent, err := pool.GetForHost(hosts[0])
|
||||
if err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to connect to host",
|
||||
"host_id", id.UUIDString(hosts[0].ID), "error", err)
|
||||
return templates
|
||||
}
|
||||
|
||||
for i, t := range templates {
|
||||
if t.SizeBytes > 0 {
|
||||
continue
|
||||
}
|
||||
resp, err := agent.GetTemplateSize(ctx, connect.NewRequest(&pb.GetTemplateSizeRequest{
|
||||
TeamId: formatUUIDForRPC(t.TeamID),
|
||||
TemplateId: formatUUIDForRPC(t.ID),
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to get size from host",
|
||||
"template", t.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
templates[i].SizeBytes = resp.Msg.SizeBytes
|
||||
if err := queries.UpdateTemplateSize(ctx, db.UpdateTemplateSizeParams{
|
||||
ID: t.ID,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
}); err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to persist size",
|
||||
"template", t.Name, "error", err)
|
||||
}
|
||||
}
|
||||
return templates
|
||||
}
|
||||
|
||||
// updateLastActive updates the sandbox last_active_at timestamp.
|
||||
// Uses a background context with timeout for streaming handlers where
|
||||
// the request context may already be cancelled.
|
||||
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
54
internal/api/auth_hooks.go
Normal file
54
internal/api/auth_hooks.go
Normal file
@ -0,0 +1,54 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// collectAuthHooks filters the extension list down to those that implement
|
||||
// cpextension.AuthHook.
|
||||
func collectAuthHooks(extensions []cpextension.Extension) []cpextension.AuthHook {
|
||||
var hooks []cpextension.AuthHook
|
||||
for _, ext := range extensions {
|
||||
if h, ok := ext.(cpextension.AuthHook); ok {
|
||||
hooks = append(hooks, h)
|
||||
}
|
||||
}
|
||||
return hooks
|
||||
}
|
||||
|
||||
// fireOnSignup runs every OnSignup hook sequentially. The first error wins —
|
||||
// signup must abort so the cloud-side billing customer creation stays the
|
||||
// gating step. Other hook errors after a failure are not run.
|
||||
func fireOnSignup(ctx context.Context, hooks []cpextension.AuthHook, userID, teamID pgtype.UUID, email string) error {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnSignup(ctx, userID, teamID, email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fireOnLogin runs every OnLogin hook. Errors are logged and swallowed —
|
||||
// login must never be blocked by a misbehaving extension.
|
||||
func fireOnLogin(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnLogin(ctx, userID); err != nil {
|
||||
slog.Warn("auth hook OnLogin failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fireOnSoftDelete runs every OnAccountSoftDelete hook. Errors are logged.
|
||||
func fireOnSoftDelete(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnAccountSoftDelete(ctx, userID); err != nil {
|
||||
slog.Warn("auth hook OnAccountSoftDelete failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
@ -17,8 +12,6 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type adminCapsuleHandler struct {
|
||||
@ -49,15 +42,18 @@ func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
|
||||
if err != nil {
|
||||
if sb.ID.Valid {
|
||||
h.audit.LogSandboxDestroySystem(r.Context(), id.PlatformTeamID, sb.ID, "cleanup_after_create_error", nil)
|
||||
}
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/capsules.
|
||||
@ -106,26 +102,27 @@ func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
err = h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID)
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
type adminSnapshotRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
|
||||
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot. Takes a live
|
||||
// snapshot of a platform-owned capsule and registers the result as a
|
||||
// platform template (team_id = 00000000-...).
|
||||
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
@ -133,117 +130,22 @@ func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req adminSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify sandbox exists and belongs to platform team BEFORE any
|
||||
// destructive operations (template overwrite).
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists as a platform template.
|
||||
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
if r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
sb, name, err := h.svc.CreateSnapshot(r.Context(), sandboxID, id.PlatformTeamID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context so the snapshot completes even if the client disconnects.
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(id.PlatformTeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the ephemeral capsule after successful snapshot.
|
||||
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
|
||||
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
|
||||
// Don't fail the response — the snapshot was created successfully.
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
@ -20,6 +20,8 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
@ -140,14 +142,15 @@ type switchTeamRequest struct {
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
sessions *session.Service
|
||||
mailer email.Mailer
|
||||
rdb *redis.Client
|
||||
redirectURL string
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, sessions *session.Service, mailer email.Mailer, rdb *redis.Client, redirectURL string, hooks []cpextension.AuthHook) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, sessions: sessions, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/"), authHooks: hooks}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
@ -166,11 +169,41 @@ type activateRequest struct {
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
// issueSession creates a new session and writes both the session and CSRF
|
||||
// cookies to the response. On success it writes the authResponse JSON body.
|
||||
func (h *authHandler) issueSession(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
) error {
|
||||
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIP returns the request's apparent client IP, honoring
|
||||
// X-Forwarded-For when behind a reverse proxy.
|
||||
func clientIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
if i := strings.IndexByte(fwd, ','); i > 0 {
|
||||
return strings.TrimSpace(fwd[:i])
|
||||
}
|
||||
return strings.TrimSpace(fwd)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
type signupResponse struct {
|
||||
@ -253,8 +286,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Generate activation token and store in Redis.
|
||||
rawToken := generateActivationToken()
|
||||
tokenHash := hashActivationToken(rawToken)
|
||||
rawToken := generateOpaqueToken()
|
||||
tokenHash := hashOpaqueToken(rawToken)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
|
||||
@ -296,7 +329,7 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashActivationToken(req.Token)
|
||||
tokenHash := hashOpaqueToken(req.Token)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
@ -345,18 +378,26 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
// Fire OnSignup before issuing a session — billing must succeed first.
|
||||
if err := fireOnSignup(ctx, h.authHooks, userID, team.ID, user.Email); err != nil {
|
||||
slog.Error("activate: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "signup_hook_failed", "failed to finalize account setup")
|
||||
return
|
||||
}
|
||||
if err := h.issueSession(w, r, userID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
|
||||
slog.Error("activate: failed to issue session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, userID)
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
@ -427,18 +468,20 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
if err := h.issueSession(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
|
||||
slog.Error("login: failed to issue session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
@ -503,24 +546,54 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
|
||||
// Rotate the SID so any leaked old cookie loses access at the moment of
|
||||
// privilege change.
|
||||
newSess, err := h.sessions.Rotate(ctx, ac.SessionID, ac.UserID, teamID, user.Email, user.Name, membership.Role, user.IsAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
slog.Error("switch team: failed to rotate session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to switch team")
|
||||
return
|
||||
}
|
||||
setSessionCookies(w, newSess.RawSID, newSess.CSRFToken, isSecure(r))
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: ac.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: membership.Role,
|
||||
IsAdmin: user.IsAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
// Logout handles POST /v1/auth/logout — revokes the caller's current session.
|
||||
func (h *authHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if err := h.sessions.Revoke(r.Context(), ac.SessionID); err != nil {
|
||||
slog.Warn("logout: revoke failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// LogoutAll handles POST /v1/auth/logout-all — revokes every session for the
|
||||
// current user, including the caller's own.
|
||||
func (h *authHandler) LogoutAll(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if err := h.sessions.RevokeAllForUser(r.Context(), ac.UserID); err != nil {
|
||||
slog.Error("logout-all: revoke failed", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to revoke sessions")
|
||||
return
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateActivationToken() string {
|
||||
// generateOpaqueToken returns a fresh 16-byte hex-encoded random token,
|
||||
// used for email activation links and password reset links.
|
||||
func generateOpaqueToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
@ -528,7 +601,9 @@ func generateActivationToken() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashActivationToken(raw string) string {
|
||||
// hashOpaqueToken returns the SHA-256 hex digest of raw, used as the
|
||||
// lookup key for one-shot tokens stored in Redis.
|
||||
func hashOpaqueToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
158
internal/api/handlers_build_stream.go
Normal file
158
internal/api/handlers_build_stream.go
Normal file
@ -0,0 +1,158 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/recipe"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
// buildStreamKeepalive is the interval between server pings on an idle build
|
||||
// stream, preventing intermediaries from closing the WebSocket.
|
||||
const buildStreamKeepalive = 30 * time.Second
|
||||
|
||||
// buildStreamHandler serves the live admin build console WebSocket.
|
||||
type buildStreamHandler struct {
|
||||
db *db.Queries
|
||||
broker *service.BuildBroker
|
||||
}
|
||||
|
||||
func newBuildStreamHandler(db *db.Queries, broker *service.BuildBroker) *buildStreamHandler {
|
||||
return &buildStreamHandler{db: db, broker: broker}
|
||||
}
|
||||
|
||||
// Stream handles WS /v1/admin/builds/{id}/stream. On connect it replays the
|
||||
// completed-step history from the DB log, sends the current build status,
|
||||
// then live-tails events from the build broker until the build finishes or
|
||||
// the client disconnects. Admin auth is enforced by upstream middleware.
|
||||
func (h *buildStreamHandler) Stream(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.db.GetTemplateBuild(r.Context(), buildID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "build not found")
|
||||
return
|
||||
}
|
||||
|
||||
conn, _, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("build stream websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runStream(r.Context(), conn, build)
|
||||
}
|
||||
|
||||
func (h *buildStreamHandler) runStream(ctx context.Context, conn *websocket.Conn, build db.TemplateBuild) {
|
||||
ws := &wsWriter{conn: conn}
|
||||
buildIDStr := id.FormatBuildID(build.ID)
|
||||
|
||||
// Replay completed-step history from the DB log snapshot. lastStep is the
|
||||
// highest step number already delivered, used to dedup overlapping live
|
||||
// events for a step that finished between the DB read and the subscribe.
|
||||
lastStep := replayBuildHistory(ws, build)
|
||||
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "build-status",
|
||||
Status: build.Status,
|
||||
CurrentStep: build.CurrentStep,
|
||||
TotalSteps: build.TotalSteps,
|
||||
Error: build.Error,
|
||||
})
|
||||
|
||||
// A finished build has no live events to follow.
|
||||
if service.IsTerminalBuildStatus(build.Status) {
|
||||
return
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
events, release := h.broker.Subscribe(buildIDStr)
|
||||
defer release()
|
||||
|
||||
// Drain client reads so a disconnect cancels the stream. The client sends
|
||||
// nothing meaningful; any read error means the socket is gone.
|
||||
go func() {
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(buildStreamKeepalive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ws.writeJSON(map[string]string{"type": "ping"})
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Skip step events already covered by the history replay.
|
||||
if ev.Type != "build-status" && ev.Step > 0 && ev.Step <= lastStep {
|
||||
continue
|
||||
}
|
||||
ws.writeJSON(ev)
|
||||
if ev.Type == "build-status" && service.IsTerminalBuildStatus(ev.Status) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replayBuildHistory synthesizes step-start/output/step-end events from the
|
||||
// build's persisted log entries and writes them to the WebSocket. It returns
|
||||
// the highest step number replayed.
|
||||
func replayBuildHistory(ws *wsWriter, build db.TemplateBuild) int {
|
||||
if len(build.Logs) == 0 {
|
||||
return 0
|
||||
}
|
||||
var entries []recipe.BuildLogEntry
|
||||
if err := json.Unmarshal(build.Logs, &entries); err != nil {
|
||||
slog.Warn("build stream: bad log JSON", "build_id", id.FormatBuildID(build.ID), "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastStep := 0
|
||||
for _, e := range entries {
|
||||
ws.writeJSON(service.BuildStreamEvent{Type: "step-start", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd})
|
||||
if out := e.Stdout + e.Stderr; out != "" {
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "output",
|
||||
Step: e.Step,
|
||||
Data: base64.StdEncoding.EncodeToString([]byte(out)),
|
||||
})
|
||||
}
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "step-end", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd,
|
||||
Exit: e.Exit, Ok: e.Ok, ElapsedMs: e.Elapsed,
|
||||
})
|
||||
if e.Step > lastStep {
|
||||
lastStep = e.Step
|
||||
}
|
||||
}
|
||||
return lastStep
|
||||
}
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
@ -20,7 +19,6 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type buildHandler struct {
|
||||
@ -42,6 +40,7 @@ type createBuildRequest struct {
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
RunAsRoot bool `json:"run_as_root"`
|
||||
}
|
||||
|
||||
type buildResponse struct {
|
||||
@ -181,6 +180,7 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
RunAsRoot: req.RunAsRoot,
|
||||
Archive: archive,
|
||||
ArchiveName: archiveName,
|
||||
})
|
||||
@ -238,6 +238,9 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve actual on-disk sizes for templates with unknown size.
|
||||
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
|
||||
|
||||
type templateResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@ -246,6 +249,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Protected bool `json:"protected"`
|
||||
}
|
||||
|
||||
resp := make([]templateResponse, len(templates))
|
||||
@ -257,6 +261,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: t.MemoryMb,
|
||||
SizeBytes: t.SizeBytes,
|
||||
TeamID: id.FormatTeamID(t.TeamID),
|
||||
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -280,29 +285,17 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
// Broadcast delete to all online hosts.
|
||||
hosts, _ := h.db.ListActiveHosts(ctx)
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(tmpl.TeamID),
|
||||
TemplateId: formatUUIDForRPC(tmpl.ID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
// Remove the files from every host before dropping the DB record, so a
|
||||
// failure leaves the template intact and retryable rather than orphaned.
|
||||
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusConflict, "delete_failed",
|
||||
"could not remove template files from all hosts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
|
||||
|
||||
@ -3,14 +3,11 @@ package api
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
|
||||
|
||||
// Exec handles POST /v1/capsules/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
|
||||
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
@ -142,6 +119,8 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
Envs: req.Envs,
|
||||
Cwd: req.Cwd,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
@ -151,41 +130,24 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Update last active.
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
stdout := resp.Msg.Stdout
|
||||
stderr := resp.Msg.Stderr
|
||||
encoding := "utf-8"
|
||||
|
||||
encoding := "utf-8"
|
||||
stdoutStr, stderrStr := string(stdout), string(stderr)
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
return
|
||||
stdoutStr = base64.StdEncoding.EncodeToString(stdout)
|
||||
stderrStr = base64.StdEncoding.EncodeToString(stderr)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
Stdout: stdoutStr,
|
||||
Stderr: stderrStr,
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -20,13 +19,12 @@ import (
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -59,37 +57,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
slog.Error("websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -186,18 +156,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
|
||||
@ -7,11 +7,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -108,23 +94,11 @@ type readFileRequest struct {
|
||||
// Download handles POST /v1/capsules/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -8,11 +8,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -103,6 +89,12 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := agent.WriteFileStream(ctx)
|
||||
var streamClosed bool
|
||||
defer func() {
|
||||
if !streamClosed {
|
||||
_, _ = stream.CloseAndReceive()
|
||||
}
|
||||
}()
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
@ -141,6 +133,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
streamClosed = true
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -153,23 +146,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -4,11 +4,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -58,23 +56,11 @@ type removeRequest struct {
|
||||
|
||||
// ListDir handles POST /v1/capsules/{id}/files/list.
|
||||
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
||||
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Remove handles POST /v1/capsules/{id}/files/remove.
|
||||
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@ -21,10 +22,11 @@ type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
monitor *HostMonitor
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al}
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger, monitor *HostMonitor) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al, monitor: monitor}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
@ -98,6 +100,11 @@ type hostResponse struct {
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
RunningDiskMb int32 `json:"running_disk_mb"`
|
||||
PausedMemoryMb int32 `json:"paused_memory_mb"`
|
||||
PausedDiskMb int32 `json:"paused_disk_mb"`
|
||||
}
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
@ -136,12 +143,37 @@ func hostToResponse(h db.Host) hostResponse {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
resp.LastHeartbeatAt = &s
|
||||
}
|
||||
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
|
||||
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
|
||||
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
|
||||
return resp
|
||||
}
|
||||
|
||||
func hostToResponseWithLoad(h db.ListHostsByTeamRow) hostResponse {
|
||||
resp := hostToResponse(db.Host{
|
||||
ID: h.ID,
|
||||
Type: h.Type,
|
||||
TeamID: h.TeamID,
|
||||
Provider: h.Provider,
|
||||
AvailabilityZone: h.AvailabilityZone,
|
||||
Arch: h.Arch,
|
||||
CpuCores: h.CpuCores,
|
||||
MemoryMb: h.MemoryMb,
|
||||
DiskGb: h.DiskGb,
|
||||
Address: h.Address,
|
||||
Status: h.Status,
|
||||
LastHeartbeatAt: h.LastHeartbeatAt,
|
||||
CreatedBy: h.CreatedBy,
|
||||
CreatedAt: h.CreatedAt,
|
||||
UpdatedAt: h.UpdatedAt,
|
||||
})
|
||||
resp.RunningVcpus = h.RunningVcpus
|
||||
resp.RunningMemoryMb = h.RunningMemoryMb
|
||||
resp.RunningDiskMb = h.RunningDiskMb
|
||||
resp.PausedMemoryMb = h.PausedMemoryMb
|
||||
resp.PausedDiskMb = h.PausedDiskMb
|
||||
return resp
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
@ -233,7 +265,7 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
resp[i] = hostToResponseWithLoad(host)
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
@ -335,6 +367,54 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// AdminList handles GET /v1/admin/hosts.
|
||||
// Returns all hosts with per-host resource consumption. Admin-only.
|
||||
func (h *hostHandler) AdminList(w http.ResponseWriter, r *http.Request) {
|
||||
hosts, err := h.svc.ListAdmin(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique team IDs to fetch team names.
|
||||
var teamNames map[string]string
|
||||
seen := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) > 0 {
|
||||
teamNames = make(map[string]string, len(seen))
|
||||
for _, host := range hosts {
|
||||
if !host.TeamID.Valid {
|
||||
continue
|
||||
}
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if _, ok := teamNames[key]; ok {
|
||||
continue
|
||||
}
|
||||
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
|
||||
teamNames[key] = team.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponseWithLoad(db.ListHostsByTeamRow(host))
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
resp[i].TeamName = &name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
@ -426,9 +506,12 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log marked_up if the host just recovered from unreachable.
|
||||
// If the host just recovered from unreachable, log it and trigger immediate
|
||||
// reconciliation so "missing" sandboxes are resolved without waiting for the
|
||||
// next monitor tick.
|
||||
if prevHost.Status == "unreachable" {
|
||||
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
|
||||
go h.monitor.ReconcileHost(context.Background(), hc.HostID)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
@ -2,9 +2,6 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@ -20,6 +17,8 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
@ -34,42 +33,62 @@ type meHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
jwtSecret []byte
|
||||
hmacKey []byte // HMAC key for OAuth state and link cookies (was jwtSecret)
|
||||
sessions *session.Service
|
||||
mailer email.Mailer
|
||||
oauthRegistry *oauth.Registry
|
||||
redirectURL string
|
||||
teamSvc *service.TeamService
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newMeHandler(
|
||||
db *db.Queries,
|
||||
pool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
hmacKey []byte,
|
||||
sessions *session.Service,
|
||||
mailer email.Mailer,
|
||||
registry *oauth.Registry,
|
||||
redirectURL string,
|
||||
teamSvc *service.TeamService,
|
||||
hooks []cpextension.AuthHook,
|
||||
) *meHandler {
|
||||
return &meHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
rdb: rdb,
|
||||
jwtSecret: jwtSecret,
|
||||
hmacKey: hmacKey,
|
||||
sessions: sessions,
|
||||
mailer: mailer,
|
||||
oauthRegistry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
teamSvc: teamSvc,
|
||||
authHooks: hooks,
|
||||
}
|
||||
}
|
||||
|
||||
type meResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
HasPassword bool `json:"has_password"`
|
||||
Providers []string `json:"providers"`
|
||||
}
|
||||
|
||||
type sessionRow struct {
|
||||
ID string `json:"id"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastSeenAt string `json:"last_seen_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
Current bool `json:"current"`
|
||||
}
|
||||
|
||||
type updateNameRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
@ -116,14 +135,19 @@ func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, meResponse{
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: ac.Role,
|
||||
IsAdmin: user.IsAdmin,
|
||||
HasPassword: user.PasswordHash.Valid,
|
||||
Providers: providerNames,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and refreshes
|
||||
// any cached session blobs so the new name shows up on next request.
|
||||
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
@ -148,31 +172,11 @@ func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
if err := h.sessions.InvalidateCacheForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("update name: invalidate session cache failed", "error", err)
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ChangePassword handles POST /v1/me/password.
|
||||
@ -235,6 +239,14 @@ func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke every session for this user — including the caller's — so a new
|
||||
// password resets all device access. Clear cookies on the response so the
|
||||
// caller is signed out immediately.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("change password: revoke sessions failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
|
||||
isAdding := !user.PasswordHash.Valid
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
@ -285,8 +297,8 @@ func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
rawToken := generateResetToken()
|
||||
tokenHash := hashResetToken(rawToken)
|
||||
rawToken := generateOpaqueToken()
|
||||
tokenHash := hashOpaqueToken(rawToken)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
|
||||
@ -330,7 +342,7 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashResetToken(req.Token)
|
||||
tokenHash := hashOpaqueToken(req.Token)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
// GetDel atomically retrieves and removes the token in a single round-trip,
|
||||
@ -371,6 +383,11 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset invalidates every active session for the user.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, userID); err != nil {
|
||||
slog.Warn("confirm password reset: revoke sessions failed", "error", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
@ -404,7 +421,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state+":"+"login")
|
||||
mac := computeHMAC(h.hmacKey, state+":"+"login")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state + ":" + mac + ":" + "login",
|
||||
@ -416,7 +433,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
userIDStr := id.FormatUserID(ac.UserID)
|
||||
linkMac := computeHMAC(h.jwtSecret, userIDStr)
|
||||
linkMac := computeHMAC(h.hmacKey, userIDStr)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: userIDStr + ":" + linkMac,
|
||||
@ -552,6 +569,14 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke every active session and clear the caller's cookies.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("delete account: revoke sessions failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
|
||||
fireOnSoftDelete(ctx, h.authHooks, ac.UserID)
|
||||
|
||||
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
|
||||
|
||||
go func() {
|
||||
@ -569,17 +594,4 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateResetToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashResetToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
// (token helpers live in handlers_auth.go)
|
||||
|
||||
@ -14,10 +14,12 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
@ -25,18 +27,22 @@ import (
|
||||
type oauthHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
hmacKey []byte // HMAC key for OAuth state and link cookies
|
||||
sessions *session.Service
|
||||
registry *oauth.Registry
|
||||
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, hmacKey []byte, sessions *session.Service, registry *oauth.Registry, redirectURL string, hooks []cpextension.AuthHook) *oauthHandler {
|
||||
return &oauthHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
jwtSecret: jwtSecret,
|
||||
hmacKey: hmacKey,
|
||||
sessions: sessions,
|
||||
registry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
authHooks: hooks,
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +67,7 @@ func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
intent = "login"
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state+":"+intent)
|
||||
mac := computeHMAC(h.hmacKey, state+":"+intent)
|
||||
cookieVal := state + ":" + mac + ":" + intent
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
@ -121,7 +127,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if len(parts) == 3 && parts[2] == "signup" {
|
||||
intent = "signup"
|
||||
}
|
||||
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce+":"+intent)), []byte(expectedMAC)) {
|
||||
if !hmac.Equal([]byte(computeHMAC(h.hmacKey, nonce+":"+intent)), []byte(expectedMAC)) {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
@ -164,7 +170,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Verify the HMAC to prevent cookie forgery.
|
||||
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.hmacKey, linkParts[0])), []byte(linkParts[1])) {
|
||||
slog.Warn("oauth link: invalid or tampered link cookie")
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
@ -244,13 +250,12 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
|
||||
slog.Error("oauth login: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
@ -374,10 +379,10 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
// Fire OnSignup before session issuance — billing must succeed first.
|
||||
if hookErr := fireOnSignup(ctx, h.authHooks, userID, teamID, email); hookErr != nil {
|
||||
slog.Error("oauth signup: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", hookErr)
|
||||
redirectWithError(w, r, redirectBase, "signup_hook_failed")
|
||||
return
|
||||
}
|
||||
|
||||
@ -385,14 +390,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "wrenn_oauth_new_signup",
|
||||
Value: "1",
|
||||
Path: "/auth/",
|
||||
Path: "/",
|
||||
MaxAge: 60,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
|
||||
if err := h.issueSessionAndRedirect(w, r, userID, teamID, email, profile.Name, "owner", isFirstUser, redirectBase); err != nil {
|
||||
slog.Error("oauth: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, userID)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
@ -431,33 +441,35 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
return
|
||||
}
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
|
||||
slog.Error("oauth: retry login: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
|
||||
// Set auth data as short-lived cookies instead of URL query parameters.
|
||||
// This prevents token leakage via server access logs, Referer headers, and browser history.
|
||||
for _, c := range []http.Cookie{
|
||||
{Name: "wrenn_oauth_token", Value: token},
|
||||
{Name: "wrenn_oauth_user_id", Value: userID},
|
||||
{Name: "wrenn_oauth_team_id", Value: teamID},
|
||||
{Name: "wrenn_oauth_email", Value: email},
|
||||
{Name: "wrenn_oauth_name", Value: name},
|
||||
} {
|
||||
c.Path = "/auth/"
|
||||
c.MaxAge = 60
|
||||
c.HttpOnly = false // frontend JS must read these
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
c.Secure = isSecure(r)
|
||||
http.SetCookie(w, &c)
|
||||
// issueSessionAndRedirect creates a session, sets the session and CSRF
|
||||
// cookies, and redirects to the frontend dashboard. The redirectBase param
|
||||
// is the OAuth callback URL; we ignore it after success and send the user to
|
||||
// /dashboard directly (callback page will probe /v1/me to hydrate state).
|
||||
func (h *oauthHandler) issueSessionAndRedirect(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
redirectBase string,
|
||||
) error {
|
||||
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, base, http.StatusFound)
|
||||
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
|
||||
// Send the user to the callback page so the SPA can probe /v1/me and
|
||||
// trigger any post-OAuth UX (e.g. the new-signup name confirmation).
|
||||
http.Redirect(w, r, redirectBase, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
@ -477,7 +489,3 @@ func computeHMAC(key []byte, data string) string {
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -20,13 +19,12 @@ import (
|
||||
)
|
||||
|
||||
type processHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
|
||||
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
|
||||
return &processHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// processResponse is a single entry in the process list.
|
||||
@ -44,23 +42,11 @@ type processListResponse struct {
|
||||
|
||||
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
||||
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -95,24 +81,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
||||
// The selector can be a numeric PID or a string tag.
|
||||
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -146,14 +120,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// wsProcessOut is the JSON message sent to the WebSocket client.
|
||||
type wsProcessOut struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
||||
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
@ -166,37 +132,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendProcessWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
slog.Error("process stream websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -207,17 +145,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox not found")
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
||||
sendWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
@ -236,7 +174,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
||||
|
||||
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
|
||||
sendWSError(conn, "failed to connect to process: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
@ -257,42 +195,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ConnectProcessResponse_Start:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ConnectProcessResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ConnectProcessResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
if streamCtx.Err() == nil {
|
||||
sendProcessWSError(conn, err.Error())
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendProcessWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
@ -30,13 +30,12 @@ const (
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
@ -90,40 +89,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// API key auth is handled by middleware (sets context).
|
||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
slog.Error("pty websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -168,18 +136,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleStart(
|
||||
|
||||
@ -37,8 +37,8 @@ type sandboxResponse struct {
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
DiskSizeMB int32 `json:"disk_size_mb"`
|
||||
DiskUsedMB *int64 `json:"disk_used_mb,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
@ -54,8 +54,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
VCPUs: sb.Vcpus,
|
||||
MemoryMB: sb.MemoryMb,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
DiskSizeMB: sb.DiskSizeMb,
|
||||
}
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
@ -101,14 +100,17 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
|
||||
if err != nil {
|
||||
if sb.ID.Valid {
|
||||
h.audit.LogSandboxDestroySystem(r.Context(), ac.TeamID, sb.ID, "cleanup_after_create_error", nil)
|
||||
}
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/capsules.
|
||||
@ -145,7 +147,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
resp := sandboxToResponse(sb)
|
||||
|
||||
diskBytes, err := h.svc.GetDiskUsage(r.Context(), sandboxID, ac.TeamID)
|
||||
if err == nil {
|
||||
diskUsedMB := diskBytes / (1024 * 1024)
|
||||
resp.DiskUsedMB = &diskUsedMB
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/capsules/{id}/pause.
|
||||
@ -160,14 +170,14 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/capsules/{id}/resume.
|
||||
@ -182,14 +192,14 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/capsules/{id}/ping.
|
||||
@ -223,12 +233,13 @@ func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
err = h.svc.Destroy(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
169
internal/api/handlers_sandbox_events.go
Normal file
169
internal/api/handlers_sandbox_events.go
Normal file
@ -0,0 +1,169 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type sandboxEventHandler struct {
|
||||
db *db.Queries
|
||||
eventPub *channels.Publisher
|
||||
}
|
||||
|
||||
func newSandboxEventHandler(queries *db.Queries, eventPub *channels.Publisher) *sandboxEventHandler {
|
||||
return &sandboxEventHandler{db: queries, eventPub: eventPub}
|
||||
}
|
||||
|
||||
type sandboxEventRequest struct {
|
||||
Event string `json:"event"`
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
HostID string `json:"host_id"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Handle receives lifecycle event callbacks from host agents, translates the
|
||||
// raw host event into the canonical events.Event taxonomy, and publishes once
|
||||
// to the unified Redis stream. The SandboxEventConsumer (independent
|
||||
// consumer group) drives DB reconciliation; the channels dispatcher delivers
|
||||
// to subscribed channels; the SSE relay mirrors via Pub/Sub.
|
||||
func (h *sandboxEventHandler) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
var req sandboxEventRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Event == "" || req.SandboxID == "" || req.HostID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "event, sandbox_id, and host_id are required")
|
||||
return
|
||||
}
|
||||
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
callerHostID := id.FormatHostID(hc.HostID)
|
||||
if callerHostID != req.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host_id does not match authenticated host")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timestamp == 0 {
|
||||
req.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
evt, ok := h.translate(r.Context(), req)
|
||||
if !ok {
|
||||
// Unknown event type — log and accept so the host agent doesn't retry.
|
||||
slog.Warn("sandbox event callback: untranslatable event", "event", req.Event, "sandbox_id", req.SandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
h.eventPub.Publish(r.Context(), evt)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// translate converts a raw host-agent callback into the canonical event.
|
||||
// For failure events without an in-flight verb (e.g. sandbox.failed), the
|
||||
// current DB status is consulted to pick the appropriate verb.
|
||||
func (h *sandboxEventHandler) translate(ctx context.Context, req sandboxEventRequest) (events.Event, bool) {
|
||||
sandboxUUID, parseErr := id.ParseSandboxID(req.SandboxID)
|
||||
if parseErr != nil {
|
||||
return events.Event{}, false
|
||||
}
|
||||
|
||||
var teamID pgtype.UUID
|
||||
if sb, dbErr := h.db.GetSandbox(ctx, sandboxUUID); dbErr == nil {
|
||||
teamID = sb.TeamID
|
||||
}
|
||||
|
||||
base := events.Event{
|
||||
Timestamp: time.Unix(req.Timestamp, 0).UTC().Format(time.RFC3339),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: req.SandboxID, Type: "sandbox"},
|
||||
}
|
||||
|
||||
switch req.Event {
|
||||
case "sandbox.started":
|
||||
meta := map[string]string{}
|
||||
if req.HostIP != "" {
|
||||
meta["host_ip"] = req.HostIP
|
||||
}
|
||||
meta["host_id"] = req.HostID
|
||||
base.Event = events.CapsuleCreate
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = meta
|
||||
case "sandbox.resumed":
|
||||
meta := map[string]string{"host_id": req.HostID}
|
||||
if req.HostIP != "" {
|
||||
meta["host_ip"] = req.HostIP
|
||||
}
|
||||
base.Event = events.CapsuleResume
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = meta
|
||||
case "sandbox.paused":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
case "sandbox.auto_paused":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = map[string]string{"reason": "ttl_expired"}
|
||||
case "sandbox.stopped":
|
||||
base.Event = events.CapsuleDestroy
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
case "sandbox.pause_failed":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
case "sandbox.resume_failed":
|
||||
base.Event = events.CapsuleResume
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
case "sandbox.failed", "sandbox.error":
|
||||
// Pick a verb based on the sandbox's current DB status.
|
||||
verb := h.verbForFailure(ctx, sandboxUUID)
|
||||
base.Event = verb
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
default:
|
||||
return events.Event{}, false
|
||||
}
|
||||
return base, true
|
||||
}
|
||||
|
||||
func (h *sandboxEventHandler) verbForFailure(ctx context.Context, sandboxID pgtype.UUID) string {
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return events.CapsuleDestroy
|
||||
}
|
||||
switch sb.Status {
|
||||
case "starting":
|
||||
return events.CapsuleCreate
|
||||
case "resuming":
|
||||
return events.CapsuleResume
|
||||
case "pausing":
|
||||
return events.CapsulePause
|
||||
case "snapshotting":
|
||||
// A snapshot pauses then resumes the VM; a host-side failure leaves the
|
||||
// sandbox errored, not destroyed. Route through CapsuleCreate so the
|
||||
// consumer's handleFailed marks it "error" rather than removing the row.
|
||||
return events.CapsuleCreate
|
||||
default:
|
||||
return events.CapsuleDestroy
|
||||
}
|
||||
}
|
||||
66
internal/api/handlers_sessions.go
Normal file
66
internal/api/handlers_sessions.go
Normal file
@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
)
|
||||
|
||||
type sessionsHandler struct {
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newSessionsHandler(svc *session.Service) *sessionsHandler {
|
||||
return &sessionsHandler{sessions: svc}
|
||||
}
|
||||
|
||||
// List handles GET /v1/me/sessions — returns all active sessions for the
|
||||
// current user, flagging the caller's own session as current.
|
||||
func (h *sessionsHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
rows, err := h.sessions.ListForUser(r.Context(), ac.UserID)
|
||||
if err != nil {
|
||||
slog.Error("list sessions: db error", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sessions")
|
||||
return
|
||||
}
|
||||
out := make([]sessionRow, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, sessionRow{
|
||||
ID: row.ID,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
CreatedAt: row.CreatedAt.Time.UTC().Format(time.RFC3339),
|
||||
LastSeenAt: row.LastSeenAt.Time.UTC().Format(time.RFC3339),
|
||||
ExpiresAt: row.ExpiresAt.Time.UTC().Format(time.RFC3339),
|
||||
Current: row.ID == ac.SessionID,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"sessions": out})
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/me/sessions/{id} — revokes a single session
|
||||
// belonging to the current user. If the caller revokes their own session,
|
||||
// their cookies are cleared on the response.
|
||||
func (h *sessionsHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sid := chi.URLParam(r, "id")
|
||||
if sid == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "missing session id")
|
||||
return
|
||||
}
|
||||
if err := h.sessions.DeleteForUser(r.Context(), sid, ac.UserID); err != nil {
|
||||
slog.Error("delete session: db error", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete session")
|
||||
return
|
||||
}
|
||||
if sid == ac.SessionID {
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -25,49 +24,53 @@ import (
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
svc *service.TemplateService
|
||||
sandboxSvc *service.SandboxService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
// deleteSnapshotEverywhere removes a template's files from every active host.
|
||||
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
|
||||
//
|
||||
// It is strict by design: deletion is reported successful only when every
|
||||
// active host has either removed the files or reported NotFound (it never
|
||||
// held them). If any host is offline or returns an error, it returns an error
|
||||
// and the caller MUST NOT delete the DB record — doing so would orphan the
|
||||
// files on disk with no record left to retry against.
|
||||
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
|
||||
id.FormatHostID(host.ID), host.Status)
|
||||
}
|
||||
agent, err := pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(teamID),
|
||||
TemplateId: formatUUIDForRPC(templateID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
// NotFound just means this host never held the template.
|
||||
if connect.CodeOf(err) == connect.CodeNotFound {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@ -76,6 +79,7 @@ type snapshotResponse struct {
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Protected bool `json:"protected"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
Platform: t.TeamID == id.PlatformTeamID,
|
||||
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
|
||||
}
|
||||
if t.Vcpus != 0 {
|
||||
resp.VCPUs = &t.Vcpus
|
||||
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots.
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
|
||||
// registers the result as a new template.
|
||||
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SandboxID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(req.SandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
|
||||
// template is registered by a background goroutine; clients learn of the
|
||||
// result via the SSE template.snapshot.create event (or by polling).
|
||||
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
|
||||
// The host agent's CreateSnapshot removes the sandbox from its in-memory
|
||||
// map immediately; if the reconciler fires during the flatten window and
|
||||
// the DB still says "running", it will mark the sandbox "stopped".
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context with a generous timeout so the snapshot completes
|
||||
// even if the client disconnects (the flatten step can take 10-20s).
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
// Generate the new template ID upfront so the host agent knows where to store files.
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(ac.TeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status back to what it was.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/snapshots.
|
||||
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve actual on-disk sizes for templates with unknown size (e.g.
|
||||
// system base templates seeded with size_bytes = 0). This queries a host
|
||||
// agent and persists the result to the DB for subsequent requests.
|
||||
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
|
||||
|
||||
resp := make([]snapshotResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateToResponse(t)
|
||||
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusConflict, "delete_failed",
|
||||
"could not remove snapshot files from all hosts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name)
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
79
internal/api/handlers_sse.go
Normal file
79
internal/api/handlers_sse.go
Normal file
@ -0,0 +1,79 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const sseKeepaliveInterval = 30 * time.Second
|
||||
|
||||
type sseHandler struct {
|
||||
broker *SSEBroker
|
||||
}
|
||||
|
||||
func newSSEHandler(broker *SSEBroker) *sseHandler {
|
||||
return &sseHandler{broker: broker}
|
||||
}
|
||||
|
||||
// Stream handles GET /v1/events/stream. Authentication is performed by
|
||||
// upstream middleware: browser clients use the wrenn_sid cookie (which the
|
||||
// EventSource API forwards automatically); SDK clients use X-API-Key.
|
||||
func (h *sseHandler) Stream(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
|
||||
return
|
||||
}
|
||||
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), false)
|
||||
}
|
||||
|
||||
// AdminStream handles GET /v1/admin/events/stream. Upstream middleware
|
||||
// must enforce session auth + requireAdmin before reaching this handler.
|
||||
func (h *sseHandler) AdminStream(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "admin session required")
|
||||
return
|
||||
}
|
||||
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), true)
|
||||
}
|
||||
|
||||
func (h *sseHandler) serveSSE(w http.ResponseWriter, r *http.Request, teamID string, isAdmin bool) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
subID, ch := h.broker.Subscribe(teamID, isAdmin)
|
||||
defer h.broker.Unsubscribe(subID)
|
||||
|
||||
fmt.Fprintf(w, "event: connected\ndata: {\"message\":\"connected\"}\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
keepalive := time.NewTicker(sseKeepaliveInterval)
|
||||
defer keepalive.Stop()
|
||||
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", msg.EventType, msg.Data)
|
||||
flusher.Flush()
|
||||
case <-keepalive.C:
|
||||
fmt.Fprintf(w, ": keepalive\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -14,19 +14,21 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer}
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer, sessions *session.Service) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer, sessions: sessions}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
@ -366,6 +368,11 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Drop cached session blobs so the new role propagates immediately.
|
||||
if err := h.sessions.InvalidateCacheForUser(r.Context(), targetUserID); err != nil {
|
||||
_ = err
|
||||
}
|
||||
|
||||
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -450,6 +457,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
ActiveSandboxCount int32 `json:"active_sandbox_count"`
|
||||
ChannelCount int32 `json:"channel_count"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
}
|
||||
|
||||
resp := make([]adminTeamResponse, len(teams))
|
||||
@ -465,6 +474,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
RunningVcpus: t.RunningVcpus,
|
||||
RunningMemoryMb: t.RunningMemoryMb,
|
||||
}
|
||||
if t.DeletedAt != nil {
|
||||
s := t.DeletedAt.Format(time.RFC3339)
|
||||
|
||||
@ -11,19 +11,21 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
audit *audit.AuditLogger
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
audit *audit.AuditLogger
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc, audit: al}
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger, sessions *session.Service) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc, audit: al, sessions: sessions}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
@ -158,6 +160,10 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Active {
|
||||
h.audit.LogUserActivate(r.Context(), ac, userID, user.Email)
|
||||
} else {
|
||||
// Disabled users must be kicked out of every active session.
|
||||
if err := h.sessions.RevokeAllForUser(r.Context(), userID); err != nil {
|
||||
_ = err
|
||||
}
|
||||
h.audit.LogUserDeactivate(r.Context(), ac, userID, user.Email)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
@ -215,5 +221,14 @@ func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
|
||||
}
|
||||
|
||||
// Invalidate cached session blobs so the new is_admin flag is reflected
|
||||
// on the user's next request without waiting for the Redis TTL.
|
||||
if err := h.sessions.InvalidateCacheForUser(r.Context(), userID); err != nil {
|
||||
// Cache is best-effort; the DB is authoritative and requireAdmin
|
||||
// always re-reads it.
|
||||
_ = err
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
||||
type ctxKeyAdminWS struct{}
|
||||
|
||||
// setAdminWSFlag marks the context as an admin WebSocket route.
|
||||
func setAdminWSFlag(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
||||
}
|
||||
|
||||
// isAdminWSRoute checks if the request context was marked as admin WS.
|
||||
func isAdminWSRoute(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
||||
type wsAuthMsg struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
||||
// authenticated context. The caller must send this as the first message after
|
||||
// connecting.
|
||||
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var msg wsAuthMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
if msg.Type != "auth" || msg.Token == "" {
|
||||
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
||||
}
|
||||
|
||||
return auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
||||
// returning an AuthContext with the platform team ID.
|
||||
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, err
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
return ac, nil
|
||||
}
|
||||
@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
@ -15,10 +16,30 @@ import (
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// errInferredTransientTimeout marks a state change that the reconciler
|
||||
// inferred after a transient (starting/resuming) sandbox failed to settle
|
||||
// within the grace period. Used as the err value on system audit calls so
|
||||
// the published event carries Outcome=error with a human-readable message.
|
||||
var errInferredTransientTimeout = errors.New("transient state did not settle within grace period")
|
||||
|
||||
// unreachableThreshold is how long a host can go without a heartbeat before
|
||||
// it is considered unreachable (3 missed 30-second heartbeats).
|
||||
const unreachableThreshold = 90 * time.Second
|
||||
|
||||
// transientGracePeriod is how long a sandbox is allowed to stay in a transient
|
||||
// status (starting, resuming, pausing, stopping) before the monitor infers a
|
||||
// final state. This prevents the monitor from racing against in-flight RPCs
|
||||
// that may not have registered the sandbox on the host agent yet.
|
||||
const transientGracePeriod = 2 * time.Minute
|
||||
|
||||
// snapshotGracePeriod is the grace for a sandbox stuck in "snapshotting" while
|
||||
// the VM is still alive on the host. Snapshots dump guest RAM and flatten the
|
||||
// rootfs, which can run for minutes on large sandboxes, and the agent reports
|
||||
// the VM as alive throughout — so we must not race the in-flight operation.
|
||||
// It exceeds the background goroutine's 10-minute deadline, so reaching it
|
||||
// means the control plane crashed mid-snapshot and the sandbox needs recovery.
|
||||
const snapshotGracePeriod = 15 * time.Minute
|
||||
|
||||
// HostMonitor runs on a fixed interval and performs two duties:
|
||||
//
|
||||
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
|
||||
@ -77,6 +98,21 @@ func (m *HostMonitor) run(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// ReconcileHost triggers immediate active reconciliation for a single host.
|
||||
// Called when a host transitions from unreachable → online so sandboxes marked
|
||||
// "missing" are resolved without waiting for the next monitor tick.
|
||||
func (m *HostMonitor) ReconcileHost(ctx context.Context, hostID pgtype.UUID) {
|
||||
host, err := m.db.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: reconcile-on-connect: failed to get host", "error", err)
|
||||
return
|
||||
}
|
||||
if host.Status != "online" {
|
||||
return
|
||||
}
|
||||
m.checkHost(ctx, host)
|
||||
}
|
||||
|
||||
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
// --- Passive phase: check heartbeat staleness ---
|
||||
|
||||
@ -116,21 +152,29 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build set of sandbox IDs alive on the host.
|
||||
// The host agent returns sandbox IDs as strings (formatted with prefix).
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
// Build map of sandbox ID -> reported status. Transient statuses
|
||||
// (pausing/resuming/starting/stopping) are coerced to a presence-only
|
||||
// entry: ListSandboxes can observe the in-memory status mid-transition
|
||||
// (Pause flips the status under m.mu while List holds m.mu.RLock), and
|
||||
// writing those transient labels into the DB would force the transient
|
||||
// reconciliation phase to wait the full grace period before resolving.
|
||||
// Recording the presence keeps "missing → restore" and "running →
|
||||
// orphan-stop" logic correct without overwriting with stale labels;
|
||||
// the next monitor tick reads the settled status.
|
||||
aliveStatus := make(map[string]string, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, apID := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPaused[apID] = struct{}{}
|
||||
status := sb.Status
|
||||
switch status {
|
||||
case "pausing", "resuming", "starting", "stopping":
|
||||
status = ""
|
||||
}
|
||||
aliveStatus[sb.SandboxId] = status
|
||||
}
|
||||
|
||||
// --- Restore sandboxes that are "missing" in DB but alive on host ---
|
||||
// This handles the case where CP marked them missing due to a transient
|
||||
// heartbeat gap, but the host was actually fine.
|
||||
// Handles transient heartbeat gaps where the host was actually fine. The
|
||||
// reported status must be honored: a sandbox the agent paused while CP
|
||||
// was disconnected must not be silently promoted back to running.
|
||||
|
||||
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
@ -139,34 +183,65 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
var toRestore []pgtype.UUID
|
||||
var toStop []pgtype.UUID
|
||||
restoreByStatus := make(map[string][]db.Sandbox)
|
||||
var toStop []db.Sandbox
|
||||
for _, sb := range missingSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
toRestore = append(toRestore, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok {
|
||||
toStop = append(toStop, sb)
|
||||
continue
|
||||
}
|
||||
if status == "" {
|
||||
continue
|
||||
}
|
||||
restoreByStatus[status] = append(restoreByStatus[status], sb)
|
||||
}
|
||||
if len(toRestore) > 0 {
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
|
||||
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
for status, sbs := range restoreByStatus {
|
||||
ids := make([]pgtype.UUID, len(sbs))
|
||||
for i, sb := range sbs {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
|
||||
if err := m.db.BulkRestoreMissingToStatus(ctx, db.BulkRestoreMissingToStatusParams{
|
||||
Column1: ids,
|
||||
Status: status,
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
|
||||
continue
|
||||
}
|
||||
// Only restore→paused emits a notification (per design: running restore is silent).
|
||||
if status == "paused" {
|
||||
for _, sb := range sbs {
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "restored_after_host_recovery", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
ids := make([]pgtype.UUID, len(toStop))
|
||||
for i, sb := range toStop {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Column1: ids,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range toStop {
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Find running sandboxes in DB that are no longer alive on the host ---
|
||||
// --- Reconcile running sandboxes in DB against live host state ---
|
||||
// Three cases per DB-running row:
|
||||
// absent on host -> stopped
|
||||
// present and running -> no change
|
||||
// present but paused/etc. -> sync DB to reported status (catches the
|
||||
// shutdown-pause notify failure case)
|
||||
|
||||
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
@ -177,40 +252,196 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
return
|
||||
}
|
||||
|
||||
var toPause, toStop []pgtype.UUID
|
||||
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
|
||||
var toStop []db.Sandbox
|
||||
syncByStatus := make(map[string][]db.Sandbox)
|
||||
for _, sb := range runningSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
sbTeamID[sb.ID] = sb.TeamID
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok {
|
||||
toStop = append(toStop, sb)
|
||||
continue
|
||||
}
|
||||
if _, ok := autoPaused[sbIDStr]; ok {
|
||||
toPause = append(toPause, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
if status == "running" || status == "" {
|
||||
continue
|
||||
}
|
||||
syncByStatus[status] = append(syncByStatus[status], sb)
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
for _, sbID := range toPause {
|
||||
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
ids := make([]pgtype.UUID, len(toStop))
|
||||
for i, sb := range toStop {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Column1: ids,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range toStop {
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
for status, sbs := range syncByStatus {
|
||||
ids := make([]pgtype.UUID, len(sbs))
|
||||
for i, sb := range sbs {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: syncing running→reported status", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: ids,
|
||||
Status: status,
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to sync running sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
|
||||
continue
|
||||
}
|
||||
if status == "paused" {
|
||||
for _, sb := range sbs {
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "host_state_sync", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Reconcile DB-stopped + agent-paused zombies ---
|
||||
// A sandbox the agent reports as 'paused' but DB has as 'stopped' is an
|
||||
// orphan from a previous bug where a successful pause's auto_paused
|
||||
// callback was lost (e.g. CP unreachable during agent shutdown). With the
|
||||
// agent-side fix (RestorePausedSandboxes), the snapshot survives across
|
||||
// agent restarts and surfaces here. Authoritative direction: DB wins
|
||||
// (user already saw 'stopped' and may have stopped tracking it).
|
||||
// Issue Destroy so the on-disk snapshot dir is removed and the agent's
|
||||
// slot reservation released.
|
||||
//
|
||||
// Gate: only run the DB query if the agent reports at least one paused
|
||||
// sandbox. Otherwise we'd fetch every historically-stopped sandbox on
|
||||
// this host every monitor tick — unbounded growth over a host's lifetime.
|
||||
hasPaused := false
|
||||
for _, status := range aliveStatus {
|
||||
if status == "paused" {
|
||||
hasPaused = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasPaused {
|
||||
stoppedSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"stopped"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list stopped sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range stoppedSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok || status != "paused" {
|
||||
continue
|
||||
}
|
||||
slog.Info("host monitor: destroying DB-stopped agent-paused zombie",
|
||||
"host_id", id.FormatHostID(host.ID), "sandbox_id", sbIDStr)
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sbIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("host monitor: zombie destroy failed",
|
||||
"sandbox_id", sbIDStr, "error", err)
|
||||
continue
|
||||
}
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "paused_zombie_cleanup", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Reconcile transient statuses (starting, resuming, pausing, stopping) ---
|
||||
// These represent in-flight operations. If the sandbox is no longer alive on
|
||||
// the host, infer the final state based on the transient status.
|
||||
|
||||
transientSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"starting", "resuming", "pausing", "stopping", "snapshotting"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list transient sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, sb := range transientSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if agentStatus, ok := aliveStatus[sbIDStr]; ok {
|
||||
// Sandbox is alive on host — the background goroutine should
|
||||
// finalize the transition. For starting/resuming, if the sandbox
|
||||
// is alive it means creation/resume succeeded.
|
||||
if sb.Status == "starting" || sb.Status == "resuming" {
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: sb.Status, Status_2: "running",
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: promoted transient sandbox to running", "sandbox_id", sbIDStr, "from", sb.Status)
|
||||
}
|
||||
}
|
||||
// A snapshot keeps the source sandbox alive throughout, so an alive
|
||||
// sandbox does NOT mean the snapshot finished. Only recover it once
|
||||
// it has been stuck past the snapshot grace period (i.e. the CP
|
||||
// crashed mid-op). Recover to the sandbox's actual host-side status:
|
||||
// a running sandbox is snapshotted live and stays running, but a
|
||||
// paused sandbox is snapshotted from disk and must return to paused.
|
||||
if sb.Status == "snapshotting" &&
|
||||
sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) >= snapshotGracePeriod {
|
||||
recoverTo := agentStatus
|
||||
if recoverTo != "running" && recoverTo != "paused" {
|
||||
// Coerced/unknown agent label — default to running.
|
||||
recoverTo = "running"
|
||||
}
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: "snapshotting", Status_2: recoverTo,
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: recovered stuck snapshotting sandbox", "sandbox_id", sbIDStr, "to", recoverTo)
|
||||
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "snapshot_recovered", nil)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Sandbox is not alive on host. If the transition is recent, give the
|
||||
// in-flight RPC time to finish before declaring a final state.
|
||||
if sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) < transientGracePeriod {
|
||||
slog.Debug("host monitor: transient sandbox still within grace period",
|
||||
"sandbox_id", sbIDStr, "status", sb.Status,
|
||||
"age", time.Since(sb.LastUpdated.Time).Round(time.Second))
|
||||
continue
|
||||
}
|
||||
|
||||
// Grace period expired — infer final state.
|
||||
var finalStatus string
|
||||
switch sb.Status {
|
||||
case "starting", "resuming":
|
||||
finalStatus = "error"
|
||||
case "pausing":
|
||||
finalStatus = "paused"
|
||||
case "stopping":
|
||||
finalStatus = "stopped"
|
||||
case "snapshotting":
|
||||
// VM is gone but DB says snapshotting → the snapshot died with the VM.
|
||||
finalStatus = "error"
|
||||
}
|
||||
fromStatus := sb.Status
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: fromStatus, Status_2: finalStatus,
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: resolved transient sandbox", "sandbox_id", sbIDStr, "from", fromStatus, "to", finalStatus)
|
||||
inferredErr := errInferredTransientTimeout
|
||||
switch fromStatus {
|
||||
case "starting":
|
||||
m.audit.LogSandboxCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "resuming":
|
||||
m.audit.LogSandboxResumeSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "pausing":
|
||||
// Pause assumed to have succeeded host-side; emit success with inferred metadata.
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
|
||||
case "snapshotting":
|
||||
// VM gone mid-snapshot; the sandbox is errored.
|
||||
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "stopping":
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +50,9 @@ func agentErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
|
||||
case connect.CodeAlreadyExists:
|
||||
return http.StatusConflict, "already_exists", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
case connect.CodePermissionDenied:
|
||||
return http.StatusForbidden, "forbidden", err.Error()
|
||||
|
||||
@ -26,19 +26,10 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// markAdminWS flags the request context as an admin WebSocket route.
|
||||
// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin
|
||||
// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin.
|
||||
func markAdminWS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context())))
|
||||
})
|
||||
}
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
// Must run after requireSession (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the session blob's is_admin flag is for
|
||||
// UI hints; the DB is the source of truth for admin access.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
// Both stamp TeamID into the request context via auth.AuthContext.
|
||||
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key first.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token from Authorization header.
|
||||
tokenStr := ""
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr != "" {
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject
|
||||
// unauthenticated requests. It injects auth context when valid credentials
|
||||
// are present (supporting SDK clients that set X-API-Key on WebSocket
|
||||
// upgrades) and passes through otherwise so the handler can authenticate
|
||||
// after the WebSocket upgrade via the first message.
|
||||
func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err == nil {
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil {
|
||||
if teamID, err := id.ParseTeamID(claims.TeamID); err == nil {
|
||||
if userID, err := id.ParseUserID(claims.Subject); err == nil {
|
||||
if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" {
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No valid credentials — pass through for handler to authenticate.
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/api/middleware_csrf.go
Normal file
38
internal/api/middleware_csrf.go
Normal file
@ -0,0 +1,38 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
)
|
||||
|
||||
// requireCSRF enforces double-submit CSRF: the wrenn_csrf cookie value must
|
||||
// equal the X-CSRF-Token header. Skipped for safe methods (GET/HEAD/OPTIONS)
|
||||
// and for requests authenticated via X-API-Key (SDK clients are not
|
||||
// vulnerable to cross-site request forgery against cookie auth).
|
||||
func requireCSRF() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// API-key auth path: no CSRF check needed.
|
||||
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie(csrfCookieName)
|
||||
header := r.Header.Get("X-CSRF-Token")
|
||||
if err != nil || cookie.Value == "" || header == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
|
||||
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,69 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header.
|
||||
// It also verifies the user is still active in the database.
|
||||
// WebSocket upgrade requests without an Authorization header are passed through
|
||||
// — WS handlers authenticate via the first message after upgrade.
|
||||
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var tokenStr string
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
IsAdmin: claims.IsAdmin,
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
34
internal/api/middleware_session.go
Normal file
34
internal/api/middleware_session.go
Normal file
@ -0,0 +1,34 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
sessionmw "git.omukk.dev/wrenn/wrenn/pkg/auth/session/middleware"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// Internal aliases — the canonical implementations live in the public
|
||||
// pkg/auth/session/middleware package so cloud extensions can call them.
|
||||
|
||||
const csrfCookieName = sessionmw.CSRFCookieName
|
||||
|
||||
func requireSession(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSession(svc, queries)
|
||||
}
|
||||
|
||||
func requireSessionOrAPIKey(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSessionOrAPIKey(svc, queries)
|
||||
}
|
||||
|
||||
func setSessionCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
|
||||
sessionmw.SetCookies(w, sid, csrfToken, secure)
|
||||
}
|
||||
|
||||
func clearSessionCookies(w http.ResponseWriter, secure bool) {
|
||||
sessionmw.ClearCookies(w, secure)
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return sessionmw.IsSecure(r)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
310
internal/api/sandbox_event_consumer.go
Normal file
310
internal/api/sandbox_event_consumer.go
Normal file
@ -0,0 +1,310 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const (
|
||||
unifiedEventStream = "wrenn:events"
|
||||
reconcilerConsumerGrp = "wrenn-sandbox-reconciler-v1"
|
||||
reconcilerConsumer = "cp-0"
|
||||
)
|
||||
|
||||
// SandboxEventConsumer reads capsule lifecycle events from the unified Redis
|
||||
// stream and drives DB state reconciliation. Uses an independent consumer
|
||||
// group so its cursor is separate from the channels dispatcher.
|
||||
type SandboxEventConsumer struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
hooks []cpextension.SandboxEventHook
|
||||
}
|
||||
|
||||
// NewSandboxEventConsumer creates a consumer.
|
||||
func NewSandboxEventConsumer(rdb *redis.Client, queries *db.Queries, al *audit.AuditLogger, hooks []cpextension.SandboxEventHook) *SandboxEventConsumer {
|
||||
return &SandboxEventConsumer{rdb: rdb, db: queries, audit: al, hooks: hooks}
|
||||
}
|
||||
|
||||
// Start launches the consumer goroutine. Reads from "$" so prior history
|
||||
// is not replayed.
|
||||
func (c *SandboxEventConsumer) Start(ctx context.Context) {
|
||||
go c.run(ctx)
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) run(ctx context.Context) {
|
||||
err := c.rdb.XGroupCreateMkStream(ctx, unifiedEventStream, reconcilerConsumerGrp, "$").Err()
|
||||
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||
slog.Error("sandbox event consumer: failed to create consumer group", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
streams, err := c.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: reconcilerConsumerGrp,
|
||||
Consumer: reconcilerConsumer,
|
||||
Streams: []string{unifiedEventStream, ">"},
|
||||
Count: 10,
|
||||
Block: 5 * time.Second,
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil || ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
slog.Warn("sandbox event consumer: xreadgroup error", "error", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, msg := range stream.Messages {
|
||||
c.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
ack := true
|
||||
defer func() {
|
||||
if !ack {
|
||||
return
|
||||
}
|
||||
ackCtx, ackCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer ackCancel()
|
||||
if err := c.rdb.XAck(ackCtx, unifiedEventStream, reconcilerConsumerGrp, msg.ID).Err(); err != nil {
|
||||
slog.Warn("sandbox event consumer: xack failed", "id", msg.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload, ok := msg.Values["payload"].(string)
|
||||
if !ok {
|
||||
slog.Warn("sandbox event consumer: message missing payload", "id", msg.ID)
|
||||
return
|
||||
}
|
||||
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
slog.Warn("sandbox event consumer: failed to unmarshal event", "id", msg.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Only capsule.* events drive sandbox reconciliation.
|
||||
if !strings.HasPrefix(event.Event, "capsule.") || event.Event == events.CapsuleStateChanged {
|
||||
return
|
||||
}
|
||||
// Only system-actor events represent host-side state we need to reflect
|
||||
// in the DB; user-actor events are already mirrored by the handler that
|
||||
// produced them.
|
||||
if event.Actor.Type != events.ActorSystem {
|
||||
// Exception: handlers publish capsule.create with user actor before
|
||||
// the host has reported back. Those are owned by the service goroutine.
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(event.Resource.ID)
|
||||
if err != nil {
|
||||
slog.Warn("sandbox event consumer: invalid sandbox ID", "sandbox_id", event.Resource.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Event {
|
||||
case events.CapsuleCreate:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStarted(ctx, sandboxID, event, "starting")
|
||||
} else {
|
||||
c.handleFailed(ctx, sandboxID, event)
|
||||
}
|
||||
case events.CapsuleResume:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStarted(ctx, sandboxID, event, "resuming")
|
||||
} else {
|
||||
c.handleFailed(ctx, sandboxID, event)
|
||||
}
|
||||
case events.CapsulePause:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleAutoPaused(ctx, sandboxID)
|
||||
}
|
||||
case events.CapsuleDestroy:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStopped(ctx, sandboxID)
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch to extension hooks (cloud billing, audit shipping, etc.). Any
|
||||
// hook error suppresses the ack so the message will be redelivered. Hooks
|
||||
// MUST be idempotent — duplicate deliveries are expected on transient
|
||||
// failures.
|
||||
if len(c.hooks) > 0 && event.Outcome == events.OutcomeSuccess {
|
||||
if verb, ok := canonicalSandboxVerb(event.Event); ok {
|
||||
teamID, _ := id.ParseTeamID(event.TeamID)
|
||||
meta := map[string]any{}
|
||||
for k, v := range event.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
ev := cpextension.SandboxEvent{
|
||||
SandboxID: sandboxID,
|
||||
TeamID: teamID,
|
||||
Type: verb,
|
||||
OccurredAt: parseEventTimestamp(event.Timestamp),
|
||||
Metadata: meta,
|
||||
}
|
||||
for _, h := range c.hooks {
|
||||
if err := h.OnSandboxEvent(ctx, ev); err != nil {
|
||||
slog.Warn("sandbox event hook failed; leaving message un-acked", "id", msg.ID, "event", event.Event, "error", err)
|
||||
ack = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalSandboxVerb(event string) (string, bool) {
|
||||
switch event {
|
||||
case events.CapsuleCreate:
|
||||
return "created", true
|
||||
case events.CapsuleResume:
|
||||
return "resumed", true
|
||||
case events.CapsulePause:
|
||||
return "paused", true
|
||||
case events.CapsuleDestroy:
|
||||
return "destroyed", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func parseEventTimestamp(s string) time.Time {
|
||||
if s == "" {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// handleStarted is a fallback writer for capsule.create.success and
|
||||
// capsule.resume.success. The background goroutine in SandboxService is the
|
||||
// primary writer; this only succeeds if the goroutine's conditional update
|
||||
// was missed.
|
||||
func (c *SandboxEventConsumer) handleStarted(ctx context.Context, sandboxID pgtype.UUID, event events.Event, fromStatus string) {
|
||||
hostIP := event.Metadata["host_ip"]
|
||||
now := time.Now()
|
||||
if _, err := c.db.UpdateSandboxRunningIf(ctx, db.UpdateSandboxRunningIfParams{
|
||||
ID: sandboxID,
|
||||
Status: fromStatus,
|
||||
HostIp: hostIP,
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleAutoPaused(ctx context.Context, sandboxID pgtype.UUID) {
|
||||
for _, fromStatus := range []string{"running", "pausing"} {
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: fromStatus, Status_2: "paused",
|
||||
}); err == nil {
|
||||
slog.Debug("sandbox event consumer: auto-paused fallback applied", "sandbox_id", id.FormatSandboxID(sandboxID), "from", fromStatus)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleStopped(ctx context.Context, sandboxID pgtype.UUID) {
|
||||
// stopping → stopped (CP-initiated destroy completed). No audit row here;
|
||||
// the handler that issued the destroy already wrote one.
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "stopping",
|
||||
Status_2: "stopped",
|
||||
}); err == nil {
|
||||
return
|
||||
}
|
||||
// running → stopped (autonomous destroy, e.g. TTL destroy fallback).
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "running",
|
||||
Status_2: "stopped",
|
||||
}); err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("sandbox event consumer: failed to update sandbox to stopped", "sandbox_id", id.FormatSandboxID(sandboxID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleFailed marks a sandbox as "error" when a verb event reports failure
|
||||
// and writes a system audit row. The DB update is idempotent — the
|
||||
// SandboxService background goroutine usually wrote "error" already on the
|
||||
// fast-fail path, which settles in seconds and so never reaches the
|
||||
// HostMonitor's transient-timeout reconciliation.
|
||||
//
|
||||
// audit.Log writes the row only — it does NOT republish an event, which would
|
||||
// loop back into this consumer. Do not switch to LogSandboxCreateSystem here.
|
||||
func (c *SandboxEventConsumer) handleFailed(ctx context.Context, sandboxID pgtype.UUID, event events.Event) {
|
||||
for _, fromStatus := range []string{"running", "starting", "pausing", "resuming", "snapshotting"} {
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: fromStatus, Status_2: "error",
|
||||
}); err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// The HostMonitor transient-timeout reconciler emits failure events via
|
||||
// LogSandboxCreateSystem / LogSandboxResumeSystem, which already write
|
||||
// their own audit row before publishing — auditing again here would
|
||||
// double-count. Those helpers publish with reason="transient_timeout";
|
||||
// the un-audited fast-fail (createInBackground) and host-callback paths
|
||||
// do not, so only they need a row written here.
|
||||
if event.Metadata["reason"] == "transient_timeout" {
|
||||
return
|
||||
}
|
||||
|
||||
action := "create"
|
||||
if event.Event == events.CapsuleResume {
|
||||
action = "resume"
|
||||
}
|
||||
reason := event.Metadata["reason"]
|
||||
if reason == "" {
|
||||
reason = action + "_failed"
|
||||
}
|
||||
meta := map[string]any{"reason": reason}
|
||||
if event.Error != "" {
|
||||
meta["error"] = event.Error
|
||||
}
|
||||
teamID, _ := id.ParseTeamID(event.TeamID)
|
||||
c.audit.Log(ctx, audit.Entry{
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: action,
|
||||
Scope: "team",
|
||||
Status: "error",
|
||||
Metadata: meta,
|
||||
})
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -13,9 +14,11 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
authsession "git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
@ -26,15 +29,19 @@ var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
version string
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
SSERelay *SSERelay
|
||||
SessionSvc *authsession.Service
|
||||
version string
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
// Extensions are called after core routes are registered, allowing cloud
|
||||
// or third-party code to add routes and middleware.
|
||||
// New constructs the chi router and registers all routes. The jwtSecret is
|
||||
// still used to sign host-agent JWTs (long-lived, for the wrenn-agent → cp
|
||||
// trust path) and to HMAC OAuth state/link cookies; user authentication has
|
||||
// migrated to opaque session cookies backed by the session service.
|
||||
func New(
|
||||
ctx context.Context,
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
sched scheduler.HostScheduler,
|
||||
@ -45,10 +52,12 @@ func New(
|
||||
oauthRedirectURL string,
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
eventPub *channels.Publisher,
|
||||
channelSvc *channels.Service,
|
||||
mailer email.Mailer,
|
||||
extensions []cpextension.Extension,
|
||||
sctx cpextension.ServerContext,
|
||||
monitor *HostMonitor,
|
||||
version string,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
@ -61,8 +70,28 @@ func New(
|
||||
}
|
||||
}
|
||||
|
||||
// Session service backs cookie-based browser auth. The cpserver wires it
|
||||
// through ServerContext so cloud-repo extensions can share the same
|
||||
// instance; fall back to constructing one here if the host program
|
||||
// instantiates the API directly (tests, ad-hoc tooling).
|
||||
sessionSvc := sctx.Sessions
|
||||
if sessionSvc == nil {
|
||||
sessionSvc = authsession.NewService(queries, rdb)
|
||||
}
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||
sandboxSvc.PublishEvent = func(ctx context.Context, event service.SandboxStateEvent) {
|
||||
if evt, ok := serviceEventToCanonical(event); ok {
|
||||
// State-change events are ephemeral UI signals — mirror them to the
|
||||
// dashboard via Pub/Sub only, never to durable channel subscribers.
|
||||
if evt.Event == events.CapsuleStateChanged {
|
||||
eventPub.PublishTransient(ctx, evt)
|
||||
} else {
|
||||
eventPub.Publish(ctx, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
@ -72,30 +101,40 @@ func New(
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
usageSvc := &service.UsageService{DB: queries}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
buildBroker := service.NewBuildBroker(rdb)
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool, jwtSecret)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
fsH := newFSHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
snapshots := newSnapshotHandler(templateSvc, sandboxSvc, queries, pool, al)
|
||||
authHooks := collectAuthHooks(extensions)
|
||||
authH := newAuthHandler(queries, pgPool, sessionSvc, mailer, rdb, oauthRedirectURL, authHooks)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, sessionSvc, oauthRegistry, oauthRedirectURL, authHooks)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer)
|
||||
usersH := newUsersHandler(queries, userSvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al, monitor)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer, sessionSvc)
|
||||
usersH := newUsersHandler(queries, userSvc, al, sessionSvc)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
usageH := newUsageHandler(usageSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool, al)
|
||||
buildStreamH := newBuildStreamHandler(queries, buildBroker)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool, jwtSecret)
|
||||
processH := newProcessHandler(queries, pool, jwtSecret)
|
||||
ptyH := newPtyHandler(queries, pool)
|
||||
processH := newProcessHandler(queries, pool)
|
||||
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
|
||||
sandboxEvtH := newSandboxEventHandler(queries, eventPub)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, sessionSvc, mailer, oauthRegistry, oauthRedirectURL, teamSvc, authHooks)
|
||||
sessionsH := newSessionsHandler(sessionSvc)
|
||||
|
||||
// SSE real-time event streaming.
|
||||
sseBroker := NewSSEBroker()
|
||||
sseRelay := NewSSERelay(rdb, queries, sseBroker)
|
||||
sseH := newSSEHandler(sseBroker)
|
||||
|
||||
// Health check.
|
||||
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -107,7 +146,8 @@ func New(
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
r.Get("/docs", serveDocs)
|
||||
|
||||
// Unauthenticated auth endpoints.
|
||||
// Unauthenticated auth endpoints. CSRF is not required here — there is
|
||||
// no session cookie yet to be forged.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Post("/v1/auth/activate", authH.Activate)
|
||||
@ -118,31 +158,45 @@ func New(
|
||||
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
|
||||
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
|
||||
|
||||
// JWT-authenticated: self-service account management.
|
||||
csrf := requireCSRF()
|
||||
|
||||
// Session-authenticated: logout endpoints (must be inside the session
|
||||
// group so the handler sees the current SID via AuthContext).
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/v1/auth/logout", authH.Logout)
|
||||
r.Post("/v1/auth/logout-all", authH.LogoutAll)
|
||||
r.Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
})
|
||||
|
||||
// Session-authenticated: self-service account management.
|
||||
r.Route("/v1/me", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", meH.GetMe)
|
||||
r.Patch("/", meH.UpdateName)
|
||||
r.Post("/password", meH.ChangePassword)
|
||||
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
|
||||
r.Delete("/providers/{provider}", meH.DisconnectProvider)
|
||||
r.Delete("/", meH.DeleteAccount)
|
||||
r.Get("/sessions", sessionsH.List)
|
||||
r.Delete("/sessions/{id}", sessionsH.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
// Session-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
// Session-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -157,14 +211,15 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
|
||||
// Session-authenticated: user search (for add-member UI).
|
||||
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Capsule lifecycle: accepts API key or JWT bearer token.
|
||||
// Capsule lifecycle: API key (SDK) or session cookie (browser).
|
||||
r.Route("/v1/capsules", func(r chi.Router) {
|
||||
// Auth-required routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
r.Get("/stats", statsH.GetStats)
|
||||
@ -174,7 +229,8 @@ func New(
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
// Auth-required non-WS routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
@ -193,11 +249,11 @@ func New(
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
})
|
||||
|
||||
// WebSocket endpoints — handlers authenticate after upgrade.
|
||||
// optionalAPIKeyOrJWT injects auth context from headers when
|
||||
// present (SDK clients) but does not reject when absent (browsers).
|
||||
// WebSocket endpoints — middleware injects auth context from
|
||||
// cookie (browser) or X-API-Key (SDK) before upgrade. CSRF is
|
||||
// not applicable to GET upgrades.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(optionalAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
@ -205,9 +261,10 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// Snapshot / template management: accepts API key or JWT bearer token.
|
||||
// Snapshot / template management: API key (SDK) or session (browser).
|
||||
r.Route("/v1/snapshots", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", snapshots.Create)
|
||||
r.Get("/", snapshots.List)
|
||||
r.Delete("/{name}", snapshots.Delete)
|
||||
@ -221,12 +278,14 @@ func New(
|
||||
// Unauthenticated: refresh token exchange.
|
||||
r.Post("/auth/refresh", hostH.RefreshToken)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
// Host-token-authenticated: heartbeat and lifecycle callbacks.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
r.With(requireHostToken(jwtSecret)).Post("/sandbox-events", sandboxEvtH.Handle)
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
// Session-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -241,9 +300,10 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: notification channels.
|
||||
// Session-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
@ -255,18 +315,27 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
|
||||
// SSE event stream: browser sends wrenn_sid cookie on EventSource
|
||||
// automatically; SDKs may set X-API-Key. Ticket-based auth is gone.
|
||||
r.With(requireSessionOrAPIKey(queries, sessionSvc)).Get("/v1/events/stream", sseH.Stream)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
// Session-authenticated: audit log.
|
||||
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — session + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
// Admin SSE event stream (sees all teams).
|
||||
r.With(requireSession(queries, sessionSvc), requireAdmin(queries), injectPlatformTeam()).Get("/events/stream", sseH.AdminStream)
|
||||
|
||||
// Auth-required admin routes (non-capsule + capsule list/create).
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(csrf)
|
||||
r.Get("/teams", teamH.AdminListTeams)
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
|
||||
r.Get("/hosts", hostH.AdminList)
|
||||
r.Get("/users", usersH.AdminListUsers)
|
||||
r.Put("/users/{id}/active", usersH.SetUserActive)
|
||||
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
|
||||
@ -281,11 +350,21 @@ func New(
|
||||
r.Get("/capsules", adminCapsules.List)
|
||||
})
|
||||
|
||||
// Admin build console WebSocket — cookie + admin DB check before
|
||||
// upgrade, no CSRF (WS upgrade). Builds are platform-scoped, not
|
||||
// sandbox-scoped, so this sits outside the /capsules/{id} router.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Get("/builds/{id}/stream", buildStreamH.Stream)
|
||||
})
|
||||
|
||||
r.Route("/capsules/{id}", func(r chi.Router) {
|
||||
// Auth-required non-WS admin capsule routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(csrf)
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/", adminCapsules.Get)
|
||||
r.Delete("/", adminCapsules.Destroy)
|
||||
@ -301,11 +380,13 @@ func New(
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
})
|
||||
|
||||
// Admin WebSocket endpoints — handlers authenticate after upgrade
|
||||
// via wsAuthenticateAdmin. markAdminWS sets the context flag so
|
||||
// handlers know to use admin auth instead of regular auth.
|
||||
// Admin WebSocket endpoints — browser auth via cookie + admin DB check
|
||||
// before upgrade. markAdminWS is retained as a context hint for any
|
||||
// admin-specific behavior downstream.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(markAdminWS)
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
@ -318,7 +399,7 @@ func New(
|
||||
ext.RegisterRoutes(r, sctx)
|
||||
}
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc, version: version}
|
||||
return &Server{router: r, BuildSvc: buildSvc, SSERelay: sseRelay, SessionSvc: sessionSvc, version: version}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -363,3 +444,101 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
// serviceEventToCanonical maps a SandboxService background-goroutine event
|
||||
// into the canonical events.Event taxonomy for unified publishing. Returns
|
||||
// false for events that should not be broadcast.
|
||||
func serviceEventToCanonical(e service.SandboxStateEvent) (events.Event, bool) {
|
||||
var (
|
||||
eventType string
|
||||
outcome events.Outcome
|
||||
metadata map[string]string
|
||||
)
|
||||
switch e.Event {
|
||||
case "sandbox.started":
|
||||
eventType = events.CapsuleCreate
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.resumed":
|
||||
eventType = events.CapsuleResume
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.paused":
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.auto_paused":
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeSuccess
|
||||
metadata = map[string]string{"reason": "ttl_expired"}
|
||||
case "sandbox.stopped":
|
||||
eventType = events.CapsuleDestroy
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.pause_failed":
|
||||
// reason must be non-empty or channels.isRedundantSystemFollowup
|
||||
// filters this system-actor event out of webhook delivery.
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "pause_failed"}
|
||||
case "sandbox.resume_failed":
|
||||
eventType = events.CapsuleResume
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "resume_failed"}
|
||||
case "sandbox.failed":
|
||||
// First-boot failure from the createInBackground goroutine. Without
|
||||
// this case the event falls through to default and is dropped — no
|
||||
// SSE push, no channel delivery, no DB reconciliation. reason must be
|
||||
// non-empty or channels.isRedundantSystemFollowup filters it out.
|
||||
eventType = events.CapsuleCreate
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "create_failed"}
|
||||
case "sandbox.snapshotted":
|
||||
// Completion of an async snapshot. The resource is the template name,
|
||||
// not the sandbox, so the dashboard's snapshot list refreshes.
|
||||
return events.Event{
|
||||
Event: events.SnapshotCreate,
|
||||
Outcome: events.OutcomeSuccess,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
|
||||
}, true
|
||||
case "sandbox.snapshot_failed":
|
||||
return events.Event{
|
||||
Event: events.SnapshotCreate,
|
||||
Outcome: events.OutcomeError,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
|
||||
Metadata: map[string]string{"reason": "snapshot_failed"},
|
||||
Error: e.Error,
|
||||
}, true
|
||||
case "sandbox.state_changed":
|
||||
// Transient badge transition with no terminal verb of its own. Carries
|
||||
// from/to in metadata; routed via Pub/Sub only by the caller.
|
||||
return events.Event{
|
||||
Event: events.CapsuleStateChanged,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
|
||||
Metadata: e.Metadata,
|
||||
}, true
|
||||
default:
|
||||
return events.Event{}, false
|
||||
}
|
||||
if e.HostIP != "" {
|
||||
if metadata == nil {
|
||||
metadata = map[string]string{}
|
||||
}
|
||||
metadata["host_ip"] = e.HostIP
|
||||
}
|
||||
return events.Event{
|
||||
Event: eventType,
|
||||
Outcome: outcome,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
|
||||
Metadata: metadata,
|
||||
Error: e.Error,
|
||||
}, true
|
||||
}
|
||||
|
||||
80
internal/api/sse_broker.go
Normal file
80
internal/api/sse_broker.go
Normal file
@ -0,0 +1,80 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const sseChannelBuffer = 32
|
||||
|
||||
type sseMessage struct {
|
||||
EventType string
|
||||
Data json.RawMessage
|
||||
}
|
||||
|
||||
type sseSubscriber struct {
|
||||
teamID string
|
||||
isAdmin bool
|
||||
ch chan sseMessage
|
||||
}
|
||||
|
||||
// SSEBroker is an in-process fan-out hub that dispatches events to connected
|
||||
// SSE clients, filtering by team ownership.
|
||||
type SSEBroker struct {
|
||||
mu sync.RWMutex
|
||||
nextID atomic.Uint64
|
||||
subscribers map[uint64]*sseSubscriber
|
||||
}
|
||||
|
||||
// NewSSEBroker constructs a broker with no subscribers.
|
||||
func NewSSEBroker() *SSEBroker {
|
||||
return &SSEBroker{
|
||||
subscribers: make(map[uint64]*sseSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new SSE client. Returns a subscriber ID (for
|
||||
// Unsubscribe) and a receive-only channel that delivers filtered events.
|
||||
func (b *SSEBroker) Subscribe(teamID string, isAdmin bool) (uint64, <-chan sseMessage) {
|
||||
id := b.nextID.Add(1)
|
||||
ch := make(chan sseMessage, sseChannelBuffer)
|
||||
sub := &sseSubscriber{teamID: teamID, isAdmin: isAdmin, ch: ch}
|
||||
|
||||
b.mu.Lock()
|
||||
b.subscribers[id] = sub
|
||||
b.mu.Unlock()
|
||||
|
||||
slog.Debug("sse: client subscribed", "sub_id", id, "team_id", teamID, "admin", isAdmin)
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a client. The handler loop exits via context cancellation;
|
||||
// the channel is not closed here to avoid send-on-closed-channel races with Dispatch.
|
||||
func (b *SSEBroker) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
delete(b.subscribers, id)
|
||||
b.mu.Unlock()
|
||||
slog.Debug("sse: client unsubscribed", "sub_id", id)
|
||||
}
|
||||
|
||||
// Dispatch fans out an event to all matching subscribers. Admin subscribers
|
||||
// receive all events; team subscribers only receive events for their team.
|
||||
// Non-blocking: events are dropped for slow consumers.
|
||||
func (b *SSEBroker) Dispatch(eventType string, teamID string, data json.RawMessage) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
msg := sseMessage{EventType: eventType, Data: data}
|
||||
for id, sub := range b.subscribers {
|
||||
if !sub.isAdmin && sub.teamID != teamID {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case sub.ch <- msg:
|
||||
default:
|
||||
slog.Warn("sse: dropped event for slow consumer", "sub_id", id, "event", eventType)
|
||||
}
|
||||
}
|
||||
}
|
||||
147
internal/api/sse_relay.go
Normal file
147
internal/api/sse_relay.go
Normal file
@ -0,0 +1,147 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const ssePubSubChannel = "wrenn:sse"
|
||||
|
||||
// sseEventPayload is the JSON envelope sent to SSE clients.
|
||||
type sseEventPayload struct {
|
||||
Event string `json:"event"`
|
||||
Outcome events.Outcome `json:"outcome,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor events.Actor `json:"actor"`
|
||||
Resource events.Resource `json:"resource"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Sandbox *sandboxResponse `json:"sandbox,omitempty"`
|
||||
}
|
||||
|
||||
// SSERelay subscribes to the Redis Pub/Sub channel and dispatches hydrated
|
||||
// events to the in-process broker. One instance per CP process.
|
||||
type SSERelay struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
broker *SSEBroker
|
||||
}
|
||||
|
||||
// NewSSERelay constructs the relay.
|
||||
func NewSSERelay(rdb *redis.Client, queries *db.Queries, broker *SSEBroker) *SSERelay {
|
||||
return &SSERelay{rdb: rdb, db: queries, broker: broker}
|
||||
}
|
||||
|
||||
// Start launches the Pub/Sub subscription goroutine. Returns when ctx is cancelled.
|
||||
func (r *SSERelay) Start(ctx context.Context) {
|
||||
go r.run(ctx)
|
||||
}
|
||||
|
||||
func (r *SSERelay) run(ctx context.Context) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
r.subscribe(ctx)
|
||||
// Backoff before reconnecting.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SSERelay) subscribe(ctx context.Context) {
|
||||
pubsub := r.rdb.Subscribe(ctx, ssePubSubChannel)
|
||||
defer pubsub.Close()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
r.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SSERelay) handleMessage(ctx context.Context, msg *redis.Message) {
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &event); err != nil {
|
||||
slog.Warn("sse relay: failed to unmarshal event", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
payload := sseEventPayload{
|
||||
Event: event.Event,
|
||||
Outcome: event.Outcome,
|
||||
Timestamp: event.Timestamp,
|
||||
TeamID: event.TeamID,
|
||||
Actor: event.Actor,
|
||||
Resource: event.Resource,
|
||||
Metadata: event.Metadata,
|
||||
Error: event.Error,
|
||||
}
|
||||
|
||||
// Hydrate sandbox state for capsule events.
|
||||
if isCapsuleEvent(event.Event) {
|
||||
sb, err := r.hydrateSandbox(ctx, event.Resource.ID)
|
||||
if err != nil {
|
||||
slog.Debug("sse relay: sandbox hydration failed (may be deleted)", "sandbox_id", event.Resource.ID, "error", err)
|
||||
} else {
|
||||
payload.Sandbox = sb
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
slog.Warn("sse relay: failed to marshal payload", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.broker.Dispatch(event.Event, event.TeamID, data)
|
||||
}
|
||||
|
||||
func (r *SSERelay) hydrateSandbox(ctx context.Context, sandboxIDStr string) (*sandboxResponse, error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sb, err := r.db.GetSandbox(queryCtx, sandboxID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := sandboxToResponse(sb)
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func isCapsuleEvent(eventType string) bool {
|
||||
switch eventType {
|
||||
case events.CapsuleCreate, events.CapsulePause, events.CapsuleResume, events.CapsuleDestroy, events.CapsuleStateChanged:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user