1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev>

Reviewed-on: wrenn/wrenn#50
This commit is contained in:
2026-05-24 21:10:37 +00:00
parent 4707f16c76
commit 05ddf62399
203 changed files with 15815 additions and 9344 deletions

View File

@ -3,11 +3,20 @@ package api
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
@ -20,3 +29,119 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
}
return pool.GetForHost(host)
}
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
// and verifies it is running. On failure it writes the appropriate HTTP error and
// returns false.
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
return db.Sandbox{}, pgtype.UUID{}, "", false
}
return sb, sandboxID, sandboxIDStr, true
}
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket. The
// auth context must already be populated by upstream middleware — browser
// clients via the wrenn_sid cookie (sent automatically on WS upgrade),
// SDK clients via X-API-Key. Requests without an auth context are rejected
// with a 401 before the upgrade.
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, auth.AuthContext, error) {
ac, hasAuth := auth.FromContext(r.Context())
if !hasAuth {
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
return nil, auth.AuthContext{}, fmt.Errorf("unauthenticated")
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
}
return conn, ac, nil
}
// resolveTemplateSizes queries a host agent for the actual disk usage of any
// templates with size_bytes <= 0 (e.g. system base templates seeded with
// size_bytes = 0 before the rootfs was built). Results are persisted to the
// DB so subsequent requests serve the correct size without an RPC call.
// Errors are logged but do not prevent the caller from serving the templates.
func resolveTemplateSizes(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, templates []db.Template) []db.Template {
needResolve := false
for _, t := range templates {
if t.SizeBytes <= 0 {
needResolve = true
break
}
}
if !needResolve {
return templates
}
hosts, err := queries.ListActiveHosts(ctx)
if err != nil || len(hosts) == 0 {
slog.Warn("resolveTemplateSizes: no active hosts available", "error", err)
return templates
}
agent, err := pool.GetForHost(hosts[0])
if err != nil {
slog.Warn("resolveTemplateSizes: failed to connect to host",
"host_id", id.UUIDString(hosts[0].ID), "error", err)
return templates
}
for i, t := range templates {
if t.SizeBytes > 0 {
continue
}
resp, err := agent.GetTemplateSize(ctx, connect.NewRequest(&pb.GetTemplateSizeRequest{
TeamId: formatUUIDForRPC(t.TeamID),
TemplateId: formatUUIDForRPC(t.ID),
}))
if err != nil {
slog.Warn("resolveTemplateSizes: failed to get size from host",
"template", t.Name, "error", err)
continue
}
templates[i].SizeBytes = resp.Msg.SizeBytes
if err := queries.UpdateTemplateSize(ctx, db.UpdateTemplateSizeParams{
ID: t.ID,
SizeBytes: resp.Msg.SizeBytes,
}); err != nil {
slog.Warn("resolveTemplateSizes: failed to persist size",
"template", t.Name, "error", err)
}
}
return templates
}
// updateLastActive updates the sandbox last_active_at timestamp.
// Uses a background context with timeout for streaming handlers where
// the request context may already be cancelled.
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
}

View File

@ -0,0 +1,54 @@
package api
import (
"context"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// collectAuthHooks filters the extension list down to those that implement
// cpextension.AuthHook.
func collectAuthHooks(extensions []cpextension.Extension) []cpextension.AuthHook {
var hooks []cpextension.AuthHook
for _, ext := range extensions {
if h, ok := ext.(cpextension.AuthHook); ok {
hooks = append(hooks, h)
}
}
return hooks
}
// fireOnSignup runs every OnSignup hook sequentially. The first error wins —
// signup must abort so the cloud-side billing customer creation stays the
// gating step. Other hook errors after a failure are not run.
func fireOnSignup(ctx context.Context, hooks []cpextension.AuthHook, userID, teamID pgtype.UUID, email string) error {
for _, h := range hooks {
if err := h.OnSignup(ctx, userID, teamID, email); err != nil {
return err
}
}
return nil
}
// fireOnLogin runs every OnLogin hook. Errors are logged and swallowed —
// login must never be blocked by a misbehaving extension.
func fireOnLogin(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
for _, h := range hooks {
if err := h.OnLogin(ctx, userID); err != nil {
slog.Warn("auth hook OnLogin failed", "user_id", id.FormatUserID(userID), "error", err)
}
}
}
// fireOnSoftDelete runs every OnAccountSoftDelete hook. Errors are logged.
func fireOnSoftDelete(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
for _, h := range hooks {
if err := h.OnAccountSoftDelete(ctx, userID); err != nil {
slog.Warn("auth hook OnAccountSoftDelete failed", "user_id", id.FormatUserID(userID), "error", err)
}
}
}

View File

@ -1,14 +1,9 @@
package api
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
@ -17,8 +12,6 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type adminCapsuleHandler struct {
@ -49,15 +42,18 @@ func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
MemoryMB: req.MemoryMB,
TimeoutSec: req.TimeoutSec,
})
ac.TeamID = id.PlatformTeamID
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
if err != nil {
if sb.ID.Valid {
h.audit.LogSandboxDestroySystem(r.Context(), id.PlatformTeamID, sb.ID, "cleanup_after_create_error", nil)
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
ac.TeamID = id.PlatformTeamID
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/admin/capsules.
@ -106,26 +102,27 @@ func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
ac.TeamID = id.PlatformTeamID
err = h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID)
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
w.WriteHeader(http.StatusAccepted)
}
type adminSnapshotRequest struct {
Name string `json:"name"`
}
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot. Takes a live
// snapshot of a platform-owned capsule and registers the result as a
// platform template (team_id = 00000000-...).
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
@ -133,117 +130,22 @@ func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
}
var req adminSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
// Verify sandbox exists and belongs to platform team BEFORE any
// destructive operations (template overwrite).
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
// Check if name already exists as a platform template.
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
// Delete old snapshot files from all hosts before removing the DB record.
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
if r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
sb, name, err := h.svc.CreateSnapshot(r.Context(), sandboxID, id.PlatformTeamID, req.Name)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context so the snapshot completes even if the client disconnects.
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: req.Name,
TeamId: formatUUIDForRPC(id.PlatformTeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
ac := auth.MustFromContext(r.Context())
ac.TeamID = id.PlatformTeamID
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: id.PlatformTeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
// Destroy the ephemeral capsule after successful snapshot.
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
// Don't fail the response — the snapshot was created successfully.
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}

View File

@ -20,6 +20,8 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
@ -140,14 +142,15 @@ type switchTeamRequest struct {
type authHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
sessions *session.Service
mailer email.Mailer
rdb *redis.Client
redirectURL string
authHooks []cpextension.AuthHook
}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, sessions *session.Service, mailer email.Mailer, rdb *redis.Client, redirectURL string, hooks []cpextension.AuthHook) *authHandler {
return &authHandler{db: db, pool: pool, sessions: sessions, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/"), authHooks: hooks}
}
type signupRequest struct {
@ -166,11 +169,41 @@ type activateRequest struct {
}
type authResponse struct {
Token string `json:"token"`
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
IsAdmin bool `json:"is_admin"`
}
// issueSession creates a new session and writes both the session and CSRF
// cookies to the response. On success it writes the authResponse JSON body.
func (h *authHandler) issueSession(
w http.ResponseWriter,
r *http.Request,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
) error {
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
if err != nil {
return err
}
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
return nil
}
// clientIP returns the request's apparent client IP, honoring
// X-Forwarded-For when behind a reverse proxy.
func clientIP(r *http.Request) string {
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
if i := strings.IndexByte(fwd, ','); i > 0 {
return strings.TrimSpace(fwd[:i])
}
return strings.TrimSpace(fwd)
}
return r.RemoteAddr
}
type signupResponse struct {
@ -253,8 +286,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
}
// Generate activation token and store in Redis.
rawToken := generateActivationToken()
tokenHash := hashActivationToken(rawToken)
rawToken := generateOpaqueToken()
tokenHash := hashOpaqueToken(rawToken)
redisKey := activationKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
@ -296,7 +329,7 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
tokenHash := hashActivationToken(req.Token)
tokenHash := hashOpaqueToken(req.Token)
redisKey := activationKeyPrefix + tokenHash
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
@ -345,18 +378,26 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
// Fire OnSignup before issuing a session — billing must succeed first.
if err := fireOnSignup(ctx, h.authHooks, userID, team.ID, user.Email); err != nil {
slog.Error("activate: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", err)
writeError(w, http.StatusInternalServerError, "signup_hook_failed", "failed to finalize account setup")
return
}
if err := h.issueSession(w, r, userID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
slog.Error("activate: failed to issue session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
return
}
fireOnLogin(ctx, h.authHooks, userID)
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
UserID: id.FormatUserID(userID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
Role: role,
IsAdmin: isAdmin,
})
}
@ -427,18 +468,20 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
if err := h.issueSession(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
slog.Error("login: failed to issue session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
return
}
fireOnLogin(ctx, h.authHooks, user.ID)
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
UserID: id.FormatUserID(user.ID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: user.Name,
Role: role,
IsAdmin: isAdmin,
})
}
@ -503,24 +546,54 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
// Rotate the SID so any leaked old cookie loses access at the moment of
// privilege change.
newSess, err := h.sessions.Rotate(ctx, ac.SessionID, ac.UserID, teamID, user.Email, user.Name, membership.Role, user.IsAdmin, r.UserAgent(), clientIP(r))
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
slog.Error("switch team: failed to rotate session", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to switch team")
return
}
setSessionCookies(w, newSess.RawSID, newSess.CSRFToken, isSecure(r))
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: ac.Email,
Name: user.Name,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(teamID),
Email: user.Email,
Name: user.Name,
Role: membership.Role,
IsAdmin: user.IsAdmin,
})
}
// Logout handles POST /v1/auth/logout — revokes the caller's current session.
func (h *authHandler) Logout(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
if err := h.sessions.Revoke(r.Context(), ac.SessionID); err != nil {
slog.Warn("logout: revoke failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
w.WriteHeader(http.StatusNoContent)
}
// LogoutAll handles POST /v1/auth/logout-all — revokes every session for the
// current user, including the caller's own.
func (h *authHandler) LogoutAll(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
if err := h.sessions.RevokeAllForUser(r.Context(), ac.UserID); err != nil {
slog.Error("logout-all: revoke failed", "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "failed to revoke sessions")
return
}
clearSessionCookies(w, isSecure(r))
w.WriteHeader(http.StatusNoContent)
}
// --- helpers ---
func generateActivationToken() string {
// generateOpaqueToken returns a fresh 16-byte hex-encoded random token,
// used for email activation links and password reset links.
func generateOpaqueToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
@ -528,7 +601,9 @@ func generateActivationToken() string {
return hex.EncodeToString(b)
}
func hashActivationToken(raw string) string {
// hashOpaqueToken returns the SHA-256 hex digest of raw, used as the
// lookup key for one-shot tokens stored in Redis.
func hashOpaqueToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}

View File

@ -0,0 +1,158 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/internal/recipe"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
// buildStreamKeepalive is the interval between server pings on an idle build
// stream, preventing intermediaries from closing the WebSocket.
const buildStreamKeepalive = 30 * time.Second
// buildStreamHandler serves the live admin build console WebSocket.
type buildStreamHandler struct {
db *db.Queries
broker *service.BuildBroker
}
func newBuildStreamHandler(db *db.Queries, broker *service.BuildBroker) *buildStreamHandler {
return &buildStreamHandler{db: db, broker: broker}
}
// Stream handles WS /v1/admin/builds/{id}/stream. On connect it replays the
// completed-step history from the DB log, sends the current build status,
// then live-tails events from the build broker until the build finishes or
// the client disconnects. Admin auth is enforced by upstream middleware.
func (h *buildStreamHandler) Stream(w http.ResponseWriter, r *http.Request) {
buildIDStr := chi.URLParam(r, "id")
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
return
}
build, err := h.db.GetTemplateBuild(r.Context(), buildID)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "build not found")
return
}
conn, _, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("build stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
h.runStream(r.Context(), conn, build)
}
func (h *buildStreamHandler) runStream(ctx context.Context, conn *websocket.Conn, build db.TemplateBuild) {
ws := &wsWriter{conn: conn}
buildIDStr := id.FormatBuildID(build.ID)
// Replay completed-step history from the DB log snapshot. lastStep is the
// highest step number already delivered, used to dedup overlapping live
// events for a step that finished between the DB read and the subscribe.
lastStep := replayBuildHistory(ws, build)
ws.writeJSON(service.BuildStreamEvent{
Type: "build-status",
Status: build.Status,
CurrentStep: build.CurrentStep,
TotalSteps: build.TotalSteps,
Error: build.Error,
})
// A finished build has no live events to follow.
if service.IsTerminalBuildStatus(build.Status) {
return
}
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
events, release := h.broker.Subscribe(buildIDStr)
defer release()
// Drain client reads so a disconnect cancels the stream. The client sends
// nothing meaningful; any read error means the socket is gone.
go func() {
for {
if _, _, err := conn.ReadMessage(); err != nil {
cancel()
return
}
}
}()
ticker := time.NewTicker(buildStreamKeepalive)
defer ticker.Stop()
for {
select {
case <-streamCtx.Done():
return
case <-ticker.C:
ws.writeJSON(map[string]string{"type": "ping"})
case ev, ok := <-events:
if !ok {
return
}
// Skip step events already covered by the history replay.
if ev.Type != "build-status" && ev.Step > 0 && ev.Step <= lastStep {
continue
}
ws.writeJSON(ev)
if ev.Type == "build-status" && service.IsTerminalBuildStatus(ev.Status) {
return
}
}
}
}
// replayBuildHistory synthesizes step-start/output/step-end events from the
// build's persisted log entries and writes them to the WebSocket. It returns
// the highest step number replayed.
func replayBuildHistory(ws *wsWriter, build db.TemplateBuild) int {
if len(build.Logs) == 0 {
return 0
}
var entries []recipe.BuildLogEntry
if err := json.Unmarshal(build.Logs, &entries); err != nil {
slog.Warn("build stream: bad log JSON", "build_id", id.FormatBuildID(build.ID), "error", err)
return 0
}
lastStep := 0
for _, e := range entries {
ws.writeJSON(service.BuildStreamEvent{Type: "step-start", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd})
if out := e.Stdout + e.Stderr; out != "" {
ws.writeJSON(service.BuildStreamEvent{
Type: "output",
Step: e.Step,
Data: base64.StdEncoding.EncodeToString([]byte(out)),
})
}
ws.writeJSON(service.BuildStreamEvent{
Type: "step-end", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd,
Exit: e.Exit, Ok: e.Ok, ElapsedMs: e.Elapsed,
})
if e.Step > lastStep {
lastStep = e.Step
}
}
return lastStep
}

View File

@ -9,7 +9,6 @@ import (
"strings"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/internal/layout"
@ -20,7 +19,6 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/service"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
type buildHandler struct {
@ -42,6 +40,7 @@ type createBuildRequest struct {
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
SkipPrePost bool `json:"skip_pre_post"`
RunAsRoot bool `json:"run_as_root"`
}
type buildResponse struct {
@ -181,6 +180,7 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
VCPUs: req.VCPUs,
MemoryMB: req.MemoryMB,
SkipPrePost: req.SkipPrePost,
RunAsRoot: req.RunAsRoot,
Archive: archive,
ArchiveName: archiveName,
})
@ -238,6 +238,9 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
return
}
// Resolve actual on-disk sizes for templates with unknown size.
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
type templateResponse struct {
Name string `json:"name"`
Type string `json:"type"`
@ -246,6 +249,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
SizeBytes int64 `json:"size_bytes"`
TeamID string `json:"team_id"`
CreatedAt string `json:"created_at"`
Protected bool `json:"protected"`
}
resp := make([]templateResponse, len(templates))
@ -257,6 +261,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
MemoryMB: t.MemoryMb,
SizeBytes: t.SizeBytes,
TeamID: id.FormatTeamID(t.TeamID),
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
}
if t.CreatedAt.Valid {
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
@ -280,29 +285,17 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "not_found", "template not found")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
return
}
// Broadcast delete to all online hosts.
hosts, _ := h.db.ListActiveHosts(ctx)
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := h.pool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(tmpl.TeamID),
TemplateId: formatUUIDForRPC(tmpl.ID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
}
}
// Remove the files from every host before dropping the DB record, so a
// failure leaves the template intact and retryable rather than orphaned.
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusConflict, "delete_failed",
"could not remove template files from all hosts: "+err.Error())
return
}
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {

View File

@ -3,14 +3,11 @@ package api
import (
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"time"
"unicode/utf8"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
// Exec handles POST /v1/capsules/{id}/exec.
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
SandboxID: sandboxIDStr,
@ -142,6 +119,8 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
Cmd: req.Cmd,
Args: req.Args,
TimeoutSec: req.TimeoutSec,
Envs: req.Envs,
Cwd: req.Cwd,
}))
if err != nil {
status, code, msg := agentErrToHTTP(err)
@ -151,41 +130,24 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
duration := time.Since(start)
// Update last active.
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
// Use base64 encoding if output contains non-UTF-8 bytes.
stdout := resp.Msg.Stdout
stderr := resp.Msg.Stderr
encoding := "utf-8"
encoding := "utf-8"
stdoutStr, stderrStr := string(stdout), string(stderr)
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
encoding = "base64"
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: base64.StdEncoding.EncodeToString(stdout),
Stderr: base64.StdEncoding.EncodeToString(stderr),
ExitCode: resp.Msg.ExitCode,
DurationMs: duration.Milliseconds(),
Encoding: encoding,
})
return
stdoutStr = base64.StdEncoding.EncodeToString(stdout)
stderrStr = base64.StdEncoding.EncodeToString(stderr)
}
writeJSON(w, http.StatusOK, execResponse{
SandboxID: sandboxIDStr,
Cmd: req.Cmd,
Stdout: string(stdout),
Stderr: string(stderr),
Stdout: stdoutStr,
Stderr: stderrStr,
ExitCode: resp.Msg.ExitCode,
DurationMs: duration.Milliseconds(),
Encoding: encoding,

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"log/slog"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -20,13 +19,12 @@ import (
)
type execStreamHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
return &execStreamHandler{db: db, pool: pool}
}
var upgrader = websocket.Upgrader{
@ -59,37 +57,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("websocket upgrade failed", "error", err)
slog.Error("websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -186,18 +156,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
}
}
// Update last active using a fresh context (the request context may be cancelled).
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func sendWSError(conn *websocket.Conn, msg string) {

View File

@ -7,11 +7,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
// - "path" text field: absolute destination path inside the sandbox
// - "file" file field: binary content to write
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -108,23 +94,11 @@ type readFileRequest struct {
// Download handles POST /v1/capsules/{id}/files/read.
// Accepts JSON body with path, returns raw file content with Content-Disposition.
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -8,11 +8,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
// Expects multipart/form-data with "path" text field and "file" file field.
// Streams file content directly from the request body to the host agent without buffering.
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -103,6 +89,12 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// Open client-streaming RPC to host agent.
stream := agent.WriteFileStream(ctx)
var streamClosed bool
defer func() {
if !streamClosed {
_, _ = stream.CloseAndReceive()
}
}()
// Send metadata first.
if err := stream.Send(&pb.WriteFileStreamRequest{
@ -141,6 +133,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
}
// Close and receive response.
streamClosed = true
if _, err := stream.CloseAndReceive(); err != nil {
status, code, msg := agentErrToHTTP(err)
writeError(w, status, code, msg)
@ -153,23 +146,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
// Accepts JSON body with path, streams file content back without buffering.
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -4,11 +4,9 @@ import (
"net/http"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
@ -58,23 +56,11 @@ type removeRequest struct {
// ListDir handles POST /v1/capsules/{id}/files/list.
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
// Remove handles POST /v1/capsules/{id}/files/remove.
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
@ -21,10 +22,11 @@ type hostHandler struct {
svc *service.HostService
queries *db.Queries
audit *audit.AuditLogger
monitor *HostMonitor
}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al}
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger, monitor *HostMonitor) *hostHandler {
return &hostHandler{svc: svc, queries: queries, audit: al, monitor: monitor}
}
// Request/response types.
@ -98,6 +100,11 @@ type hostResponse struct {
CreatedBy string `json:"created_by"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
RunningVcpus int32 `json:"running_vcpus"`
RunningMemoryMb int32 `json:"running_memory_mb"`
RunningDiskMb int32 `json:"running_disk_mb"`
PausedMemoryMb int32 `json:"paused_memory_mb"`
PausedDiskMb int32 `json:"paused_disk_mb"`
}
func hostToResponse(h db.Host) hostResponse {
@ -136,12 +143,37 @@ func hostToResponse(h db.Host) hostResponse {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
resp.LastHeartbeatAt = &s
}
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
return resp
}
func hostToResponseWithLoad(h db.ListHostsByTeamRow) hostResponse {
resp := hostToResponse(db.Host{
ID: h.ID,
Type: h.Type,
TeamID: h.TeamID,
Provider: h.Provider,
AvailabilityZone: h.AvailabilityZone,
Arch: h.Arch,
CpuCores: h.CpuCores,
MemoryMb: h.MemoryMb,
DiskGb: h.DiskGb,
Address: h.Address,
Status: h.Status,
LastHeartbeatAt: h.LastHeartbeatAt,
CreatedBy: h.CreatedBy,
CreatedAt: h.CreatedAt,
UpdatedAt: h.UpdatedAt,
})
resp.RunningVcpus = h.RunningVcpus
resp.RunningMemoryMb = h.RunningMemoryMb
resp.RunningDiskMb = h.RunningDiskMb
resp.PausedMemoryMb = h.PausedMemoryMb
resp.PausedDiskMb = h.PausedDiskMb
return resp
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
@ -233,7 +265,7 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponse(host)
resp[i] = hostToResponseWithLoad(host)
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
@ -335,6 +367,54 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
writeError(w, status, code, msg)
}
// AdminList handles GET /v1/admin/hosts.
// Returns all hosts with per-host resource consumption. Admin-only.
func (h *hostHandler) AdminList(w http.ResponseWriter, r *http.Request) {
hosts, err := h.svc.ListAdmin(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
return
}
// Collect unique team IDs to fetch team names.
var teamNames map[string]string
seen := make(map[string]struct{})
for _, host := range hosts {
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
seen[key] = struct{}{}
}
}
if len(seen) > 0 {
teamNames = make(map[string]string, len(seen))
for _, host := range hosts {
if !host.TeamID.Valid {
continue
}
key := id.FormatTeamID(host.TeamID)
if _, ok := teamNames[key]; ok {
continue
}
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
teamNames[key] = team.Name
}
}
}
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponseWithLoad(db.ListHostsByTeamRow(host))
if host.TeamID.Valid {
key := id.FormatTeamID(host.TeamID)
if name, ok := teamNames[key]; ok {
resp[i].TeamName = &name
}
}
}
writeJSON(w, http.StatusOK, resp)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostIDStr := chi.URLParam(r, "id")
@ -426,9 +506,12 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
return
}
// Log marked_up if the host just recovered from unreachable.
// If the host just recovered from unreachable, log it and trigger immediate
// reconciliation so "missing" sandboxes are resolved without waiting for the
// next monitor tick.
if prevHost.Status == "unreachable" {
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
go h.monitor.ReconcileHost(context.Background(), hc.HostID)
}
w.WriteHeader(http.StatusNoContent)

View File

@ -2,9 +2,6 @@ package api
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
@ -20,6 +17,8 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
@ -34,42 +33,62 @@ type meHandler struct {
db *db.Queries
pool *pgxpool.Pool
rdb *redis.Client
jwtSecret []byte
hmacKey []byte // HMAC key for OAuth state and link cookies (was jwtSecret)
sessions *session.Service
mailer email.Mailer
oauthRegistry *oauth.Registry
redirectURL string
teamSvc *service.TeamService
authHooks []cpextension.AuthHook
}
func newMeHandler(
db *db.Queries,
pool *pgxpool.Pool,
rdb *redis.Client,
jwtSecret []byte,
hmacKey []byte,
sessions *session.Service,
mailer email.Mailer,
registry *oauth.Registry,
redirectURL string,
teamSvc *service.TeamService,
hooks []cpextension.AuthHook,
) *meHandler {
return &meHandler{
db: db,
pool: pool,
rdb: rdb,
jwtSecret: jwtSecret,
hmacKey: hmacKey,
sessions: sessions,
mailer: mailer,
oauthRegistry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
teamSvc: teamSvc,
authHooks: hooks,
}
}
type meResponse struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
IsAdmin bool `json:"is_admin"`
HasPassword bool `json:"has_password"`
Providers []string `json:"providers"`
}
type sessionRow struct {
ID string `json:"id"`
UserAgent string `json:"user_agent"`
IPAddress string `json:"ip_address"`
CreatedAt string `json:"created_at"`
LastSeenAt string `json:"last_seen_at"`
ExpiresAt string `json:"expires_at"`
Current bool `json:"current"`
}
type updateNameRequest struct {
Name string `json:"name"`
}
@ -116,14 +135,19 @@ func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
}
writeJSON(w, http.StatusOK, meResponse{
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(ac.TeamID),
Name: user.Name,
Email: user.Email,
Role: ac.Role,
IsAdmin: user.IsAdmin,
HasPassword: user.PasswordHash.Valid,
Providers: providerNames,
})
}
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
// UpdateName handles PATCH /v1/me — updates the user's name and refreshes
// any cached session blobs so the new name shows up on next request.
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
ctx := r.Context()
@ -148,31 +172,11 @@ func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.db.GetUserByID(ctx, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
return
if err := h.sessions.InvalidateCacheForUser(ctx, ac.UserID); err != nil {
slog.Warn("update name: invalidate session cache failed", "error", err)
}
team, role, err := loginTeam(ctx, h.db, ac.UserID)
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
return
}
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
return
}
writeJSON(w, http.StatusOK, authResponse{
Token: token,
UserID: id.FormatUserID(ac.UserID),
TeamID: id.FormatTeamID(team.ID),
Email: user.Email,
Name: req.Name,
})
w.WriteHeader(http.StatusNoContent)
}
// ChangePassword handles POST /v1/me/password.
@ -235,6 +239,14 @@ func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
return
}
// Revoke every session for this user — including the caller's — so a new
// password resets all device access. Clear cookies on the response so the
// caller is signed out immediately.
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
slog.Warn("change password: revoke sessions failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
isAdding := !user.PasswordHash.Valid
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@ -285,8 +297,8 @@ func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request)
return
}
rawToken := generateResetToken()
tokenHash := hashResetToken(rawToken)
rawToken := generateOpaqueToken()
tokenHash := hashOpaqueToken(rawToken)
redisKey := passwordResetKeyPrefix + tokenHash
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
@ -330,7 +342,7 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
}
ctx := r.Context()
tokenHash := hashResetToken(req.Token)
tokenHash := hashOpaqueToken(req.Token)
redisKey := passwordResetKeyPrefix + tokenHash
// GetDel atomically retrieves and removes the token in a single round-trip,
@ -371,6 +383,11 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
return
}
// Reset invalidates every active session for the user.
if err := h.sessions.RevokeAllForUser(ctx, userID); err != nil {
slog.Warn("confirm password reset: revoke sessions failed", "error", err)
}
go func() {
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@ -404,7 +421,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
return
}
mac := computeHMAC(h.jwtSecret, state+":"+"login")
mac := computeHMAC(h.hmacKey, state+":"+"login")
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state + ":" + mac + ":" + "login",
@ -416,7 +433,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
})
userIDStr := id.FormatUserID(ac.UserID)
linkMac := computeHMAC(h.jwtSecret, userIDStr)
linkMac := computeHMAC(h.hmacKey, userIDStr)
http.SetCookie(w, &http.Cookie{
Name: "oauth_link_user_id",
Value: userIDStr + ":" + linkMac,
@ -552,6 +569,14 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
// Revoke every active session and clear the caller's cookies.
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
slog.Warn("delete account: revoke sessions failed", "error", err)
}
clearSessionCookies(w, isSecure(r))
fireOnSoftDelete(ctx, h.authHooks, ac.UserID)
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
go func() {
@ -569,17 +594,4 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// --- helpers ---
func generateResetToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hashResetToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}
// (token helpers live in handlers_auth.go)

View File

@ -14,10 +14,12 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
@ -25,18 +27,22 @@ import (
type oauthHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
hmacKey []byte // HMAC key for OAuth state and link cookies
sessions *session.Service
registry *oauth.Registry
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
authHooks []cpextension.AuthHook
}
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, hmacKey []byte, sessions *session.Service, registry *oauth.Registry, redirectURL string, hooks []cpextension.AuthHook) *oauthHandler {
return &oauthHandler{
db: db,
pool: pool,
jwtSecret: jwtSecret,
hmacKey: hmacKey,
sessions: sessions,
registry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
authHooks: hooks,
}
}
@ -61,7 +67,7 @@ func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
intent = "login"
}
mac := computeHMAC(h.jwtSecret, state+":"+intent)
mac := computeHMAC(h.hmacKey, state+":"+intent)
cookieVal := state + ":" + mac + ":" + intent
http.SetCookie(w, &http.Cookie{
@ -121,7 +127,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
if len(parts) == 3 && parts[2] == "signup" {
intent = "signup"
}
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce+":"+intent)), []byte(expectedMAC)) {
if !hmac.Equal([]byte(computeHMAC(h.hmacKey, nonce+":"+intent)), []byte(expectedMAC)) {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
@ -164,7 +170,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
// Verify the HMAC to prevent cookie forgery.
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.hmacKey, linkParts[0])), []byte(linkParts[1])) {
slog.Warn("oauth link: invalid or tampered link cookie")
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
return
@ -244,13 +250,12 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
slog.Error("oauth login: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
fireOnLogin(ctx, h.authHooks, user.ID)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
@ -374,10 +379,10 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
// Fire OnSignup before session issuance — billing must succeed first.
if hookErr := fireOnSignup(ctx, h.authHooks, userID, teamID, email); hookErr != nil {
slog.Error("oauth signup: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", hookErr)
redirectWithError(w, r, redirectBase, "signup_hook_failed")
return
}
@ -385,14 +390,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "wrenn_oauth_new_signup",
Value: "1",
Path: "/auth/",
Path: "/",
MaxAge: 60,
HttpOnly: false,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
if err := h.issueSessionAndRedirect(w, r, userID, teamID, email, profile.Name, "owner", isFirstUser, redirectBase); err != nil {
slog.Error("oauth: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
fireOnLogin(ctx, h.authHooks, userID)
}
// retryAsLogin handles the race where a concurrent request already created the user.
@ -431,33 +441,35 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
return
}
isAdmin := user.IsAdmin || isFirstUser
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
slog.Error("oauth: retry login: failed to issue session", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
fireOnLogin(ctx, h.authHooks, user.ID)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
// Set auth data as short-lived cookies instead of URL query parameters.
// This prevents token leakage via server access logs, Referer headers, and browser history.
for _, c := range []http.Cookie{
{Name: "wrenn_oauth_token", Value: token},
{Name: "wrenn_oauth_user_id", Value: userID},
{Name: "wrenn_oauth_team_id", Value: teamID},
{Name: "wrenn_oauth_email", Value: email},
{Name: "wrenn_oauth_name", Value: name},
} {
c.Path = "/auth/"
c.MaxAge = 60
c.HttpOnly = false // frontend JS must read these
c.SameSite = http.SameSiteLaxMode
c.Secure = isSecure(r)
http.SetCookie(w, &c)
// issueSessionAndRedirect creates a session, sets the session and CSRF
// cookies, and redirects to the frontend dashboard. The redirectBase param
// is the OAuth callback URL; we ignore it after success and send the user to
// /dashboard directly (callback page will probe /v1/me to hydrate state).
func (h *oauthHandler) issueSessionAndRedirect(
w http.ResponseWriter,
r *http.Request,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
redirectBase string,
) error {
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
if err != nil {
return err
}
http.Redirect(w, r, base, http.StatusFound)
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
// Send the user to the callback page so the SPA can probe /v1/me and
// trigger any post-OAuth UX (e.g. the new-signup name confirmation).
http.Redirect(w, r, redirectBase, http.StatusFound)
return nil
}
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
@ -477,7 +489,3 @@ func computeHMAC(key []byte, data string) string {
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func isSecure(r *http.Request) bool {
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
}

View File

@ -5,7 +5,6 @@ import (
"log/slog"
"net/http"
"strconv"
"time"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
@ -20,13 +19,12 @@ import (
)
type processHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
return &processHandler{db: db, pool: pool}
}
// processResponse is a single entry in the process list.
@ -44,23 +42,11 @@ type processListResponse struct {
// ListProcesses handles GET /v1/capsules/{id}/processes.
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -95,24 +81,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
// The selector can be a numeric PID or a string tag.
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
selectorStr := chi.URLParam(r, "selector")
ctx := r.Context()
ac := auth.MustFromContext(ctx)
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
if !ok {
return
}
@ -146,14 +120,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// wsProcessOut is the JSON message sent to the WebSocket client.
type wsProcessOut struct {
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
PID uint32 `json:"pid,omitempty"` // only for "start"
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
}
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
sandboxIDStr := chi.URLParam(r, "id")
@ -166,37 +132,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
return
}
// Authenticate: use context from middleware (API key) or WS first message (JWT).
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
return
}
defer conn.Close()
var wsAC auth.AuthContext
var authErr error
if isAdminWSRoute(ctx) {
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if authErr != nil {
sendProcessWSError(conn, "authentication failed")
return
}
ac = wsAC
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("process stream websocket upgrade failed", "error", err)
slog.Error("process stream websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -207,17 +145,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
if err != nil {
sendProcessWSError(conn, "sandbox not found")
sendWSError(conn, "sandbox not found")
return
}
if sb.Status != "running" {
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
sendProcessWSError(conn, "sandbox host is not reachable")
sendWSError(conn, "sandbox host is not reachable")
return
}
@ -236,7 +174,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
if err != nil {
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
sendWSError(conn, "failed to connect to process: "+err.Error())
return
}
defer stream.Close()
@ -257,42 +195,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
resp := stream.Msg()
switch ev := resp.Event.(type) {
case *pb.ConnectProcessResponse_Start:
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
case *pb.ConnectProcessResponse_Data:
switch o := ev.Data.Output.(type) {
case *pb.ExecStreamData_Stdout:
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
case *pb.ExecStreamData_Stderr:
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
}
case *pb.ConnectProcessResponse_End:
exitCode := ev.End.ExitCode
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
}
}
if err := stream.Err(); err != nil {
if streamCtx.Err() == nil {
sendProcessWSError(conn, err.Error())
sendWSError(conn, err.Error())
}
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
}
}
func sendProcessWSError(conn *websocket.Conn, msg string) {
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
updateLastActive(h.db, sandboxID, sandboxIDStr)
}

View File

@ -30,13 +30,12 @@ const (
)
type ptyHandler struct {
db *db.Queries
pool *lifecycle.HostClientPool
jwtSecret []byte
db *db.Queries
pool *lifecycle.HostClientPool
}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
return &ptyHandler{db: db, pool: pool}
}
// --- WebSocket message types ---
@ -90,40 +89,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
return
}
// API key auth is handled by middleware (sets context).
// For browser JWT auth, we authenticate after upgrade via first WS message.
ac, hasAuth := auth.FromContext(ctx)
if !hasAuth {
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
return
}
defer conn.Close()
ws := &wsWriter{conn: conn}
var wsAC auth.AuthContext
if isAdminWSRoute(ctx) {
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
} else {
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
}
if err != nil {
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
return
}
ac = wsAC
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, ac, err := upgradeAndAuthenticate(w, r)
if err != nil {
slog.Error("pty websocket upgrade failed", "error", err)
slog.Error("pty websocket upgrade/auth failed", "error", err)
return
}
defer conn.Close()
@ -168,18 +136,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
}
// Update last active using a fresh context.
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
}
updateLastActive(h.db, sandboxID, sandboxIDStr)
}
func (h *ptyHandler) handleStart(

View File

@ -37,8 +37,8 @@ type sandboxResponse struct {
VCPUs int32 `json:"vcpus"`
MemoryMB int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
GuestIP string `json:"guest_ip,omitempty"`
HostIP string `json:"host_ip,omitempty"`
DiskSizeMB int32 `json:"disk_size_mb"`
DiskUsedMB *int64 `json:"disk_used_mb,omitempty"`
CreatedAt string `json:"created_at"`
StartedAt *string `json:"started_at,omitempty"`
LastActiveAt *string `json:"last_active_at,omitempty"`
@ -54,8 +54,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
VCPUs: sb.Vcpus,
MemoryMB: sb.MemoryMb,
TimeoutSec: sb.TimeoutSec,
GuestIP: sb.GuestIp,
HostIP: sb.HostIp,
DiskSizeMB: sb.DiskSizeMb,
}
if len(sb.Metadata) > 0 {
var meta map[string]string
@ -101,14 +100,17 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
MemoryMB: req.MemoryMB,
TimeoutSec: req.TimeoutSec,
})
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
if err != nil {
if sb.ID.Valid {
h.audit.LogSandboxDestroySystem(r.Context(), ac.TeamID, sb.ID, "cleanup_after_create_error", nil)
}
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/capsules.
@ -145,7 +147,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
return
}
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
resp := sandboxToResponse(sb)
diskBytes, err := h.svc.GetDiskUsage(r.Context(), sandboxID, ac.TeamID)
if err == nil {
diskUsedMB := diskBytes / (1024 * 1024)
resp.DiskUsedMB = &diskUsedMB
}
writeJSON(w, http.StatusOK, resp)
}
// Pause handles POST /v1/capsules/{id}/pause.
@ -160,14 +170,14 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
}
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxPause(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// Resume handles POST /v1/capsules/{id}/resume.
@ -182,14 +192,14 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
}
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxResume(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// Ping handles POST /v1/capsules/{id}/ping.
@ -223,12 +233,13 @@ func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
err = h.svc.Destroy(r.Context(), sandboxID, ac.TeamID)
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
w.WriteHeader(http.StatusNoContent)
w.WriteHeader(http.StatusAccepted)
}

View File

@ -0,0 +1,169 @@
package api
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
type sandboxEventHandler struct {
db *db.Queries
eventPub *channels.Publisher
}
func newSandboxEventHandler(queries *db.Queries, eventPub *channels.Publisher) *sandboxEventHandler {
return &sandboxEventHandler{db: queries, eventPub: eventPub}
}
type sandboxEventRequest struct {
Event string `json:"event"`
SandboxID string `json:"sandbox_id"`
HostID string `json:"host_id"`
HostIP string `json:"host_ip,omitempty"`
Error string `json:"error,omitempty"`
Timestamp int64 `json:"timestamp"`
}
// Handle receives lifecycle event callbacks from host agents, translates the
// raw host event into the canonical events.Event taxonomy, and publishes once
// to the unified Redis stream. The SandboxEventConsumer (independent
// consumer group) drives DB reconciliation; the channels dispatcher delivers
// to subscribed channels; the SSE relay mirrors via Pub/Sub.
func (h *sandboxEventHandler) Handle(w http.ResponseWriter, r *http.Request) {
var req sandboxEventRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Event == "" || req.SandboxID == "" || req.HostID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "event, sandbox_id, and host_id are required")
return
}
hc := auth.MustHostFromContext(r.Context())
callerHostID := id.FormatHostID(hc.HostID)
if callerHostID != req.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host_id does not match authenticated host")
return
}
if req.Timestamp == 0 {
req.Timestamp = time.Now().Unix()
}
evt, ok := h.translate(r.Context(), req)
if !ok {
// Unknown event type — log and accept so the host agent doesn't retry.
slog.Warn("sandbox event callback: untranslatable event", "event", req.Event, "sandbox_id", req.SandboxID)
w.WriteHeader(http.StatusNoContent)
return
}
h.eventPub.Publish(r.Context(), evt)
w.WriteHeader(http.StatusNoContent)
}
// translate converts a raw host-agent callback into the canonical event.
// For failure events without an in-flight verb (e.g. sandbox.failed), the
// current DB status is consulted to pick the appropriate verb.
func (h *sandboxEventHandler) translate(ctx context.Context, req sandboxEventRequest) (events.Event, bool) {
sandboxUUID, parseErr := id.ParseSandboxID(req.SandboxID)
if parseErr != nil {
return events.Event{}, false
}
var teamID pgtype.UUID
if sb, dbErr := h.db.GetSandbox(ctx, sandboxUUID); dbErr == nil {
teamID = sb.TeamID
}
base := events.Event{
Timestamp: time.Unix(req.Timestamp, 0).UTC().Format(time.RFC3339),
TeamID: id.FormatTeamID(teamID),
Actor: events.SystemActor(),
Resource: events.Resource{ID: req.SandboxID, Type: "sandbox"},
}
switch req.Event {
case "sandbox.started":
meta := map[string]string{}
if req.HostIP != "" {
meta["host_ip"] = req.HostIP
}
meta["host_id"] = req.HostID
base.Event = events.CapsuleCreate
base.Outcome = events.OutcomeSuccess
base.Metadata = meta
case "sandbox.resumed":
meta := map[string]string{"host_id": req.HostID}
if req.HostIP != "" {
meta["host_ip"] = req.HostIP
}
base.Event = events.CapsuleResume
base.Outcome = events.OutcomeSuccess
base.Metadata = meta
case "sandbox.paused":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeSuccess
case "sandbox.auto_paused":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeSuccess
base.Metadata = map[string]string{"reason": "ttl_expired"}
case "sandbox.stopped":
base.Event = events.CapsuleDestroy
base.Outcome = events.OutcomeSuccess
case "sandbox.pause_failed":
base.Event = events.CapsulePause
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
case "sandbox.resume_failed":
base.Event = events.CapsuleResume
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
case "sandbox.failed", "sandbox.error":
// Pick a verb based on the sandbox's current DB status.
verb := h.verbForFailure(ctx, sandboxUUID)
base.Event = verb
base.Outcome = events.OutcomeError
base.Error = req.Error
base.Metadata = map[string]string{"reason": "host_failure"}
default:
return events.Event{}, false
}
return base, true
}
func (h *sandboxEventHandler) verbForFailure(ctx context.Context, sandboxID pgtype.UUID) string {
sb, err := h.db.GetSandbox(ctx, sandboxID)
if err != nil {
return events.CapsuleDestroy
}
switch sb.Status {
case "starting":
return events.CapsuleCreate
case "resuming":
return events.CapsuleResume
case "pausing":
return events.CapsulePause
case "snapshotting":
// A snapshot pauses then resumes the VM; a host-side failure leaves the
// sandbox errored, not destroyed. Route through CapsuleCreate so the
// consumer's handleFailed marks it "error" rather than removing the row.
return events.CapsuleCreate
default:
return events.CapsuleDestroy
}
}

View File

@ -0,0 +1,66 @@
package api
import (
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
)
type sessionsHandler struct {
sessions *session.Service
}
func newSessionsHandler(svc *session.Service) *sessionsHandler {
return &sessionsHandler{sessions: svc}
}
// List handles GET /v1/me/sessions — returns all active sessions for the
// current user, flagging the caller's own session as current.
func (h *sessionsHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
rows, err := h.sessions.ListForUser(r.Context(), ac.UserID)
if err != nil {
slog.Error("list sessions: db error", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sessions")
return
}
out := make([]sessionRow, 0, len(rows))
for _, row := range rows {
out = append(out, sessionRow{
ID: row.ID,
UserAgent: row.UserAgent,
IPAddress: row.IpAddress,
CreatedAt: row.CreatedAt.Time.UTC().Format(time.RFC3339),
LastSeenAt: row.LastSeenAt.Time.UTC().Format(time.RFC3339),
ExpiresAt: row.ExpiresAt.Time.UTC().Format(time.RFC3339),
Current: row.ID == ac.SessionID,
})
}
writeJSON(w, http.StatusOK, map[string]any{"sessions": out})
}
// Delete handles DELETE /v1/me/sessions/{id} — revokes a single session
// belonging to the current user. If the caller revokes their own session,
// their cookies are cleared on the response.
func (h *sessionsHandler) Delete(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
sid := chi.URLParam(r, "id")
if sid == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "missing session id")
return
}
if err := h.sessions.DeleteForUser(r.Context(), sid, ac.UserID); err != nil {
slog.Error("delete session: db error", "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete session")
return
}
if sid == ac.SessionID {
clearSessionCookies(w, isSecure(r))
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
@ -25,49 +24,53 @@ import (
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
svc *service.TemplateService
sandboxSvc *service.SandboxService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
// deleteSnapshotEverywhere removes a template's files from every active host.
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
//
// It is strict by design: deletion is reported successful only when every
// active host has either removed the files or reported NotFound (it never
// held them). If any host is offline or returns an error, it returns an error
// and the caller MUST NOT delete the DB record — doing so would orphan the
// files on disk with no record left to retry against.
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
hosts, err := queries.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
id.FormatHostID(host.ID), host.Status)
}
agent, err := pool.GetForHost(host)
if err != nil {
continue
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
// NotFound just means this host never held the template.
if connect.CodeOf(err) == connect.CodeNotFound {
continue
}
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
}
}
return nil
}
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
type snapshotResponse struct {
Name string `json:"name"`
Type string `json:"type"`
@ -76,6 +79,7 @@ type snapshotResponse struct {
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
Protected bool `json:"protected"`
Metadata map[string]string `json:"metadata,omitempty"`
}
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
}
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
return resp
}
// Create handles POST /v1/snapshots.
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
// registers the result as a new template.
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.SandboxID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
ac := auth.MustFromContext(r.Context())
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
ac := auth.MustFromContext(ctx)
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
// template is registered by a background goroutine; clients learn of the
// result via the SSE template.snapshot.create event (or by polling).
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/snapshots.
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
return
}
// Resolve actual on-disk sizes for templates with unknown size (e.g.
// system base templates seeded with size_bytes = 0). This queries a host
// agent and persists the result to the DB for subsequent requests.
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
resp := make([]snapshotResponse, len(templates))
for i, t := range templates {
resp[i] = templateToResponse(t)
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
return
}
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusConflict, "delete_failed",
"could not remove snapshot files from all hosts: "+err.Error())
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
h.audit.LogSnapshotDelete(r.Context(), ac, name)
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,79 @@
package api
import (
"fmt"
"net/http"
"time"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const sseKeepaliveInterval = 30 * time.Second
type sseHandler struct {
broker *SSEBroker
}
func newSSEHandler(broker *SSEBroker) *sseHandler {
return &sseHandler{broker: broker}
}
// Stream handles GET /v1/events/stream. Authentication is performed by
// upstream middleware: browser clients use the wrenn_sid cookie (which the
// EventSource API forwards automatically); SDK clients use X-API-Key.
func (h *sseHandler) Stream(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
return
}
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), false)
}
// AdminStream handles GET /v1/admin/events/stream. Upstream middleware
// must enforce session auth + requireAdmin before reaching this handler.
func (h *sseHandler) AdminStream(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized", "admin session required")
return
}
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), true)
}
func (h *sseHandler) serveSSE(w http.ResponseWriter, r *http.Request, teamID string, isAdmin bool) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
subID, ch := h.broker.Subscribe(teamID, isAdmin)
defer h.broker.Unsubscribe(subID)
fmt.Fprintf(w, "event: connected\ndata: {\"message\":\"connected\"}\n\n")
flusher.Flush()
keepalive := time.NewTicker(sseKeepaliveInterval)
defer keepalive.Stop()
ctx := r.Context()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", msg.EventType, msg.Data)
flusher.Flush()
case <-keepalive.C:
fmt.Fprintf(w, ": keepalive\n\n")
flusher.Flush()
}
}
}

View File

@ -14,19 +14,21 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/email"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type teamHandler struct {
svc *service.TeamService
audit *audit.AuditLogger
mailer email.Mailer
svc *service.TeamService
audit *audit.AuditLogger
mailer email.Mailer
sessions *session.Service
}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
return &teamHandler{svc: svc, audit: al, mailer: mailer}
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer, sessions *session.Service) *teamHandler {
return &teamHandler{svc: svc, audit: al, mailer: mailer, sessions: sessions}
}
// teamResponse is the JSON shape for a team.
@ -366,6 +368,11 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
return
}
// Drop cached session blobs so the new role propagates immediately.
if err := h.sessions.InvalidateCacheForUser(r.Context(), targetUserID); err != nil {
_ = err
}
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
w.WriteHeader(http.StatusNoContent)
}
@ -450,6 +457,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
OwnerEmail string `json:"owner_email"`
ActiveSandboxCount int32 `json:"active_sandbox_count"`
ChannelCount int32 `json:"channel_count"`
RunningVcpus int32 `json:"running_vcpus"`
RunningMemoryMb int32 `json:"running_memory_mb"`
}
resp := make([]adminTeamResponse, len(teams))
@ -465,6 +474,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
OwnerEmail: t.OwnerEmail,
ActiveSandboxCount: t.ActiveSandboxCount,
ChannelCount: t.ChannelCount,
RunningVcpus: t.RunningVcpus,
RunningMemoryMb: t.RunningMemoryMb,
}
if t.DeletedAt != nil {
s := t.DeletedAt.Format(time.RFC3339)

View File

@ -11,19 +11,21 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/service"
)
type usersHandler struct {
db *db.Queries
svc *service.UserService
audit *audit.AuditLogger
db *db.Queries
svc *service.UserService
audit *audit.AuditLogger
sessions *session.Service
}
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger) *usersHandler {
return &usersHandler{db: db, svc: svc, audit: al}
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger, sessions *session.Service) *usersHandler {
return &usersHandler{db: db, svc: svc, audit: al, sessions: sessions}
}
// Search handles GET /v1/users/search?email=<prefix>
@ -158,6 +160,10 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
if req.Active {
h.audit.LogUserActivate(r.Context(), ac, userID, user.Email)
} else {
// Disabled users must be kicked out of every active session.
if err := h.sessions.RevokeAllForUser(r.Context(), userID); err != nil {
_ = err
}
h.audit.LogUserDeactivate(r.Context(), ac, userID, user.Email)
}
w.WriteHeader(http.StatusNoContent)
@ -215,5 +221,14 @@ func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
}
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
}
// Invalidate cached session blobs so the new is_admin flag is reflected
// on the user's next request without waiting for the Redis TTL.
if err := h.sessions.InvalidateCacheForUser(r.Context(), userID); err != nil {
// Cache is best-effort; the DB is authoritative and requireAdmin
// always re-reads it.
_ = err
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,102 +0,0 @@
package api
import (
"context"
"fmt"
"time"
"github.com/gorilla/websocket"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}
// setAdminWSFlag marks the context as an admin WebSocket route.
func setAdminWSFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
}
// isAdminWSRoute checks if the request context was marked as admin WS.
func isAdminWSRoute(ctx context.Context) bool {
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
return v
}
// wsAuthMsg is the first message a browser WS client sends to authenticate.
type wsAuthMsg struct {
Type string `json:"type"`
Token string `json:"token"`
}
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
// authenticated context. The caller must send this as the first message after
// connecting.
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg wsAuthMsg
if err := conn.ReadJSON(&msg); err != nil {
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
}
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
if msg.Type != "auth" || msg.Token == "" {
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
}
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
}
user, err := queries.GetUserByID(ctx, userID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if user.Status != "active" {
return auth.AuthContext{}, fmt.Errorf("account deactivated")
}
return auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
}, nil
}
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
// returning an AuthContext with the platform team ID.
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
if err != nil {
return auth.AuthContext{}, err
}
user, err := queries.GetUserByID(ctx, ac.UserID)
if err != nil {
return auth.AuthContext{}, fmt.Errorf("user not found")
}
if !user.IsAdmin {
return auth.AuthContext{}, fmt.Errorf("admin access required")
}
ac.TeamID = id.PlatformTeamID
return ac, nil
}

View File

@ -2,6 +2,7 @@ package api
import (
"context"
"errors"
"log/slog"
"time"
@ -15,10 +16,30 @@ import (
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// errInferredTransientTimeout marks a state change that the reconciler
// inferred after a transient (starting/resuming) sandbox failed to settle
// within the grace period. Used as the err value on system audit calls so
// the published event carries Outcome=error with a human-readable message.
var errInferredTransientTimeout = errors.New("transient state did not settle within grace period")
// unreachableThreshold is how long a host can go without a heartbeat before
// it is considered unreachable (3 missed 30-second heartbeats).
const unreachableThreshold = 90 * time.Second
// transientGracePeriod is how long a sandbox is allowed to stay in a transient
// status (starting, resuming, pausing, stopping) before the monitor infers a
// final state. This prevents the monitor from racing against in-flight RPCs
// that may not have registered the sandbox on the host agent yet.
const transientGracePeriod = 2 * time.Minute
// snapshotGracePeriod is the grace for a sandbox stuck in "snapshotting" while
// the VM is still alive on the host. Snapshots dump guest RAM and flatten the
// rootfs, which can run for minutes on large sandboxes, and the agent reports
// the VM as alive throughout — so we must not race the in-flight operation.
// It exceeds the background goroutine's 10-minute deadline, so reaching it
// means the control plane crashed mid-snapshot and the sandbox needs recovery.
const snapshotGracePeriod = 15 * time.Minute
// HostMonitor runs on a fixed interval and performs two duties:
//
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
@ -77,6 +98,21 @@ func (m *HostMonitor) run(ctx context.Context) {
}
}
// ReconcileHost triggers immediate active reconciliation for a single host.
// Called when a host transitions from unreachable → online so sandboxes marked
// "missing" are resolved without waiting for the next monitor tick.
func (m *HostMonitor) ReconcileHost(ctx context.Context, hostID pgtype.UUID) {
host, err := m.db.GetHost(ctx, hostID)
if err != nil {
slog.Warn("host monitor: reconcile-on-connect: failed to get host", "error", err)
return
}
if host.Status != "online" {
return
}
m.checkHost(ctx, host)
}
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
// --- Passive phase: check heartbeat staleness ---
@ -116,21 +152,29 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
return
}
// Build set of sandbox IDs alive on the host.
// The host agent returns sandbox IDs as strings (formatted with prefix).
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
// Build map of sandbox ID -> reported status. Transient statuses
// (pausing/resuming/starting/stopping) are coerced to a presence-only
// entry: ListSandboxes can observe the in-memory status mid-transition
// (Pause flips the status under m.mu while List holds m.mu.RLock), and
// writing those transient labels into the DB would force the transient
// reconciliation phase to wait the full grace period before resolving.
// Recording the presence keeps "missing → restore" and "running →
// orphan-stop" logic correct without overwriting with stale labels;
// the next monitor tick reads the settled status.
aliveStatus := make(map[string]string, len(resp.Msg.Sandboxes))
for _, sb := range resp.Msg.Sandboxes {
alive[sb.SandboxId] = struct{}{}
}
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
for _, apID := range resp.Msg.AutoPausedSandboxIds {
autoPaused[apID] = struct{}{}
status := sb.Status
switch status {
case "pausing", "resuming", "starting", "stopping":
status = ""
}
aliveStatus[sb.SandboxId] = status
}
// --- Restore sandboxes that are "missing" in DB but alive on host ---
// This handles the case where CP marked them missing due to a transient
// heartbeat gap, but the host was actually fine.
// Handles transient heartbeat gaps where the host was actually fine. The
// reported status must be honored: a sandbox the agent paused while CP
// was disconnected must not be silently promoted back to running.
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
@ -139,34 +183,65 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
if err != nil {
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
var toRestore []pgtype.UUID
var toStop []pgtype.UUID
restoreByStatus := make(map[string][]db.Sandbox)
var toStop []db.Sandbox
for _, sb := range missingSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if _, ok := alive[sbIDStr]; ok {
toRestore = append(toRestore, sb.ID)
} else {
toStop = append(toStop, sb.ID)
status, ok := aliveStatus[sbIDStr]
if !ok {
toStop = append(toStop, sb)
continue
}
if status == "" {
continue
}
restoreByStatus[status] = append(restoreByStatus[status], sb)
}
if len(toRestore) > 0 {
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
for status, sbs := range restoreByStatus {
ids := make([]pgtype.UUID, len(sbs))
for i, sb := range sbs {
ids[i] = sb.ID
}
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
if err := m.db.BulkRestoreMissingToStatus(ctx, db.BulkRestoreMissingToStatusParams{
Column1: ids,
Status: status,
}); err != nil {
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
continue
}
// Only restore→paused emits a notification (per design: running restore is silent).
if status == "paused" {
for _, sb := range sbs {
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "restored_after_host_recovery", nil)
}
}
}
if len(toStop) > 0 {
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
ids := make([]pgtype.UUID, len(toStop))
for i, sb := range toStop {
ids[i] = sb.ID
}
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Column1: ids,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range toStop {
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
}
}
}
}
// --- Find running sandboxes in DB that are no longer alive on the host ---
// --- Reconcile running sandboxes in DB against live host state ---
// Three cases per DB-running row:
// absent on host -> stopped
// present and running -> no change
// present but paused/etc. -> sync DB to reported status (catches the
// shutdown-pause notify failure case)
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
@ -177,40 +252,196 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
return
}
var toPause, toStop []pgtype.UUID
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
var toStop []db.Sandbox
syncByStatus := make(map[string][]db.Sandbox)
for _, sb := range runningSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
sbTeamID[sb.ID] = sb.TeamID
if _, ok := alive[sbIDStr]; ok {
status, ok := aliveStatus[sbIDStr]
if !ok {
toStop = append(toStop, sb)
continue
}
if _, ok := autoPaused[sbIDStr]; ok {
toPause = append(toPause, sb.ID)
} else {
toStop = append(toStop, sb.ID)
if status == "running" || status == "" {
continue
}
syncByStatus[status] = append(syncByStatus[status], sb)
}
if len(toPause) > 0 {
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toPause,
Status: "paused",
}); err != nil {
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
}
for _, sbID := range toPause {
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
}
}
if len(toStop) > 0 {
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
ids := make([]pgtype.UUID, len(toStop))
for i, sb := range toStop {
ids[i] = sb.ID
}
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: toStop,
Column1: ids,
Status: "stopped",
}); err != nil {
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range toStop {
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
}
}
}
for status, sbs := range syncByStatus {
ids := make([]pgtype.UUID, len(sbs))
for i, sb := range sbs {
ids[i] = sb.ID
}
slog.Info("host monitor: syncing running→reported status", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: ids,
Status: status,
}); err != nil {
slog.Warn("host monitor: failed to sync running sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
continue
}
if status == "paused" {
for _, sb := range sbs {
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "host_state_sync", nil)
}
}
}
// --- Reconcile DB-stopped + agent-paused zombies ---
// A sandbox the agent reports as 'paused' but DB has as 'stopped' is an
// orphan from a previous bug where a successful pause's auto_paused
// callback was lost (e.g. CP unreachable during agent shutdown). With the
// agent-side fix (RestorePausedSandboxes), the snapshot survives across
// agent restarts and surfaces here. Authoritative direction: DB wins
// (user already saw 'stopped' and may have stopped tracking it).
// Issue Destroy so the on-disk snapshot dir is removed and the agent's
// slot reservation released.
//
// Gate: only run the DB query if the agent reports at least one paused
// sandbox. Otherwise we'd fetch every historically-stopped sandbox on
// this host every monitor tick — unbounded growth over a host's lifetime.
hasPaused := false
for _, status := range aliveStatus {
if status == "paused" {
hasPaused = true
break
}
}
if hasPaused {
stoppedSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"stopped"},
})
if err != nil {
slog.Warn("host monitor: failed to list stopped sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
} else {
for _, sb := range stoppedSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
status, ok := aliveStatus[sbIDStr]
if !ok || status != "paused" {
continue
}
slog.Info("host monitor: destroying DB-stopped agent-paused zombie",
"host_id", id.FormatHostID(host.ID), "sandbox_id", sbIDStr)
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sbIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("host monitor: zombie destroy failed",
"sandbox_id", sbIDStr, "error", err)
continue
}
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "paused_zombie_cleanup", nil)
}
}
}
// --- Reconcile transient statuses (starting, resuming, pausing, stopping) ---
// These represent in-flight operations. If the sandbox is no longer alive on
// the host, infer the final state based on the transient status.
transientSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: host.ID,
Column2: []string{"starting", "resuming", "pausing", "stopping", "snapshotting"},
})
if err != nil {
slog.Warn("host monitor: failed to list transient sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
return
}
for _, sb := range transientSandboxes {
sbIDStr := id.FormatSandboxID(sb.ID)
if agentStatus, ok := aliveStatus[sbIDStr]; ok {
// Sandbox is alive on host — the background goroutine should
// finalize the transition. For starting/resuming, if the sandbox
// is alive it means creation/resume succeeded.
if sb.Status == "starting" || sb.Status == "resuming" {
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: sb.Status, Status_2: "running",
}); err == nil {
slog.Info("host monitor: promoted transient sandbox to running", "sandbox_id", sbIDStr, "from", sb.Status)
}
}
// A snapshot keeps the source sandbox alive throughout, so an alive
// sandbox does NOT mean the snapshot finished. Only recover it once
// it has been stuck past the snapshot grace period (i.e. the CP
// crashed mid-op). Recover to the sandbox's actual host-side status:
// a running sandbox is snapshotted live and stays running, but a
// paused sandbox is snapshotted from disk and must return to paused.
if sb.Status == "snapshotting" &&
sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) >= snapshotGracePeriod {
recoverTo := agentStatus
if recoverTo != "running" && recoverTo != "paused" {
// Coerced/unknown agent label — default to running.
recoverTo = "running"
}
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: "snapshotting", Status_2: recoverTo,
}); err == nil {
slog.Info("host monitor: recovered stuck snapshotting sandbox", "sandbox_id", sbIDStr, "to", recoverTo)
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "snapshot_recovered", nil)
}
}
continue
}
// Sandbox is not alive on host. If the transition is recent, give the
// in-flight RPC time to finish before declaring a final state.
if sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) < transientGracePeriod {
slog.Debug("host monitor: transient sandbox still within grace period",
"sandbox_id", sbIDStr, "status", sb.Status,
"age", time.Since(sb.LastUpdated.Time).Round(time.Second))
continue
}
// Grace period expired — infer final state.
var finalStatus string
switch sb.Status {
case "starting", "resuming":
finalStatus = "error"
case "pausing":
finalStatus = "paused"
case "stopping":
finalStatus = "stopped"
case "snapshotting":
// VM is gone but DB says snapshotting → the snapshot died with the VM.
finalStatus = "error"
}
fromStatus := sb.Status
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sb.ID, Status: fromStatus, Status_2: finalStatus,
}); err == nil {
slog.Info("host monitor: resolved transient sandbox", "sandbox_id", sbIDStr, "from", fromStatus, "to", finalStatus)
inferredErr := errInferredTransientTimeout
switch fromStatus {
case "starting":
m.audit.LogSandboxCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "resuming":
m.audit.LogSandboxResumeSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "pausing":
// Pause assumed to have succeeded host-side; emit success with inferred metadata.
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
case "snapshotting":
// VM gone mid-snapshot; the sandbox is errored.
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
case "stopping":
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
}
}
}
}

View File

@ -50,7 +50,9 @@ func agentErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", err.Error()
case connect.CodeInvalidArgument:
return http.StatusBadRequest, "invalid_request", err.Error()
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
case connect.CodeAlreadyExists:
return http.StatusConflict, "already_exists", err.Error()
case connect.CodeFailedPrecondition:
return http.StatusConflict, "conflict", err.Error()
case connect.CodePermissionDenied:
return http.StatusForbidden, "forbidden", err.Error()

View File

@ -26,19 +26,10 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
}
}
// markAdminWS flags the request context as an admin WebSocket route.
// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin
// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin.
func markAdminWS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context())))
})
}
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// Must run after requireSession (depends on AuthContext being present).
// Re-validates against the DB — the session blob's is_admin flag is for
// UI hints; the DB is the source of truth for admin access.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -1,145 +0,0 @@
package api
import (
"log/slog"
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
// Both stamp TeamID into the request context via auth.AuthContext.
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try API key first.
if key := r.Header.Get("X-API-Key"); key != "" {
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err != nil {
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Try JWT bearer token from Authorization header.
tokenStr := ""
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr != "" {
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
if err != nil {
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject
// unauthenticated requests. It injects auth context when valid credentials
// are present (supporting SDK clients that set X-API-Key on WebSocket
// upgrades) and passes through otherwise so the handler can authenticate
// after the WebSocket upgrade via the first message.
func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try API key.
if key := r.Header.Get("X-API-Key"); key != "" {
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err == nil {
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// Try JWT bearer token.
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr := strings.TrimPrefix(header, "Bearer ")
if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil {
if teamID, err := id.ParseTeamID(claims.TeamID); err == nil {
if userID, err := id.ParseUserID(claims.Subject); err == nil {
if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" {
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
}
}
}
// No valid credentials — pass through for handler to authenticate.
next.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,38 @@
package api
import (
"crypto/subtle"
"net/http"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
)
// requireCSRF enforces double-submit CSRF: the wrenn_csrf cookie value must
// equal the X-CSRF-Token header. Skipped for safe methods (GET/HEAD/OPTIONS)
// and for requests authenticated via X-API-Key (SDK clients are not
// vulnerable to cross-site request forgery against cookie auth).
func requireCSRF() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, r)
return
}
// API-key auth path: no CSRF check needed.
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
next.ServeHTTP(w, r)
return
}
cookie, err := r.Cookie(csrfCookieName)
header := r.Header.Get("X-CSRF-Token")
if err != nil || cookie.Value == "" || header == "" ||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,69 +0,0 @@
package api
import (
"log/slog"
"net/http"
"strings"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// requireJWT validates a JWT from the Authorization: Bearer header.
// It also verifies the user is still active in the database.
// WebSocket upgrade requests without an Authorization header are passed through
// — WS handlers authenticate via the first message after upgrade.
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var tokenStr string
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}
claims, err := auth.VerifyJWT(secret, tokenStr)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
return
}
teamID, err := id.ParseTeamID(claims.TeamID)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
return
}
userID, err := id.ParseUserID(claims.Subject)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
return
}
// Verify user is still active in the database.
user, err := queries.GetUserByID(r.Context(), userID)
if err != nil {
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
return
}
if user.Status != "active" {
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
return
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
IsAdmin: claims.IsAdmin,
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,34 @@
package api
import (
"net/http"
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
sessionmw "git.omukk.dev/wrenn/wrenn/pkg/auth/session/middleware"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// Internal aliases — the canonical implementations live in the public
// pkg/auth/session/middleware package so cloud extensions can call them.
const csrfCookieName = sessionmw.CSRFCookieName
func requireSession(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
return sessionmw.RequireSession(svc, queries)
}
func requireSessionOrAPIKey(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
return sessionmw.RequireSessionOrAPIKey(svc, queries)
}
func setSessionCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
sessionmw.SetCookies(w, sid, csrfToken, secure)
}
func clearSessionCookies(w http.ResponseWriter, secure bool) {
sessionmw.ClearCookies(w, secure)
}
func isSecure(r *http.Request) bool {
return sessionmw.IsSecure(r)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,310 @@
package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const (
unifiedEventStream = "wrenn:events"
reconcilerConsumerGrp = "wrenn-sandbox-reconciler-v1"
reconcilerConsumer = "cp-0"
)
// SandboxEventConsumer reads capsule lifecycle events from the unified Redis
// stream and drives DB state reconciliation. Uses an independent consumer
// group so its cursor is separate from the channels dispatcher.
type SandboxEventConsumer struct {
rdb *redis.Client
db *db.Queries
audit *audit.AuditLogger
hooks []cpextension.SandboxEventHook
}
// NewSandboxEventConsumer creates a consumer.
func NewSandboxEventConsumer(rdb *redis.Client, queries *db.Queries, al *audit.AuditLogger, hooks []cpextension.SandboxEventHook) *SandboxEventConsumer {
return &SandboxEventConsumer{rdb: rdb, db: queries, audit: al, hooks: hooks}
}
// Start launches the consumer goroutine. Reads from "$" so prior history
// is not replayed.
func (c *SandboxEventConsumer) Start(ctx context.Context) {
go c.run(ctx)
}
func (c *SandboxEventConsumer) run(ctx context.Context) {
err := c.rdb.XGroupCreateMkStream(ctx, unifiedEventStream, reconcilerConsumerGrp, "$").Err()
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
slog.Error("sandbox event consumer: failed to create consumer group", "error", err)
return
}
for {
select {
case <-ctx.Done():
return
default:
}
streams, err := c.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: reconcilerConsumerGrp,
Consumer: reconcilerConsumer,
Streams: []string{unifiedEventStream, ">"},
Count: 10,
Block: 5 * time.Second,
}).Result()
if err != nil {
if err == redis.Nil || ctx.Err() != nil {
continue
}
slog.Warn("sandbox event consumer: xreadgroup error", "error", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, msg := range stream.Messages {
c.handleMessage(ctx, msg)
}
}
}
}
func (c *SandboxEventConsumer) handleMessage(ctx context.Context, msg redis.XMessage) {
ack := true
defer func() {
if !ack {
return
}
ackCtx, ackCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer ackCancel()
if err := c.rdb.XAck(ackCtx, unifiedEventStream, reconcilerConsumerGrp, msg.ID).Err(); err != nil {
slog.Warn("sandbox event consumer: xack failed", "id", msg.ID, "error", err)
}
}()
payload, ok := msg.Values["payload"].(string)
if !ok {
slog.Warn("sandbox event consumer: message missing payload", "id", msg.ID)
return
}
var event events.Event
if err := json.Unmarshal([]byte(payload), &event); err != nil {
slog.Warn("sandbox event consumer: failed to unmarshal event", "id", msg.ID, "error", err)
return
}
// Only capsule.* events drive sandbox reconciliation.
if !strings.HasPrefix(event.Event, "capsule.") || event.Event == events.CapsuleStateChanged {
return
}
// Only system-actor events represent host-side state we need to reflect
// in the DB; user-actor events are already mirrored by the handler that
// produced them.
if event.Actor.Type != events.ActorSystem {
// Exception: handlers publish capsule.create with user actor before
// the host has reported back. Those are owned by the service goroutine.
return
}
sandboxID, err := id.ParseSandboxID(event.Resource.ID)
if err != nil {
slog.Warn("sandbox event consumer: invalid sandbox ID", "sandbox_id", event.Resource.ID, "error", err)
return
}
switch event.Event {
case events.CapsuleCreate:
if event.Outcome == events.OutcomeSuccess {
c.handleStarted(ctx, sandboxID, event, "starting")
} else {
c.handleFailed(ctx, sandboxID, event)
}
case events.CapsuleResume:
if event.Outcome == events.OutcomeSuccess {
c.handleStarted(ctx, sandboxID, event, "resuming")
} else {
c.handleFailed(ctx, sandboxID, event)
}
case events.CapsulePause:
if event.Outcome == events.OutcomeSuccess {
c.handleAutoPaused(ctx, sandboxID)
}
case events.CapsuleDestroy:
if event.Outcome == events.OutcomeSuccess {
c.handleStopped(ctx, sandboxID)
}
}
// Dispatch to extension hooks (cloud billing, audit shipping, etc.). Any
// hook error suppresses the ack so the message will be redelivered. Hooks
// MUST be idempotent — duplicate deliveries are expected on transient
// failures.
if len(c.hooks) > 0 && event.Outcome == events.OutcomeSuccess {
if verb, ok := canonicalSandboxVerb(event.Event); ok {
teamID, _ := id.ParseTeamID(event.TeamID)
meta := map[string]any{}
for k, v := range event.Metadata {
meta[k] = v
}
ev := cpextension.SandboxEvent{
SandboxID: sandboxID,
TeamID: teamID,
Type: verb,
OccurredAt: parseEventTimestamp(event.Timestamp),
Metadata: meta,
}
for _, h := range c.hooks {
if err := h.OnSandboxEvent(ctx, ev); err != nil {
slog.Warn("sandbox event hook failed; leaving message un-acked", "id", msg.ID, "event", event.Event, "error", err)
ack = false
return
}
}
}
}
}
func canonicalSandboxVerb(event string) (string, bool) {
switch event {
case events.CapsuleCreate:
return "created", true
case events.CapsuleResume:
return "resumed", true
case events.CapsulePause:
return "paused", true
case events.CapsuleDestroy:
return "destroyed", true
}
return "", false
}
func parseEventTimestamp(s string) time.Time {
if s == "" {
return time.Now().UTC()
}
t, err := time.Parse(time.RFC3339, s)
if err != nil {
return time.Now().UTC()
}
return t
}
// handleStarted is a fallback writer for capsule.create.success and
// capsule.resume.success. The background goroutine in SandboxService is the
// primary writer; this only succeeds if the goroutine's conditional update
// was missed.
func (c *SandboxEventConsumer) handleStarted(ctx context.Context, sandboxID pgtype.UUID, event events.Event, fromStatus string) {
hostIP := event.Metadata["host_ip"]
now := time.Now()
if _, err := c.db.UpdateSandboxRunningIf(ctx, db.UpdateSandboxRunningIfParams{
ID: sandboxID,
Status: fromStatus,
HostIp: hostIP,
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
}); err != nil {
return
}
}
func (c *SandboxEventConsumer) handleAutoPaused(ctx context.Context, sandboxID pgtype.UUID) {
for _, fromStatus := range []string{"running", "pausing"} {
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID, Status: fromStatus, Status_2: "paused",
}); err == nil {
slog.Debug("sandbox event consumer: auto-paused fallback applied", "sandbox_id", id.FormatSandboxID(sandboxID), "from", fromStatus)
return
}
}
}
func (c *SandboxEventConsumer) handleStopped(ctx context.Context, sandboxID pgtype.UUID) {
// stopping → stopped (CP-initiated destroy completed). No audit row here;
// the handler that issued the destroy already wrote one.
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID,
Status: "stopping",
Status_2: "stopped",
}); err == nil {
return
}
// running → stopped (autonomous destroy, e.g. TTL destroy fallback).
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID,
Status: "running",
Status_2: "stopped",
}); err != nil && !errors.Is(err, pgx.ErrNoRows) {
slog.Warn("sandbox event consumer: failed to update sandbox to stopped", "sandbox_id", id.FormatSandboxID(sandboxID), "error", err)
}
}
// handleFailed marks a sandbox as "error" when a verb event reports failure
// and writes a system audit row. The DB update is idempotent — the
// SandboxService background goroutine usually wrote "error" already on the
// fast-fail path, which settles in seconds and so never reaches the
// HostMonitor's transient-timeout reconciliation.
//
// audit.Log writes the row only — it does NOT republish an event, which would
// loop back into this consumer. Do not switch to LogSandboxCreateSystem here.
func (c *SandboxEventConsumer) handleFailed(ctx context.Context, sandboxID pgtype.UUID, event events.Event) {
for _, fromStatus := range []string{"running", "starting", "pausing", "resuming", "snapshotting"} {
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
ID: sandboxID, Status: fromStatus, Status_2: "error",
}); err == nil {
break
}
}
// The HostMonitor transient-timeout reconciler emits failure events via
// LogSandboxCreateSystem / LogSandboxResumeSystem, which already write
// their own audit row before publishing — auditing again here would
// double-count. Those helpers publish with reason="transient_timeout";
// the un-audited fast-fail (createInBackground) and host-callback paths
// do not, so only they need a row written here.
if event.Metadata["reason"] == "transient_timeout" {
return
}
action := "create"
if event.Event == events.CapsuleResume {
action = "resume"
}
reason := event.Metadata["reason"]
if reason == "" {
reason = action + "_failed"
}
meta := map[string]any{"reason": reason}
if event.Error != "" {
meta["error"] = event.Error
}
teamID, _ := id.ParseTeamID(event.TeamID)
c.audit.Log(ctx, audit.Entry{
TeamID: teamID,
ActorType: "system",
ResourceType: "sandbox",
ResourceID: id.FormatSandboxID(sandboxID),
Action: action,
Scope: "team",
Status: "error",
Metadata: meta,
})
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
_ "embed"
"fmt"
"net/http"
@ -13,9 +14,11 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
authsession "git.omukk.dev/wrenn/wrenn/pkg/auth/session"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
"git.omukk.dev/wrenn/wrenn/pkg/service"
@ -26,15 +29,19 @@ var openapiYAML []byte
// Server is the control plane HTTP server.
type Server struct {
router chi.Router
BuildSvc *service.BuildService
version string
router chi.Router
BuildSvc *service.BuildService
SSERelay *SSERelay
SessionSvc *authsession.Service
version string
}
// New constructs the chi router and registers all routes.
// Extensions are called after core routes are registered, allowing cloud
// or third-party code to add routes and middleware.
// New constructs the chi router and registers all routes. The jwtSecret is
// still used to sign host-agent JWTs (long-lived, for the wrenn-agent → cp
// trust path) and to HMAC OAuth state/link cookies; user authentication has
// migrated to opaque session cookies backed by the session service.
func New(
ctx context.Context,
queries *db.Queries,
pool *lifecycle.HostClientPool,
sched scheduler.HostScheduler,
@ -45,10 +52,12 @@ func New(
oauthRedirectURL string,
ca *auth.CA,
al *audit.AuditLogger,
eventPub *channels.Publisher,
channelSvc *channels.Service,
mailer email.Mailer,
extensions []cpextension.Extension,
sctx cpextension.ServerContext,
monitor *HostMonitor,
version string,
) *Server {
r := chi.NewRouter()
@ -61,8 +70,28 @@ func New(
}
}
// Session service backs cookie-based browser auth. The cpserver wires it
// through ServerContext so cloud-repo extensions can share the same
// instance; fall back to constructing one here if the host program
// instantiates the API directly (tests, ad-hoc tooling).
sessionSvc := sctx.Sessions
if sessionSvc == nil {
sessionSvc = authsession.NewService(queries, rdb)
}
// Shared service layer.
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
sandboxSvc.PublishEvent = func(ctx context.Context, event service.SandboxStateEvent) {
if evt, ok := serviceEventToCanonical(event); ok {
// State-change events are ephemeral UI signals — mirror them to the
// dashboard via Pub/Sub only, never to durable channel subscribers.
if evt.Event == events.CapsuleStateChanged {
eventPub.PublishTransient(ctx, evt)
} else {
eventPub.Publish(ctx, evt)
}
}
}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
@ -72,30 +101,40 @@ func New(
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
usageSvc := &service.UsageService{DB: queries}
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
buildBroker := service.NewBuildBroker(rdb)
sandbox := newSandboxHandler(sandboxSvc, al)
exec := newExecHandler(queries, pool)
execStream := newExecStreamHandler(queries, pool, jwtSecret)
execStream := newExecStreamHandler(queries, pool)
files := newFilesHandler(queries, pool)
filesStream := newFilesStreamHandler(queries, pool)
fsH := newFSHandler(queries, pool)
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
snapshots := newSnapshotHandler(templateSvc, sandboxSvc, queries, pool, al)
authHooks := collectAuthHooks(extensions)
authH := newAuthHandler(queries, pgPool, sessionSvc, mailer, rdb, oauthRedirectURL, authHooks)
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, sessionSvc, oauthRegistry, oauthRedirectURL, authHooks)
apiKeys := newAPIKeyHandler(apiKeySvc, al)
hostH := newHostHandler(hostSvc, queries, al)
teamH := newTeamHandler(teamSvc, al, mailer)
usersH := newUsersHandler(queries, userSvc, al)
hostH := newHostHandler(hostSvc, queries, al, monitor)
teamH := newTeamHandler(teamSvc, al, mailer, sessionSvc)
usersH := newUsersHandler(queries, userSvc, al, sessionSvc)
auditH := newAuditHandler(auditSvc)
statsH := newStatsHandler(statsSvc)
usageH := newUsageHandler(usageSvc)
metricsH := newSandboxMetricsHandler(queries, pool)
buildH := newBuildHandler(buildSvc, queries, pool, al)
buildStreamH := newBuildStreamHandler(queries, buildBroker)
channelH := newChannelHandler(channelSvc, al)
ptyH := newPtyHandler(queries, pool, jwtSecret)
processH := newProcessHandler(queries, pool, jwtSecret)
ptyH := newPtyHandler(queries, pool)
processH := newProcessHandler(queries, pool)
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
sandboxEvtH := newSandboxEventHandler(queries, eventPub)
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, sessionSvc, mailer, oauthRegistry, oauthRedirectURL, teamSvc, authHooks)
sessionsH := newSessionsHandler(sessionSvc)
// SSE real-time event streaming.
sseBroker := NewSSEBroker()
sseRelay := NewSSERelay(rdb, queries, sseBroker)
sseH := newSSEHandler(sseBroker)
// Health check.
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
@ -107,7 +146,8 @@ func New(
r.Get("/openapi.yaml", serveOpenAPI)
r.Get("/docs", serveDocs)
// Unauthenticated auth endpoints.
// Unauthenticated auth endpoints. CSRF is not required here — there is
// no session cookie yet to be forged.
r.Post("/v1/auth/signup", authH.Signup)
r.Post("/v1/auth/login", authH.Login)
r.Post("/v1/auth/activate", authH.Activate)
@ -118,31 +158,45 @@ func New(
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
// JWT-authenticated: self-service account management.
csrf := requireCSRF()
// Session-authenticated: logout endpoints (must be inside the session
// group so the handler sees the current SID via AuthContext).
r.Group(func(r chi.Router) {
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/v1/auth/logout", authH.Logout)
r.Post("/v1/auth/logout-all", authH.LogoutAll)
r.Post("/v1/auth/switch-team", authH.SwitchTeam)
})
// Session-authenticated: self-service account management.
r.Route("/v1/me", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Get("/", meH.GetMe)
r.Patch("/", meH.UpdateName)
r.Post("/password", meH.ChangePassword)
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
r.Delete("/providers/{provider}", meH.DisconnectProvider)
r.Delete("/", meH.DeleteAccount)
r.Get("/sessions", sessionsH.List)
r.Delete("/sessions/{id}", sessionsH.Delete)
})
// JWT-authenticated: switch active team.
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
// JWT-authenticated: API key management.
// Session-authenticated: API key management.
r.Route("/v1/api-keys", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", apiKeys.Create)
r.Get("/", apiKeys.List)
r.Delete("/{id}", apiKeys.Delete)
})
// JWT-authenticated: team management.
// Session-authenticated: team management.
r.Route("/v1/teams", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Get("/", teamH.List)
r.Post("/", teamH.Create)
r.Route("/{id}", func(r chi.Router) {
@ -157,14 +211,15 @@ func New(
})
})
// JWT-authenticated: user search (for add-member UI).
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Session-authenticated: user search (for add-member UI).
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/users/search", usersH.Search)
// Capsule lifecycle: accepts API key or JWT bearer token.
// Capsule lifecycle: API key (SDK) or session cookie (browser).
r.Route("/v1/capsules", func(r chi.Router) {
// Auth-required routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
@ -174,7 +229,8 @@ func New(
r.Route("/{id}", func(r chi.Router) {
// Auth-required non-WS routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Get("/", sandbox.Get)
r.Delete("/", sandbox.Destroy)
r.Post("/exec", exec.Exec)
@ -193,11 +249,11 @@ func New(
r.Delete("/processes/{selector}", processH.KillProcess)
})
// WebSocket endpoints — handlers authenticate after upgrade.
// optionalAPIKeyOrJWT injects auth context from headers when
// present (SDK clients) but does not reject when absent (browsers).
// WebSocket endpoints — middleware injects auth context from
// cookie (browser) or X-API-Key (SDK) before upgrade. CSRF is
// not applicable to GET upgrades.
r.Group(func(r chi.Router) {
r.Use(optionalAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
@ -205,9 +261,10 @@ func New(
})
})
// Snapshot / template management: accepts API key or JWT bearer token.
// Snapshot / template management: API key (SDK) or session (browser).
r.Route("/v1/snapshots", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
r.Use(csrf)
r.Post("/", snapshots.Create)
r.Get("/", snapshots.List)
r.Delete("/{name}", snapshots.Delete)
@ -221,12 +278,14 @@ func New(
// Unauthenticated: refresh token exchange.
r.Post("/auth/refresh", hostH.RefreshToken)
// Host-token-authenticated: heartbeat.
// Host-token-authenticated: heartbeat and lifecycle callbacks.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
r.With(requireHostToken(jwtSecret)).Post("/sandbox-events", sandboxEvtH.Handle)
// JWT-authenticated: host CRUD and tags.
// Session-authenticated: host CRUD and tags.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", hostH.Create)
r.Get("/", hostH.List)
r.Route("/{id}", func(r chi.Router) {
@ -241,9 +300,10 @@ func New(
})
})
// JWT-authenticated: notification channels.
// Session-authenticated: notification channels.
r.Route("/v1/channels", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(csrf)
r.Post("/", channelH.Create)
r.Get("/", channelH.List)
r.Post("/test", channelH.Test)
@ -255,18 +315,27 @@ func New(
})
})
// JWT-authenticated: audit log.
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
// SSE event stream: browser sends wrenn_sid cookie on EventSource
// automatically; SDKs may set X-API-Key. Ticket-based auth is gone.
r.With(requireSessionOrAPIKey(queries, sessionSvc)).Get("/v1/events/stream", sseH.Stream)
// Platform admin routes — require JWT + DB-validated admin status.
// Session-authenticated: audit log.
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/audit-logs", auditH.List)
// Platform admin routes — session + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
// Admin SSE event stream (sees all teams).
r.With(requireSession(queries, sessionSvc), requireAdmin(queries), injectPlatformTeam()).Get("/events/stream", sseH.AdminStream)
// Auth-required admin routes (non-capsule + capsule list/create).
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(csrf)
r.Get("/teams", teamH.AdminListTeams)
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/hosts", hostH.AdminList)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
@ -281,11 +350,21 @@ func New(
r.Get("/capsules", adminCapsules.List)
})
// Admin build console WebSocket — cookie + admin DB check before
// upgrade, no CSRF (WS upgrade). Builds are platform-scoped, not
// sandbox-scoped, so this sits outside the /capsules/{id} router.
r.Group(func(r chi.Router) {
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Get("/builds/{id}/stream", buildStreamH.Stream)
})
r.Route("/capsules/{id}", func(r chi.Router) {
// Auth-required non-WS admin capsule routes.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(csrf)
r.Use(injectPlatformTeam())
r.Get("/", adminCapsules.Get)
r.Delete("/", adminCapsules.Destroy)
@ -301,11 +380,13 @@ func New(
r.Delete("/processes/{selector}", processH.KillProcess)
})
// Admin WebSocket endpoints — handlers authenticate after upgrade
// via wsAuthenticateAdmin. markAdminWS sets the context flag so
// handlers know to use admin auth instead of regular auth.
// Admin WebSocket endpoints — browser auth via cookie + admin DB check
// before upgrade. markAdminWS is retained as a context hint for any
// admin-specific behavior downstream.
r.Group(func(r chi.Router) {
r.Use(markAdminWS)
r.Use(requireSession(queries, sessionSvc))
r.Use(requireAdmin(queries))
r.Use(injectPlatformTeam())
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
@ -318,7 +399,7 @@ func New(
ext.RegisterRoutes(r, sctx)
}
return &Server{router: r, BuildSvc: buildSvc, version: version}
return &Server{router: r, BuildSvc: buildSvc, SSERelay: sseRelay, SessionSvc: sessionSvc, version: version}
}
// Handler returns the HTTP handler.
@ -363,3 +444,101 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
</body>
</html>`)
}
// serviceEventToCanonical maps a SandboxService background-goroutine event
// into the canonical events.Event taxonomy for unified publishing. Returns
// false for events that should not be broadcast.
func serviceEventToCanonical(e service.SandboxStateEvent) (events.Event, bool) {
var (
eventType string
outcome events.Outcome
metadata map[string]string
)
switch e.Event {
case "sandbox.started":
eventType = events.CapsuleCreate
outcome = events.OutcomeSuccess
case "sandbox.resumed":
eventType = events.CapsuleResume
outcome = events.OutcomeSuccess
case "sandbox.paused":
eventType = events.CapsulePause
outcome = events.OutcomeSuccess
case "sandbox.auto_paused":
eventType = events.CapsulePause
outcome = events.OutcomeSuccess
metadata = map[string]string{"reason": "ttl_expired"}
case "sandbox.stopped":
eventType = events.CapsuleDestroy
outcome = events.OutcomeSuccess
case "sandbox.pause_failed":
// reason must be non-empty or channels.isRedundantSystemFollowup
// filters this system-actor event out of webhook delivery.
eventType = events.CapsulePause
outcome = events.OutcomeError
metadata = map[string]string{"reason": "pause_failed"}
case "sandbox.resume_failed":
eventType = events.CapsuleResume
outcome = events.OutcomeError
metadata = map[string]string{"reason": "resume_failed"}
case "sandbox.failed":
// First-boot failure from the createInBackground goroutine. Without
// this case the event falls through to default and is dropped — no
// SSE push, no channel delivery, no DB reconciliation. reason must be
// non-empty or channels.isRedundantSystemFollowup filters it out.
eventType = events.CapsuleCreate
outcome = events.OutcomeError
metadata = map[string]string{"reason": "create_failed"}
case "sandbox.snapshotted":
// Completion of an async snapshot. The resource is the template name,
// not the sandbox, so the dashboard's snapshot list refreshes.
return events.Event{
Event: events.SnapshotCreate,
Outcome: events.OutcomeSuccess,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
}, true
case "sandbox.snapshot_failed":
return events.Event{
Event: events.SnapshotCreate,
Outcome: events.OutcomeError,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
Metadata: map[string]string{"reason": "snapshot_failed"},
Error: e.Error,
}, true
case "sandbox.state_changed":
// Transient badge transition with no terminal verb of its own. Carries
// from/to in metadata; routed via Pub/Sub only by the caller.
return events.Event{
Event: events.CapsuleStateChanged,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
Metadata: e.Metadata,
}, true
default:
return events.Event{}, false
}
if e.HostIP != "" {
if metadata == nil {
metadata = map[string]string{}
}
metadata["host_ip"] = e.HostIP
}
return events.Event{
Event: eventType,
Outcome: outcome,
Timestamp: events.Now(),
TeamID: e.TeamID,
Actor: events.SystemActor(),
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
Metadata: metadata,
Error: e.Error,
}, true
}

View File

@ -0,0 +1,80 @@
package api
import (
"encoding/json"
"log/slog"
"sync"
"sync/atomic"
)
const sseChannelBuffer = 32
type sseMessage struct {
EventType string
Data json.RawMessage
}
type sseSubscriber struct {
teamID string
isAdmin bool
ch chan sseMessage
}
// SSEBroker is an in-process fan-out hub that dispatches events to connected
// SSE clients, filtering by team ownership.
type SSEBroker struct {
mu sync.RWMutex
nextID atomic.Uint64
subscribers map[uint64]*sseSubscriber
}
// NewSSEBroker constructs a broker with no subscribers.
func NewSSEBroker() *SSEBroker {
return &SSEBroker{
subscribers: make(map[uint64]*sseSubscriber),
}
}
// Subscribe registers a new SSE client. Returns a subscriber ID (for
// Unsubscribe) and a receive-only channel that delivers filtered events.
func (b *SSEBroker) Subscribe(teamID string, isAdmin bool) (uint64, <-chan sseMessage) {
id := b.nextID.Add(1)
ch := make(chan sseMessage, sseChannelBuffer)
sub := &sseSubscriber{teamID: teamID, isAdmin: isAdmin, ch: ch}
b.mu.Lock()
b.subscribers[id] = sub
b.mu.Unlock()
slog.Debug("sse: client subscribed", "sub_id", id, "team_id", teamID, "admin", isAdmin)
return id, ch
}
// Unsubscribe removes a client. The handler loop exits via context cancellation;
// the channel is not closed here to avoid send-on-closed-channel races with Dispatch.
func (b *SSEBroker) Unsubscribe(id uint64) {
b.mu.Lock()
delete(b.subscribers, id)
b.mu.Unlock()
slog.Debug("sse: client unsubscribed", "sub_id", id)
}
// Dispatch fans out an event to all matching subscribers. Admin subscribers
// receive all events; team subscribers only receive events for their team.
// Non-blocking: events are dropped for slow consumers.
func (b *SSEBroker) Dispatch(eventType string, teamID string, data json.RawMessage) {
b.mu.RLock()
defer b.mu.RUnlock()
msg := sseMessage{EventType: eventType, Data: data}
for id, sub := range b.subscribers {
if !sub.isAdmin && sub.teamID != teamID {
continue
}
select {
case sub.ch <- msg:
default:
slog.Warn("sse: dropped event for slow consumer", "sub_id", id, "event", eventType)
}
}
}

147
internal/api/sse_relay.go Normal file
View File

@ -0,0 +1,147 @@
package api
import (
"context"
"encoding/json"
"log/slog"
"time"
"github.com/jackc/pgx/v5"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const ssePubSubChannel = "wrenn:sse"
// sseEventPayload is the JSON envelope sent to SSE clients.
type sseEventPayload struct {
Event string `json:"event"`
Outcome events.Outcome `json:"outcome,omitempty"`
Timestamp string `json:"timestamp"`
TeamID string `json:"team_id"`
Actor events.Actor `json:"actor"`
Resource events.Resource `json:"resource"`
Metadata map[string]string `json:"metadata,omitempty"`
Error string `json:"error,omitempty"`
Sandbox *sandboxResponse `json:"sandbox,omitempty"`
}
// SSERelay subscribes to the Redis Pub/Sub channel and dispatches hydrated
// events to the in-process broker. One instance per CP process.
type SSERelay struct {
rdb *redis.Client
db *db.Queries
broker *SSEBroker
}
// NewSSERelay constructs the relay.
func NewSSERelay(rdb *redis.Client, queries *db.Queries, broker *SSEBroker) *SSERelay {
return &SSERelay{rdb: rdb, db: queries, broker: broker}
}
// Start launches the Pub/Sub subscription goroutine. Returns when ctx is cancelled.
func (r *SSERelay) Start(ctx context.Context) {
go r.run(ctx)
}
func (r *SSERelay) run(ctx context.Context) {
for {
if ctx.Err() != nil {
return
}
r.subscribe(ctx)
// Backoff before reconnecting.
select {
case <-ctx.Done():
return
case <-time.After(2 * time.Second):
}
}
}
func (r *SSERelay) subscribe(ctx context.Context) {
pubsub := r.rdb.Subscribe(ctx, ssePubSubChannel)
defer pubsub.Close()
ch := pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
r.handleMessage(ctx, msg)
}
}
}
func (r *SSERelay) handleMessage(ctx context.Context, msg *redis.Message) {
var event events.Event
if err := json.Unmarshal([]byte(msg.Payload), &event); err != nil {
slog.Warn("sse relay: failed to unmarshal event", "error", err)
return
}
payload := sseEventPayload{
Event: event.Event,
Outcome: event.Outcome,
Timestamp: event.Timestamp,
TeamID: event.TeamID,
Actor: event.Actor,
Resource: event.Resource,
Metadata: event.Metadata,
Error: event.Error,
}
// Hydrate sandbox state for capsule events.
if isCapsuleEvent(event.Event) {
sb, err := r.hydrateSandbox(ctx, event.Resource.ID)
if err != nil {
slog.Debug("sse relay: sandbox hydration failed (may be deleted)", "sandbox_id", event.Resource.ID, "error", err)
} else {
payload.Sandbox = sb
}
}
data, err := json.Marshal(payload)
if err != nil {
slog.Warn("sse relay: failed to marshal payload", "error", err)
return
}
r.broker.Dispatch(event.Event, event.TeamID, data)
}
func (r *SSERelay) hydrateSandbox(ctx context.Context, sandboxIDStr string) (*sandboxResponse, error) {
queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
if err != nil {
return nil, err
}
sb, err := r.db.GetSandbox(queryCtx, sandboxID)
if err != nil {
if err == pgx.ErrNoRows {
return nil, nil
}
return nil, err
}
resp := sandboxToResponse(sb)
return &resp, nil
}
func isCapsuleEvent(eventType string) bool {
switch eventType {
case events.CapsuleCreate, events.CapsulePause, events.CapsuleResume, events.CapsuleDestroy, events.CapsuleStateChanged:
return true
}
return false
}

View File

@ -80,8 +80,8 @@ func (r *LoopRegistry) Release(imagePath string) {
e.refcount--
if e.refcount <= 0 {
if err := losetupDetach(e.device); err != nil {
slog.Warn("losetup detach failed", "device", e.device, "error", err)
if err := losetupDetachRetry(e.device); err != nil {
slog.Error("losetup detach failed, loop device leaked", "device", e.device, "image", imagePath, "error", err)
}
delete(r.entries, imagePath)
slog.Info("loop device released", "image", imagePath, "device", e.device)
@ -94,8 +94,8 @@ func (r *LoopRegistry) ReleaseAll() {
defer r.mu.Unlock()
for path, e := range r.entries {
if err := losetupDetach(e.device); err != nil {
slog.Warn("losetup detach failed", "device", e.device, "error", err)
if err := losetupDetachRetry(e.device); err != nil {
slog.Error("losetup detach failed during shutdown", "device", e.device, "image", path, "error", err)
}
delete(r.entries, path)
}
@ -109,6 +109,31 @@ type SnapshotDevice struct {
CowLoopDev string // loop device for the CoW file
}
// attachCowAndCreate attaches a CoW file as a loop device, creates the
// dm-snapshot target, and returns the assembled SnapshotDevice. On failure
// it detaches the CoW loop device before returning.
func attachCowAndCreate(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
cowLoopDev, err := losetupCreateRW(cowPath)
if err != nil {
return nil, fmt.Errorf("losetup cow: %w", err)
}
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil {
slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr)
}
return nil, fmt.Errorf("dmsetup create: %w", err)
}
return &SnapshotDevice{
Name: name,
DevicePath: "/dev/mapper/" + name,
CowPath: cowPath,
CowLoopDev: cowLoopDev,
}, nil
}
// CreateSnapshot sets up a new dm-snapshot device.
//
// It creates a sparse CoW file, attaches it as a loop device, and creates
@ -117,45 +142,24 @@ type SnapshotDevice struct {
//
// The origin loop device must already exist (from LoopRegistry.Acquire).
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
// Create sparse CoW file. The logical size limits how many blocks can be
// modified; because the file is sparse, only written blocks use real disk.
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
return nil, fmt.Errorf("create cow file: %w", err)
}
cowLoopDev, err := losetupCreateRW(cowPath)
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
if err != nil {
os.Remove(cowPath)
return nil, fmt.Errorf("losetup cow: %w", err)
return nil, err
}
// The dm-snapshot virtual device size must match the origin — the snapshot
// target maps 1:1 onto origin sectors. The CoW file just needs enough
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
}
os.Remove(cowPath)
return nil, fmt.Errorf("dmsetup create: %w", err)
}
devPath := "/dev/mapper/" + name
slog.Info("dm-snapshot created",
"name", name,
"device", devPath,
"device", dev.DevicePath,
"origin", originLoopDev,
"cow", cowPath,
)
return &SnapshotDevice{
Name: name,
DevicePath: devPath,
CowPath: cowPath,
CowLoopDev: cowLoopDev,
}, nil
return dev, nil
}
// RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file.
@ -171,34 +175,19 @@ func RestoreSnapshot(ctx context.Context, name, originLoopDev, cowPath string, o
}
}
cowLoopDev, err := losetupCreateRW(cowPath)
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
if err != nil {
return nil, fmt.Errorf("losetup cow: %w", err)
return nil, err
}
sectors := originSizeBytes / 512
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
}
return nil, fmt.Errorf("dmsetup create: %w", err)
}
devPath := "/dev/mapper/" + name
slog.Info("dm-snapshot restored",
"name", name,
"device", devPath,
"device", dev.DevicePath,
"origin", originLoopDev,
"cow", cowPath,
)
return &SnapshotDevice{
Name: name,
DevicePath: devPath,
CowPath: cowPath,
CowLoopDev: cowLoopDev,
}, nil
return dev, nil
}
// RemoveSnapshot tears down a dm-snapshot device and its CoW loop device.
@ -208,8 +197,8 @@ func RemoveSnapshot(ctx context.Context, dev *SnapshotDevice) error {
return fmt.Errorf("dmsetup remove %s: %w", dev.Name, err)
}
if err := losetupDetach(dev.CowLoopDev); err != nil {
slog.Warn("cow losetup detach failed", "device", dev.CowLoopDev, "error", err)
if err := losetupDetachRetry(dev.CowLoopDev); err != nil {
return fmt.Errorf("detach cow loop %s: %w", dev.CowLoopDev, err)
}
slog.Info("dm-snapshot removed", "name", dev.Name)
@ -272,6 +261,29 @@ func CleanupStaleDevices() {
}
}
// LogLoopState enumerates currently-attached loop devices that back wrenn
// rootfs images and logs them at INFO. Diagnostic only — meant to be called
// once at agent startup so leaked loop attachments from a prior crash are
// visible in the journal before the LoopRegistry starts refcounting.
func LogLoopState() {
out, err := exec.Command("losetup", "-l", "--noheadings", "--output", "NAME,BACK-FILE").CombinedOutput()
if err != nil {
slog.Debug("losetup -l failed", "error", err)
return
}
wrennCount := 0
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
if !strings.Contains(line, "/var/lib/wrenn/") {
continue
}
wrennCount++
slog.Info("pre-existing loop attachment", "entry", strings.TrimSpace(line))
}
if wrennCount == 0 {
slog.Info("no pre-existing wrenn loop attachments")
}
}
// --- low-level helpers ---
// losetupCreate attaches a file as a read-only loop device.
@ -297,6 +309,24 @@ func losetupDetach(dev string) error {
return exec.Command("losetup", "-d", dev).Run()
}
// losetupDetachRetry detaches a loop device with retries for transient
// "device busy" errors (kernel may still hold references briefly after
// dm-snapshot removal).
func losetupDetachRetry(dev string) error {
var lastErr error
for attempt := range 5 {
if attempt > 0 {
time.Sleep(200 * time.Millisecond)
}
if err := losetupDetach(dev); err == nil {
return nil
} else {
lastErr = err
}
}
return fmt.Errorf("after 5 attempts: %w", lastErr)
}
// dmsetupCreate creates a dm-snapshot device with persistent metadata.
func dmsetupCreate(name, originDev, cowDev string, sectors int64) error {
// Table format: <start> <size> snapshot <origin> <cow> P <chunk_size>
@ -316,7 +346,7 @@ func dmDeviceExists(name string) bool {
// dmsetupRemove removes a device-mapper device, retrying on transient
// "device busy" errors that occur when the kernel hasn't fully released
// the device after a Firecracker process exits.
// the device after a VMM process exits.
func dmsetupRemove(ctx context.Context, name string) error {
var lastErr error
for attempt := range 5 {
@ -361,5 +391,9 @@ func createSparseFile(path string, sizeBytes int64) error {
os.Remove(path)
return err
}
return f.Close()
if err := f.Close(); err != nil {
os.Remove(path)
return err
}
return nil
}

View File

@ -10,9 +10,12 @@ import (
"mime/multipart"
"net/http"
"net/url"
"time"
"connectrpc.com/connect"
"github.com/google/uuid"
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
"git.omukk.dev/wrenn/wrenn/proto/envd/gen/genconnect"
)
@ -78,16 +81,31 @@ type ExecResult struct {
ExitCode int32
}
// ExecOpts holds optional parameters for Exec.
type ExecOpts struct {
Envs map[string]string
Cwd string
}
// Exec runs a command inside the sandbox and collects all stdout/stderr output.
// It blocks until the command completes.
func (c *Client) Exec(ctx context.Context, cmd string, args ...string) (*ExecResult, error) {
func (c *Client) Exec(ctx context.Context, cmd string, args []string, opts *ExecOpts) (*ExecResult, error) {
stdin := false
proc := &envdpb.ProcessConfig{
Cmd: cmd,
Args: args,
}
if opts != nil {
if len(opts.Envs) > 0 {
proc.Envs = opts.Envs
}
if opts.Cwd != "" {
proc.Cwd = &opts.Cwd
}
}
req := connect.NewRequest(&envdpb.StartRequest{
Process: &envdpb.ProcessConfig{
Cmd: cmd,
Args: args,
},
Stdin: &stdin,
Process: proc,
Stdin: &stdin,
})
stream, err := c.process.Start(ctx, req)
@ -294,7 +312,7 @@ func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which stops
// the port scanner/forwarder and marks active connections for post-restore
// cleanup before Firecracker freezes vCPUs.
// cleanup before the VMM freezes vCPUs.
//
// Best-effort: the caller should log a warning on error but not abort the pause.
func (c *Client) PrepareSnapshot(ctx context.Context) error {
@ -317,27 +335,135 @@ func (c *Client) PrepareSnapshot(ctx context.Context) error {
return nil
}
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
// env vars and the corresponding files under /run/wrenn/ inside the guest.
// Must be called after snapshot restore so envd picks up the new sandbox's metadata.
// MemoryPreloadStatus mirrors envd's /memory/preload response.
//
// State values: "idle", "running", "done", "failed", "cancelled".
type MemoryPreloadStatus struct {
State string `json:"state"`
Regions uint64 `json:"regions"`
Pages uint64 `json:"pages"`
Bytes uint64 `json:"bytes"`
ElapsedSec float64 `json:"elapsed_sec"`
Source string `json:"source"`
Error string `json:"error,omitempty"`
}
// StartMemoryPreload posts to envd's /memory/preload to spawn a guest-side
// loader that reads every physical RAM page. The request returns immediately
// after the loader is queued — the actual materialisation runs in a detached
// thread inside envd. Required after a snapshot restore with
// memory_restore_mode=ondemand so the next ch.snapshot writes a
// self-contained memory-ranges file.
//
// Use WaitMemoryPreload to block on completion or GetMemoryPreloadStatus to
// query progress.
func (c *Client) StartMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
return c.memoryPreloadRequest(ctx, http.MethodPost)
}
// GetMemoryPreloadStatus reads envd's /memory/preload status without
// starting a new loader.
func (c *Client) GetMemoryPreloadStatus(ctx context.Context) (MemoryPreloadStatus, error) {
return c.memoryPreloadRequest(ctx, http.MethodGet)
}
func (c *Client) memoryPreloadRequest(ctx context.Context, method string) (MemoryPreloadStatus, error) {
var status MemoryPreloadStatus
req, err := http.NewRequestWithContext(ctx, method, c.base+"/memory/preload", nil)
if err != nil {
return status, fmt.Errorf("create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return status, fmt.Errorf("memory preload %s: %w", method, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return status, fmt.Errorf("memory preload %s: status %d: %s", method, resp.StatusCode, string(body))
}
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return status, fmt.Errorf("memory preload %s: decode: %w", method, err)
}
return status, nil
}
// WaitMemoryPreload polls envd until the loader is no longer running or ctx
// is cancelled. Returns the final status. Polling interval is fixed at 1s —
// the loader runs for many seconds to minutes, so finer polling wastes RPCs.
func (c *Client) WaitMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
status, err := c.GetMemoryPreloadStatus(ctx)
if err != nil {
return status, err
}
if status.State != "running" {
return status, nil
}
select {
case <-ctx.Done():
return status, ctx.Err()
case <-ticker.C:
}
}
}
// CancelMemoryPreload signals the in-guest memory preloader to stop early.
// Used during teardown so a pause/destroy doesn't have to wait for a
// multi-hundred-MiB read to finish.
func (c *Client) CancelMemoryPreload(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/memory/preload/cancel", nil)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("preload cancel: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("preload cancel: status %d: %s", resp.StatusCode, string(body))
}
return nil
}
// PostInit calls envd's POST /init endpoint to trigger post-boot or
// post-restore initialization. sandbox_id and template_id are passed
// so envd can set WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID env vars.
func (c *Client) PostInit(ctx context.Context) error {
return c.PostInitWithDefaults(ctx, "", nil)
return c.PostInitWithDefaults(ctx, "", nil, "", "")
}
// PostInitWithDefaults calls envd's POST /init endpoint with optional default
// user and environment variables. These are applied to envd's defaults so all
// subsequent process executions use them.
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string) error {
// user, environment variables, and sandbox metadata. These are applied to
// envd's defaults so all subsequent process executions use them.
//
// timestamp and lifecycle_id are always populated: envd uses them to snap
// the guest clock to the host's wall time and to detect post-resume calls
// (which trigger port-forwarder restart + NFS remount).
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string, sandboxID, templateID string) error {
payload := map[string]any{
"timestamp": time.Now().UTC().Format(time.RFC3339Nano),
"lifecycle_id": uuid.NewString(),
}
if defaultUser != "" {
payload["defaultUser"] = defaultUser
}
if len(envVars) > 0 {
payload["envVars"] = envVars
}
if sandboxID != "" {
payload["sandbox_id"] = sandboxID
}
if templateID != "" {
payload["template_id"] = templateID
}
var body io.Reader
if defaultUser != "" || len(envVars) > 0 {
payload := make(map[string]any)
if defaultUser != "" {
payload["defaultUser"] = defaultUser
}
if len(envVars) > 0 {
payload["envVars"] = envVars
}
if len(payload) > 0 {
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal init body: %w", err)

View File

@ -59,6 +59,28 @@ func (c *Client) FetchVersion(ctx context.Context) (string, error) {
return data.Version, nil
}
// WaitUntilRPCReady polls envd's Connect RPC layer until it responds
// successfully or the context is cancelled. This catches cases where envd's
// HTTP health endpoint works but the Connect protocol layer is not yet
// functional (e.g., after VM snapshot restore).
func (c *Client) WaitUntilRPCReady(ctx context.Context) error {
const retryInterval = 200 * time.Millisecond
ticker := time.NewTicker(retryInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("envd RPC not ready: %w", ctx.Err())
case <-ticker.C:
if _, err := c.ListProcesses(ctx); err == nil {
return nil
}
}
}
}
// healthCheck sends a single GET /health request to envd.
func (c *Client) healthCheck(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)

View File

@ -0,0 +1,129 @@
package hostagent
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
)
// CallbackEvent is the payload sent to the CP's sandbox event callback endpoint.
type CallbackEvent struct {
Event string `json:"event"`
SandboxID string `json:"sandbox_id"`
HostID string `json:"host_id"`
Timestamp int64 `json:"timestamp"`
}
// CallbackSender sends sandbox lifecycle events to the CP via HTTP POST.
// Used for autonomous agent-side events (auto-pause, auto-destroy) that
// the CP cannot observe through its own RPC goroutines.
type CallbackSender struct {
cpURL string
hostID string
credFile string
client *http.Client
mu sync.RWMutex
jwt string
}
// NewCallbackSender creates a callback sender.
func NewCallbackSender(cpURL, credFile, hostID string) *CallbackSender {
jwt := ""
if tf, err := LoadTokenFile(credFile); err == nil {
jwt = tf.JWT
}
return &CallbackSender{
cpURL: strings.TrimRight(cpURL, "/"),
hostID: hostID,
credFile: credFile,
client: &http.Client{Timeout: 10 * time.Second},
jwt: jwt,
}
}
// UpdateJWT refreshes the JWT used for callback authentication.
// Called from the heartbeat's onCredsRefreshed hook.
func (s *CallbackSender) UpdateJWT(jwt string) {
s.mu.Lock()
s.jwt = jwt
s.mu.Unlock()
}
func (s *CallbackSender) getJWT() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.jwt
}
// Send sends a callback event to the CP synchronously with retries.
func (s *CallbackSender) Send(ctx context.Context, ev CallbackEvent) error {
ev.HostID = s.hostID
if ev.Timestamp == 0 {
ev.Timestamp = time.Now().Unix()
}
body, err := json.Marshal(ev)
if err != nil {
return fmt.Errorf("marshal callback event: %w", err)
}
url := s.cpURL + "/v1/hosts/sandbox-events"
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Duration(attempt) * 500 * time.Millisecond):
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create callback request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Host-Token", s.getJWT())
resp, err := s.client.Do(req)
if err != nil {
lastErr = err
continue
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
if newCreds, refreshErr := RefreshCredentials(ctx, s.cpURL, s.credFile); refreshErr == nil {
s.UpdateJWT(newCreds.JWT)
}
lastErr = fmt.Errorf("callback auth failed: %d", resp.StatusCode)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
lastErr = fmt.Errorf("callback failed: status %d", resp.StatusCode)
}
return fmt.Errorf("callback failed after 3 attempts: %w", lastErr)
}
// SendAsync sends a callback event in a background goroutine.
func (s *CallbackSender) SendAsync(ev CallbackEvent) {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.Send(ctx, ev); err != nil {
slog.Warn("callback send failed (reconciler will catch it)", "event", ev.Event, "sandbox_id", ev.SandboxID, "error", err)
}
}()
}

View File

@ -0,0 +1,31 @@
package hostagent
import (
"context"
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
)
// callbackAdapter adapts CallbackSender to satisfy sandbox.EventSender.
type callbackAdapter struct {
sender *CallbackSender
}
// NewEventSender wraps a CallbackSender as a sandbox.EventSender.
func NewEventSender(sender *CallbackSender) sandbox.EventSender {
return &callbackAdapter{sender: sender}
}
func (a *callbackAdapter) SendAsync(event sandbox.LifecycleEvent) {
a.sender.SendAsync(CallbackEvent{
Event: event.Event,
SandboxID: event.SandboxID,
})
}
func (a *callbackAdapter) Send(ctx context.Context, event sandbox.LifecycleEvent) error {
return a.sender.Send(ctx, CallbackEvent{
Event: event.Event,
SandboxID: event.SandboxID,
})
}

View File

@ -2,13 +2,14 @@ package hostagent
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"net/url"
"strings"
"os"
"time"
"connectrpc.com/connect"
@ -19,6 +20,7 @@ import (
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
)
@ -49,38 +51,48 @@ func parseUUIDString(s string) (pgtype.UUID, error) {
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
}
// parseSandboxIDs parses the team+template UUID pair every snapshot-targeting
// RPC handler receives, returning a CodeInvalidArgument Connect error on the
// first failure so the caller can `return nil, err` directly.
func parseSandboxIDs(teamIDStr, templateIDStr string) (teamID, templateID pgtype.UUID, err error) {
teamID, err = parseUUIDString(teamIDStr)
if err != nil {
return pgtype.UUID{}, pgtype.UUID{}, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err = parseUUIDString(templateIDStr)
if err != nil {
return pgtype.UUID{}, pgtype.UUID{}, connect.NewError(connect.CodeInvalidArgument, err)
}
return teamID, templateID, nil
}
func (s *Server) CreateSandbox(
ctx context.Context,
req *connect.Request[pb.CreateSandboxRequest],
) (*connect.Response[pb.CreateSandboxResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
teamID, templateID, err := parseSandboxIDs(msg.TeamId, msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
return nil, err
}
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
sb, diskSizeBytes, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID,
int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb),
msg.DefaultUser, msg.DefaultEnv)
if err != nil {
if errors.Is(err, sandbox.ErrDraining) {
return nil, connect.NewError(connect.CodeUnavailable, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
}
// Apply template defaults (user, env vars) if provided.
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
slog.Warn("failed to set sandbox defaults", "sandbox", sb.ID, "error", err)
}
}
return connect.NewResponse(&pb.CreateSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
Metadata: sb.Metadata,
SandboxId: sb.ID,
Status: string(sb.Status),
HostIp: sb.HostIP.String(),
Metadata: sb.Metadata,
DiskSizeMb: int32(diskSizeBytes / (1024 * 1024)),
}), nil
}
@ -89,7 +101,7 @@ func (s *Server) DestroySandbox(
req *connect.Request[pb.DestroySandboxRequest],
) (*connect.Response[pb.DestroySandboxResponse], error) {
if err := s.mgr.Destroy(ctx, req.Msg.SandboxId); err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.DestroySandboxResponse{}), nil
}
@ -99,7 +111,7 @@ func (s *Server) PauseSandbox(
req *connect.Request[pb.PauseSandboxRequest],
) (*connect.Response[pb.PauseSandboxResponse], error) {
if err := s.mgr.Pause(ctx, req.Msg.SandboxId); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.PauseSandboxResponse{}), nil
}
@ -108,12 +120,10 @@ func (s *Server) ResumeSandbox(
ctx context.Context,
req *connect.Request[pb.ResumeSandboxRequest],
) (*connect.Response[pb.ResumeSandboxResponse], error) {
msg := req.Msg
sb, err := s.mgr.Resume(ctx, msg.SandboxId, int(msg.TimeoutSec), msg.KernelVersion, msg.DefaultUser, msg.DefaultEnv)
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec), req.Msg.DefaultUser, req.Msg.KernelVersion, req.Msg.DefaultEnv)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.ResumeSandboxResponse{
SandboxId: sb.ID,
Status: string(sb.Status),
@ -126,41 +136,30 @@ func (s *Server) CreateSnapshot(
ctx context.Context,
req *connect.Request[pb.CreateSnapshotRequest],
) (*connect.Response[pb.CreateSnapshotResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
return nil, err
}
templateID, err := parseUUIDString(msg.TemplateId)
size, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, teamID, templateID, req.Msg.Name)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.CreateSnapshotResponse{
SizeBytes: sizeBytes,
Name: req.Msg.Name,
SizeBytes: size,
}), nil
}
func (s *Server) DeleteSnapshot(
ctx context.Context,
_ context.Context,
req *connect.Request[pb.DeleteSnapshotRequest],
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
return nil, err
}
templateID, err := parseUUIDString(msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
}
@ -169,22 +168,54 @@ func (s *Server) FlattenRootfs(
ctx context.Context,
req *connect.Request[pb.FlattenRootfsRequest],
) (*connect.Response[pb.FlattenRootfsResponse], error) {
msg := req.Msg
teamID, err := parseUUIDString(msg.TeamId)
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
return nil, err
}
templateID, err := parseUUIDString(msg.TemplateId)
size, err := s.mgr.FlattenRootfs(ctx, req.Msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.FlattenRootfsResponse{
SizeBytes: sizeBytes,
SizeBytes: size,
}), nil
}
// mapSandboxError translates sandbox.Manager errors to Connect error codes
// via sentinel errors (errors.Is). Adding a new precondition sentinel in the
// sandbox package only requires extending this switch — no string sniffing.
func mapSandboxError(err error) error {
switch {
case errors.Is(err, sandbox.ErrNotFound):
return connect.NewError(connect.CodeNotFound, err)
case errors.Is(err, sandbox.ErrNotRunning), errors.Is(err, sandbox.ErrNotPaused):
return connect.NewError(connect.CodeFailedPrecondition, err)
case errors.Is(err, sandbox.ErrDraining):
return connect.NewError(connect.CodeUnavailable, err)
case errors.Is(err, sandbox.ErrInvalidRange):
return connect.NewError(connect.CodeInvalidArgument, err)
default:
return connect.NewError(connect.CodeInternal, err)
}
}
func (s *Server) GetTemplateSize(
ctx context.Context,
req *connect.Request[pb.GetTemplateSizeRequest],
) (*connect.Response[pb.GetTemplateSizeResponse], error) {
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
if err != nil {
return nil, err
}
size, err := s.mgr.TemplateRootfsSize(teamID, templateID)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("get template size: %w", err))
}
return connect.NewResponse(&pb.GetTemplateSizeResponse{
SizeBytes: size,
}), nil
}
@ -193,7 +224,7 @@ func (s *Server) PingSandbox(
req *connect.Request[pb.PingSandboxRequest],
) (*connect.Response[pb.PingSandboxResponse], error) {
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
@ -215,7 +246,12 @@ func (s *Server) Exec(
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
var opts *envdclient.ExecOpts
if len(msg.Envs) > 0 || msg.Cwd != "" {
opts = &envdclient.ExecOpts{Envs: msg.Envs, Cwd: msg.Cwd}
}
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args, opts)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("exec: %w", err))
}
@ -227,6 +263,17 @@ func (s *Server) Exec(
}), nil
}
// envdErr propagates an error from the envd client, preserving its Connect
// error code (e.g. AlreadyExists, NotFound) so the control plane maps it to
// the correct HTTP status. Non-Connect errors fall back to CodeInternal.
func envdErr(action string, err error) error {
code := connect.CodeOf(err)
if code == connect.CodeUnknown {
code = connect.CodeInternal
}
return connect.NewError(code, fmt.Errorf("%s: %w", action, err))
}
func (s *Server) WriteFile(
ctx context.Context,
req *connect.Request[pb.WriteFileRequest],
@ -239,7 +286,7 @@ func (s *Server) WriteFile(
}
if err := client.WriteFile(ctx, msg.Path, msg.Content); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file: %w", err))
return nil, envdErr("write file", err)
}
return connect.NewResponse(&pb.WriteFileResponse{}), nil
@ -258,7 +305,7 @@ func (s *Server) ReadFile(
content, err := client.ReadFile(ctx, msg.Path)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("read file: %w", err))
return nil, envdErr("read file", err)
}
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
@ -277,7 +324,7 @@ func (s *Server) ListDir(
resp, err := client.ListDir(ctx, msg.Path, msg.Depth)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list dir: %w", err))
return nil, envdErr("list dir", err)
}
entries := make([]*pb.FileEntry, 0, len(resp.Entries))
@ -301,7 +348,7 @@ func (s *Server) MakeDir(
resp, err := client.MakeDir(ctx, msg.Path)
if err != nil {
return nil, fmt.Errorf("make dir: %w", err)
return nil, envdErr("make dir", err)
}
return connect.NewResponse(&pb.MakeDirResponse{
@ -321,7 +368,7 @@ func (s *Server) RemovePath(
}
if err := client.Remove(ctx, msg.Path); err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("remove: %w", err))
return nil, envdErr("remove", err)
}
return connect.NewResponse(&pb.RemovePathResponse{}), nil
@ -373,6 +420,8 @@ func (s *Server) ExecStream(
Error: ev.Error,
},
}
default:
continue
}
if err := stream.Send(&resp); err != nil {
return err
@ -548,6 +597,14 @@ func (s *Server) ListSandboxes(
infos := make([]*pb.SandboxInfo, len(sandboxes))
for i, sb := range sandboxes {
// Paused / restored-paused sandboxes have no active network slot, so
// HostIP is nil — net.IP(nil).String() returns the literal "<nil>"
// which would leak into DB host_ip columns and SDK responses. Emit
// empty string instead.
hostIP := ""
if sb.HostIP != nil {
hostIP = sb.HostIP.String()
}
infos[i] = &pb.SandboxInfo{
SandboxId: sb.ID,
Status: string(sb.Status),
@ -555,7 +612,7 @@ func (s *Server) ListSandboxes(
TemplateId: uuid.UUID(sb.TemplateID).String(),
Vcpus: int32(sb.VCPUs),
MemoryMb: int32(sb.MemoryMB),
HostIp: sb.HostIP.String(),
HostIp: hostIP,
CreatedAtUnix: sb.CreatedAt.Unix(),
LastActiveAtUnix: sb.LastActiveAt.Unix(),
TimeoutSec: int32(sb.TimeoutSec),
@ -588,13 +645,7 @@ func (s *Server) GetSandboxMetrics(
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if strings.Contains(err.Error(), "invalid range") {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
return nil, connect.NewError(connect.CodeInternal, err)
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.GetSandboxMetricsResponse{Points: metricPointsToPB(points)}), nil
@ -606,10 +657,7 @@ func (s *Server) FlushSandboxMetrics(
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, err)
return nil, mapSandboxError(err)
}
return connect.NewResponse(&pb.FlushSandboxMetricsResponse{
@ -759,7 +807,7 @@ func (s *Server) StartBackground(
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
@ -777,7 +825,7 @@ func (s *Server) ListProcesses(
) (*connect.Response[pb.ListProcessesResponse], error) {
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
@ -828,7 +876,7 @@ func (s *Server) KillProcess(
}
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
@ -857,7 +905,7 @@ func (s *Server) ConnectProcess(
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return connect.NewError(connect.CodeNotFound, err)
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
@ -889,6 +937,8 @@ func (s *Server) ConnectProcess(
Error: ev.Error,
},
}
default:
continue
}
if err := stream.Send(&resp); err != nil {
return err

View File

@ -6,26 +6,28 @@ import (
"path/filepath"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// IsMinimal reports whether the given team and template IDs represent the
// built-in "minimal" template (both all-zeros).
func IsMinimal(teamID, templateID pgtype.UUID) bool {
return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
func timeNowNano() int64 { return time.Now().UnixNano() }
// IsSystemTemplate reports whether the given team and template IDs represent a
// built-in system base template (minimal-ubuntu / -alpine / -arch / -fedora):
// platform-owned with a template ID in the reserved range. System templates are
// protected from deletion.
func IsSystemTemplate(teamID, templateID pgtype.UUID) bool {
return teamID.Bytes == id.PlatformTeamID.Bytes && id.IsReservedTemplateID(templateID)
}
// TemplateDir returns the on-disk directory for a template.
// TemplateDir returns the on-disk directory for a template. Every template —
// including the built-in system base templates — lives under the teams tree:
//
// minimal (zeros, zeros): {wrennDir}/images/minimal
// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
// {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
if IsMinimal(teamID, templateID) {
return filepath.Join(wrennDir, "images", "minimal")
}
return filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(teamID.Bytes),
id.UUIDToBase36(templateID.Bytes))
@ -36,17 +38,64 @@ func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
}
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
func PauseSnapshotDir(wrennDir, sandboxID string) string {
return filepath.Join(wrennDir, "snapshots", sandboxID)
// IsSnapshotTemplate reports whether dir contains a Cloud Hypervisor memory
// snapshot (state.json + config.json + memory-ranges) alongside the flattened
// rootfs.ext4. Used to distinguish snapshot templates (launch via CH restore)
// from base/disk-only templates (launch via fresh boot).
//
// state.json is CH-authoritative — its presence indicates a complete snapshot.
func IsSnapshotTemplate(dir string) bool {
for _, name := range []string{"state.json", "config.json", "rootfs.ext4"} {
if _, err := os.Stat(filepath.Join(dir, name)); err != nil {
return false
}
}
return true
}
// SandboxesDir returns the directory for running sandbox CoW files.
// SandboxCowName is the filename for a sandbox's CoW rootfs diff, kept inside
// the per-sandbox directory alongside any pause snapshot files.
const SandboxCowName = "rootfs.cow"
// SandboxDir returns the per-sandbox directory under sandboxes/. It holds
// the CoW file and, if the sandbox is paused, the snapshot files.
//
// Layout:
//
// {wrennDir}/sandboxes/{id}/rootfs.cow CoW file (persistent across pause/resume)
// {wrennDir}/sandboxes/{id}/ paused snapshot (config.json, state.json, memory-ranges, wrenn-snapshot.json)
// {wrennDir}/sandboxes/{id}.staging-*/ in-flight Pause writes (cleaned up by swapDir or startup GC)
// {wrennDir}/sandboxes/{id}.trash-*/ mid-swap previous generation (cleaned up by swapDir or startup GC)
func SandboxDir(wrennDir, sandboxID string) string {
return filepath.Join(wrennDir, "sandboxes", sandboxID)
}
// SandboxCowPath returns the path to a sandbox's CoW rootfs diff file.
func SandboxCowPath(wrennDir, sandboxID string) string {
return filepath.Join(SandboxDir(wrennDir, sandboxID), SandboxCowName)
}
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
// Same path as SandboxDir — pause snapshot files live alongside the CoW.
func PauseSnapshotDir(wrennDir, sandboxID string) string {
return SandboxDir(wrennDir, sandboxID)
}
// PauseStagingDir returns a fresh staging directory for an in-flight Pause.
// Each call returns a unique path (timestamped) so concurrent retries do not
// collide.
func PauseStagingDir(wrennDir, sandboxID string) string {
return filepath.Join(wrennDir, "sandboxes",
fmt.Sprintf("%s.staging-%d", sandboxID, timeNowNano()))
}
// SandboxesDir returns the directory for running sandbox CoW files and paused
// snapshot directories.
func SandboxesDir(wrennDir string) string {
return filepath.Join(wrennDir, "sandboxes")
}
// KernelPath returns the path to the Firecracker kernel.
// KernelPath returns the path to the VM kernel.
func KernelPath(wrennDir string) string {
return filepath.Join(wrennDir, "kernels", "vmlinux")
}

View File

@ -9,7 +9,7 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
func TestIsMinimal(t *testing.T) {
func TestIsSystemTemplate(t *testing.T) {
tests := []struct {
name string
teamID pgtype.UUID
@ -17,35 +17,41 @@ func TestIsMinimal(t *testing.T) {
want bool
}{
{
name: "both zeros",
name: "ubuntu (zeros, zeros)",
teamID: id.PlatformTeamID,
templateID: id.MinimalTemplateID,
templateID: id.UbuntuTemplateID,
want: true,
},
{
name: "non-zero team",
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
templateID: id.MinimalTemplateID,
want: false,
},
{
name: "non-zero template",
name: "fedora (platform, id 3)",
teamID: id.PlatformTeamID,
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
templateID: id.FedoraTemplateID,
want: true,
},
{
name: "platform, max reserved id",
teamID: id.PlatformTeamID,
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x00}, Valid: true}, // 1024
want: true,
},
{
name: "platform, above reserved range",
teamID: id.PlatformTeamID,
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x01}, Valid: true}, // 1025
want: false,
},
{
name: "both non-zero",
teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
name: "non-platform team, reserved id",
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
templateID: id.UbuntuTemplateID,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
if got := IsSystemTemplate(tt.teamID, tt.templateID); got != tt.want {
t.Errorf("IsSystemTemplate() = %v, want %v", got, tt.want)
}
})
}
@ -54,9 +60,11 @@ func TestIsMinimal(t *testing.T) {
func TestTemplateDir(t *testing.T) {
wrennDir := "/var/lib/wrenn"
t.Run("minimal", func(t *testing.T) {
got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal")
t.Run("system base template (ubuntu) lives under teams", func(t *testing.T) {
got := TemplateDir(wrennDir, id.PlatformTeamID, id.UbuntuTemplateID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(id.PlatformTeamID.Bytes),
id.UUIDToBase36(id.UbuntuTemplateID.Bytes))
if got != want {
t.Errorf("TemplateDir() = %q, want %q", got, want)
}
@ -88,8 +96,11 @@ func TestTemplateDir(t *testing.T) {
func TestTemplateRootfs(t *testing.T) {
wrennDir := "/var/lib/wrenn"
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.UbuntuTemplateID)
want := filepath.Join(wrennDir, "images", "teams",
id.UUIDToBase36(id.PlatformTeamID.Bytes),
id.UUIDToBase36(id.UbuntuTemplateID.Bytes),
"rootfs.ext4")
if got != want {
t.Errorf("TemplateRootfs() = %q, want %q", got, want)
}
@ -97,12 +108,20 @@ func TestTemplateRootfs(t *testing.T) {
func TestPauseSnapshotDir(t *testing.T) {
got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
want := "/var/lib/wrenn/snapshots/cl-abc123"
want := "/var/lib/wrenn/sandboxes/cl-abc123"
if got != want {
t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
}
}
func TestPauseStagingDir(t *testing.T) {
got := PauseStagingDir("/var/lib/wrenn", "cl-abc123")
prefix := "/var/lib/wrenn/sandboxes/cl-abc123.staging-"
if len(got) <= len(prefix) || got[:len(prefix)] != prefix {
t.Errorf("PauseStagingDir() = %q, want prefix %q", got, prefix)
}
}
func TestSandboxesDir(t *testing.T) {
got := SandboxesDir("/var/lib/wrenn")
want := "/var/lib/wrenn/sandboxes"

View File

@ -9,12 +9,13 @@ import (
type SandboxStatus string
const (
StatusPending SandboxStatus = "pending"
StatusRunning SandboxStatus = "running"
StatusPausing SandboxStatus = "pausing"
StatusPaused SandboxStatus = "paused"
StatusStopped SandboxStatus = "stopped"
StatusError SandboxStatus = "error"
StatusPending SandboxStatus = "pending"
StatusRunning SandboxStatus = "running"
StatusPausing SandboxStatus = "pausing"
StatusPaused SandboxStatus = "paused"
StatusSnapshotting SandboxStatus = "snapshotting"
StatusStopped SandboxStatus = "stopped"
StatusError SandboxStatus = "error"
)
// Sandbox holds all state for a running sandbox on this host.

View File

@ -39,3 +39,19 @@ func (a *SlotAllocator) Release(index int) {
defer a.mu.Unlock()
delete(a.inUse, index)
}
// Reserve marks a specific slot index as in use. Returns an error if the
// index is out of range or already taken. Used on resume to re-acquire the
// slot a sandbox previously held so its host-reachable IP stays stable.
func (a *SlotAllocator) Reserve(index int) error {
if index < 1 || index > 32767 {
return fmt.Errorf("slot index out of range: %d", index)
}
a.mu.Lock()
defer a.mu.Unlock()
if a.inUse[index] {
return fmt.Errorf("slot %d already in use", index)
}
a.inUse[index] = true
return nil
}

View File

@ -42,6 +42,43 @@ func CleanupStaleNamespaces() {
// Clean up any stale wrenn iptables rules referencing old veth interfaces.
cleanupStaleIptablesRules()
// Flush any orphan conntrack rows for sandbox host-IPs. After a wedged
// destroy the netfilter conntrack table can retain DNAT/SNAT entries
// pointing at vanished interfaces, which makes new flows to recycled
// slot IPs misroute. Best-effort; missing conntrack binary is OK.
flushStaleConntrack()
}
// flushStaleConntrack removes conntrack rows referencing the sandbox host
// IP range (10.11.0.0/16) and the namespace veth range (10.12.0.0/16).
// Best-effort: silently skipped if conntrack(8) is absent.
func flushStaleConntrack() {
if _, err := exec.LookPath("conntrack"); err != nil {
slog.Debug("conntrack binary not found, skipping flush")
return
}
flushed := 0
for _, cidr := range []string{"10.11.0.0/16", "10.12.0.0/16"} {
for _, dir := range []string{"--src", "--dst"} {
out, err := exec.Command("conntrack", "-D", dir, cidr).CombinedOutput()
if err != nil {
// conntrack -D exits 1 when no entries match; not an
// error from our perspective.
slog.Debug("conntrack flush", "cidr", cidr, "dir", dir, "error", err)
continue
}
// Output looks like "conntrack v1.4.x ... 3 flow entries have been deleted."
// We only log INFO when at least one row was actually removed.
if strings.Contains(string(out), "have been deleted") &&
!strings.Contains(string(out), "0 flow entries") {
flushed++
}
}
}
if flushed > 0 {
slog.Info("flushed stale conntrack entries", "matched_filters", flushed)
}
}
// cleanupStaleIptablesRules removes host iptables rules that reference
@ -176,7 +213,7 @@ func NewSlot(index int) *Slot {
// CreateNetwork sets up the full network topology for a sandbox:
// - Named network namespace
// - Veth pair bridging host and namespace
// - TAP device inside namespace for Firecracker
// - TAP device inside namespace for Cloud Hypervisor
// - Routes and NAT rules for connectivity
//
// On error, all partially created resources are rolled back.
@ -430,6 +467,9 @@ func CreateNetwork(slot *Slot) error {
rollback()
return fmt.Errorf("add masquerade rule: %w", err)
}
rollbacks = append(rollbacks, func() {
_ = iptablesHost("-t", "nat", "-D", "POSTROUTING", "-s", fmt.Sprintf("%s/32", slot.VpeerIP.String()), "-o", defaultIface, "-j", "MASQUERADE")
})
slog.Info("network created",
"ns", slot.NamespaceID,
@ -444,6 +484,9 @@ func CreateNetwork(slot *Slot) error {
// All steps are attempted even if earlier ones fail. Returns a combined
// error describing which cleanup steps failed.
func RemoveNetwork(slot *Slot) error {
if slot == nil {
return nil
}
var errs []error
defaultIface, _ := getDefaultInterface()

View File

@ -41,6 +41,28 @@ type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*
// accumulated log entries. Used for per-step DB progress updates.
type ProgressFunc func(step int, entries []BuildLogEntry)
// StepStartFunc is called immediately before a step begins executing.
type StepStartFunc func(step int, phase string, st Step)
// OutputChunkFunc is called with each raw output chunk produced by a streaming
// RUN step, as it arrives.
type OutputChunkFunc func(step int, data []byte)
// PtyChunk is one event from a streaming PTY exec: either an output chunk
// (Data set) or the terminal result (Done set, Exit/Err populated).
type PtyChunk struct {
Data []byte
Done bool
Exit int32
Err error
}
// StreamExecFunc runs shellCmd in a PTY inside sandboxID and returns a channel
// of PtyChunk events. The channel is closed after a Done chunk (or an Err
// chunk). It is the streaming counterpart of ExecFunc, used for RUN steps so
// build output reaches the client live.
type StreamExecFunc func(ctx context.Context, sandboxID, shellCmd string) (<-chan PtyChunk, error)
// Execute runs steps sequentially against sandboxID using execFn.
//
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
@ -63,6 +85,9 @@ func Execute(
defaultTimeout time.Duration,
bctx *ExecContext,
execFn ExecFunc,
streamFn StreamExecFunc,
onStepStart StepStartFunc,
onChunk OutputChunkFunc,
onProgress ProgressFunc,
) (entries []BuildLogEntry, nextStep int, ok bool) {
if defaultTimeout <= 0 {
@ -73,6 +98,9 @@ func Execute(
for _, st := range steps {
step++
slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw)
if onStepStart != nil {
onStepStart(step, phase, st)
}
switch st.Kind {
case KindENV:
@ -120,7 +148,13 @@ func Execute(
if st.Timeout > 0 {
timeout = st.Timeout
}
entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
var entry BuildLogEntry
var succeeded bool
if streamFn != nil {
entry, succeeded = execRunStreaming(ctx, st, sandboxID, phase, step, timeout, bctx, streamFn, onChunk)
} else {
entry, succeeded = execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
}
entries = append(entries, entry)
if !succeeded {
return entries, step, false
@ -171,6 +205,66 @@ func execRun(
return entry, entry.Ok
}
// execRunStreaming runs a RUN step in a PTY via streamFn, forwarding each
// output chunk to onChunk as it arrives. The merged PTY output is also
// accumulated into the returned BuildLogEntry.Stdout for cold log viewing.
// A PTY merges stdout and stderr onto one stream, so Stderr stays empty
// unless the exec itself fails to start.
func execRunStreaming(
ctx context.Context,
st Step,
sandboxID, phase string,
step int,
timeout time.Duration,
bctx *ExecContext,
streamFn StreamExecFunc,
onChunk OutputChunkFunc,
) (BuildLogEntry, bool) {
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
entry := BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw}
ch, err := streamFn(execCtx, sandboxID, bctx.WrappedCommand(st.Shell))
if err != nil {
entry.Stderr = fmt.Sprintf("exec error: %v", err)
entry.Elapsed = time.Since(start).Milliseconds()
return entry, false
}
var out []byte
gotDone := false
for chunk := range ch {
if chunk.Err != nil {
entry.Stdout = string(out)
entry.Stderr = fmt.Sprintf("exec error: %v", chunk.Err)
entry.Elapsed = time.Since(start).Milliseconds()
return entry, false
}
if chunk.Done {
entry.Exit = chunk.Exit
gotDone = true
continue
}
out = append(out, chunk.Data...)
if onChunk != nil {
onChunk(step, chunk.Data)
}
}
entry.Stdout = string(out)
entry.Elapsed = time.Since(start).Milliseconds()
// A channel that closes without a Done chunk means the stream ended
// early (cancelled/aborted). Treat that as a failure, never a success.
if !gotDone {
entry.Stderr = "exec error: build step stream ended without completion"
return entry, false
}
entry.Ok = entry.Exit == 0
return entry, entry.Ok
}
// execUser creates a unix user (if not exists), grants passwordless sudo,
// and updates bctx.User for subsequent steps.
func execUser(

View File

@ -0,0 +1,28 @@
package sandbox
import (
"fmt"
"os/exec"
"strings"
)
// DetectCHVersion runs the cloud-hypervisor binary with --version and
// parses the semver from the output (e.g. "cloud-hypervisor v43.0" → "43.0").
func DetectCHVersion(binaryPath string) (string, error) {
out, err := exec.Command(binaryPath, "--version").Output()
if err != nil {
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
}
line := strings.TrimSpace(string(out))
for field := range strings.FieldsSeq(line) {
v := strings.TrimPrefix(field, "v")
if v != field || strings.Contains(field, ".") {
if strings.Count(v, ".") >= 1 {
return v, nil
}
}
}
return "", fmt.Errorf("could not parse version from cloud-hypervisor output: %q", line)
}

View File

@ -10,12 +10,22 @@ import (
// ConnTracker tracks active proxy connections for a single sandbox and
// provides a drain mechanism for pre-pause graceful shutdown.
// It is safe for concurrent use.
//
// Internally we do not use sync.WaitGroup because Wait cannot be interrupted
// — a stuck handler would pin the waiter goroutine forever. Instead we keep
// an explicit counter guarded by mu plus a zeroCh that is closed when the
// counter transitions to 0, allowing Drain/ForceClose to select on it
// alongside cancellation and timeout signals without spawning helper
// goroutines that could leak across Reset boundaries.
type ConnTracker struct {
draining atomic.Bool
wg sync.WaitGroup
mu sync.Mutex
count int
zeroCh chan struct{} // closed when count drops to 0; recreated on next Acquire
// cancelMu protects cancelDrain so Reset can signal a timed-out Drain
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
// to exit early.
cancelMu sync.Mutex
cancelDrain chan struct{}
@ -40,13 +50,18 @@ func (t *ConnTracker) Acquire() bool {
if t.draining.Load() {
return false
}
t.wg.Add(1)
// Re-check after Add: Drain may have set draining between our Load
// and Add. If so, undo the Add and reject the connection.
t.mu.Lock()
// Re-check under mu so a concurrent Drain that flipped draining cannot
// race past us with the counter already incremented.
if t.draining.Load() {
t.wg.Done()
t.mu.Unlock()
return false
}
t.count++
if t.count == 1 {
t.zeroCh = make(chan struct{})
}
t.mu.Unlock()
return true
}
@ -63,11 +78,32 @@ func (t *ConnTracker) Context() context.Context {
// Release marks one connection as complete. Must be called exactly once
// per successful Acquire.
func (t *ConnTracker) Release() {
t.wg.Done()
t.mu.Lock()
t.count--
if t.count == 0 && t.zeroCh != nil {
close(t.zeroCh)
t.zeroCh = nil
}
t.mu.Unlock()
}
// waitDrain returns a channel that closes when the in-flight count is zero,
// or a closed channel immediately if there's nothing in flight.
func (t *ConnTracker) waitDrain() <-chan struct{} {
t.mu.Lock()
defer t.mu.Unlock()
if t.count == 0 {
ch := make(chan struct{})
close(ch)
return ch
}
return t.zeroCh
}
// Drain marks the tracker as draining (all future Acquire calls return
// false) and waits up to timeout for in-flight connections to finish.
// Returns when the count hits 0, Reset is called, or the timeout fires —
// whichever happens first. No goroutine is leaked on timeout.
func (t *ConnTracker) Drain(timeout time.Duration) {
t.draining.Store(true)
@ -76,16 +112,9 @@ func (t *ConnTracker) Drain(timeout time.Duration) {
t.cancelDrain = cancel
t.cancelMu.Unlock()
done := make(chan struct{})
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-t.waitDrain():
case <-cancel:
// Reset was called; stop waiting.
case <-time.After(timeout):
}
}
@ -101,22 +130,16 @@ func (t *ConnTracker) ForceClose() {
}
t.ctxMu.Unlock()
// Wait briefly for force-closed connections to call Release().
done := make(chan struct{})
go func() {
t.wg.Wait()
close(done)
}()
select {
case <-done:
case <-t.waitDrain():
case <-time.After(2 * time.Second):
}
}
// Reset re-enables the tracker after a failed drain. This allows the
// sandbox to accept proxy connections again if the pause operation fails
// and the VM is resumed. It also cancels any lingering Drain goroutine
// and creates a fresh context for new connections.
// and the VM is resumed. It also signals any lingering Drain to exit and
// creates a fresh context for new connections.
func (t *ConnTracker) Reset() {
t.cancelMu.Lock()
if t.cancelDrain != nil {
@ -130,7 +153,6 @@ func (t *ConnTracker) Reset() {
}
t.cancelMu.Unlock()
// Replace the cancelled context with a fresh one.
t.ctxMu.Lock()
t.ctx, t.cancel = context.WithCancel(context.Background())
t.ctxMu.Unlock()

View File

@ -1,30 +0,0 @@
package sandbox
import (
"fmt"
"os/exec"
"strings"
)
// DetectFirecrackerVersion runs the firecracker binary with --version and
// parses the semver from the output (e.g. "Firecracker v1.14.1" → "1.14.1").
func DetectFirecrackerVersion(binaryPath string) (string, error) {
out, err := exec.Command(binaryPath, "--version").Output()
if err != nil {
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
}
// Output is typically "Firecracker v1.14.1\n" or similar.
line := strings.TrimSpace(string(out))
for _, field := range strings.Fields(line) {
v := strings.TrimPrefix(field, "v")
if v != field || strings.Contains(field, ".") {
// Either had a "v" prefix or contains a dot — likely the version.
if strings.Count(v, ".") >= 1 {
return v, nil
}
}
}
return "", fmt.Errorf("could not parse version from firecracker output: %q", line)
}

View File

@ -9,6 +9,8 @@ import (
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
@ -29,13 +31,9 @@ func EnsureImageSizes(wrennDir string, targetMB int) error {
}
targetBytes := int64(targetMB) * 1024 * 1024
// Expand the built-in minimal image.
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil {
return err
}
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep.
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep. The
// built-in system base templates live under teams/{base36(0)}/... so this
// covers them too.
teamsDir := layout.TeamsDir(wrennDir)
teamEntries, err := os.ReadDir(teamsDir)
if err != nil {
@ -104,12 +102,19 @@ func ParseSizeToMB(s string) (int, error) {
}
}
// ShrinkMinimalImage shrinks the built-in minimal rootfs back to its minimum
// size using resize2fs -M. This is the inverse of EnsureImageSizes and should
// be called during graceful shutdown so the image is stored compactly on disk.
func ShrinkMinimalImage(wrennDir string) {
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
shrinkImage(minimalRootfs)
// ShrinkSystemImages shrinks the built-in system base rootfs images back to
// their minimum size using resize2fs -M. This is the inverse of
// EnsureImageSizes and should be called during graceful shutdown so the images
// are stored compactly on disk.
func ShrinkSystemImages(wrennDir string) {
for _, tmplID := range []pgtype.UUID{
id.UbuntuTemplateID,
id.AlpineTemplateID,
id.ArchTemplateID,
id.FedoraTemplateID,
} {
shrinkImage(layout.TemplateRootfs(wrennDir, id.PlatformTeamID, tmplID))
}
}
// shrinkImage shrinks a single rootfs image to its minimum size.

View File

@ -0,0 +1,187 @@
// Package sandbox: launching a fresh sandbox from a snapshot template.
//
// Mirrors the pause/resume restore path but produces a brand-new sandbox each
// call: fresh ID, fresh network slot, fresh CoW on top of the template's
// flattened rootfs. The CH process is launched with --restore + lazy memory
// (UFFD), and the post-restore memory loader is started so any subsequent
// CreateSnapshot taken from this descendant is self-contained (the
// pause-resume-pause chain guarantee, applied to template lineages).
package sandbox
import (
"context"
"fmt"
"log/slog"
"os"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/internal/devicemapper"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/models"
"git.omukk.dev/wrenn/wrenn/internal/network"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// createFromSnapshotTemplate launches a new sandbox from a snapshot-template
// directory (state.json + config.json + memory-ranges + rootfs.ext4).
//
// The caller has already verified IsSnapshotTemplate(templateDir). Resources
// acquired here are rolled back on any failure; on success the sandbox is
// registered in m.boxes and runs in StatusRunning.
func (m *Manager) createFromSnapshotTemplate(
ctx context.Context,
sandboxID string,
teamID, templateID pgtype.UUID,
vcpus, memoryMB, timeoutSec, diskSizeMB int,
defaultUser string,
defaultEnv map[string]string,
) (*models.Sandbox, int64, error) {
templateDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
baseRootfs := layout.TemplateRootfs(m.cfg.WrennDir, teamID, templateID)
meta, err := readSnapshotMeta(templateDir)
if err != nil {
return nil, 0, fmt.Errorf("read snapshot meta: %w", err)
}
if meta.SandboxDir == "" {
// CH's saved config.json hardcodes a tmpfs disk path; meta.SandboxDir
// is that exact path. A snapshot template without it cannot be launched.
return nil, 0, fmt.Errorf("snapshot template %s missing sandbox_dir in meta", templateDir)
}
// Acquire shared read-only loop on the flattened rootfs. Many sandboxes
// can share this loop concurrently — refcounted in LoopRegistry.
originLoop, err := m.loops.Acquire(baseRootfs)
if err != nil {
return nil, 0, fmt.Errorf("acquire loop: %w", err)
}
originSize, err := devicemapper.OriginSizeBytes(originLoop)
if err != nil {
m.loops.Release(baseRootfs)
return nil, 0, fmt.Errorf("origin size: %w", err)
}
// Per-sandbox CoW on top of the shared origin.
dmName := "wrenn-" + sandboxID
if err := os.MkdirAll(layout.SandboxDir(m.cfg.WrennDir, sandboxID), 0o755); err != nil {
m.loops.Release(baseRootfs)
return nil, 0, fmt.Errorf("create sandbox dir: %w", err)
}
cowPath := layout.SandboxCowPath(m.cfg.WrennDir, sandboxID)
cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
if err != nil {
m.loops.Release(baseRootfs)
return nil, 0, fmt.Errorf("create dm-snapshot: %w", err)
}
res := &createResources{
sandboxID: sandboxID,
loops: m.loops,
loopImage: baseRootfs,
dmDevice: dmDev,
cowPath: cowPath,
slots: m.slots,
}
slotIdx, err := m.slots.Allocate()
if err != nil {
res.rollback()
return nil, 0, fmt.Errorf("allocate network slot: %w", err)
}
res.slotIdx = slotIdx
slot := network.NewSlot(slotIdx)
if err := network.CreateNetwork(slot); err != nil {
res.rollback()
return nil, 0, fmt.Errorf("create network: %w", err)
}
res.slot = slot
// CH's saved config.json hardcodes a tmpfs disk path; meta.SandboxDir is
// that exact path (carried forward verbatim across template chains, so a
// snapshot-of-a-snapshot resolves to the root ancestor's path). The
// launcher mounts a fresh tmpfs there inside its private mount namespace
// and symlinks rootfs.ext4 → our new dm device.
vmCfg := m.buildRestoreVMConfig(restoreInputs{
sandboxID: sandboxID,
templateID: id.UUIDString(templateID),
snapDir: templateDir,
rootfsPath: dmDev.DevicePath,
vcpus: vcpus,
memoryMB: memoryMB,
slot: slot,
sandboxDir: meta.SandboxDir,
})
client, err := m.launchRestoredVM(ctx, vmCfg, slot.HostIP.String())
if err != nil {
res.rollback()
return nil, 0, err
}
res.vm = m.vm
envdVersion, _ := client.FetchVersion(ctx)
now := time.Now()
sb := &sandboxState{
Sandbox: models.Sandbox{
ID: sandboxID,
Status: models.StatusRunning,
TemplateTeamID: teamID.Bytes,
TemplateID: templateID.Bytes,
VCPUs: vcpus,
MemoryMB: memoryMB,
TimeoutSec: timeoutSec,
SlotIndex: slotIdx,
HostIP: slot.HostIP,
RootfsPath: dmDev.DevicePath,
CreatedAt: now,
LastActiveAt: now,
Metadata: m.buildMetadata(envdVersion),
},
slot: slot,
connTracker: &ConnTracker{},
dmDevice: dmDev,
baseImagePath: baseRootfs,
sandboxDirOverride: meta.SandboxDir,
}
sb.client.Store(client)
m.mu.Lock()
m.boxes[sandboxID] = sb
m.mu.Unlock()
// /init lifecycle bump then start the memory loader. Loader is required
// so any future CreateSnapshot taken from this descendant captures all
// guest pages (otherwise SEEK_DATA/SEEK_HOLE would emit holes for the
// still-lazy UFFD pages — silent corruption across template chains).
m.initAndStartMemoryLoader(ctx, sb, defaultUser, id.UUIDString(templateID), defaultEnv)
m.startSampler(sb)
m.startCrashWatcher(sb)
slog.Info("sandbox launched from snapshot template",
"id", sandboxID,
"team_id", teamID,
"template_id", templateID,
"sandbox_dir", meta.SandboxDir,
"host_ip", slot.HostIP.String(),
"dm_device", dmDev.DevicePath,
)
return &sb.Sandbox, cowSize, nil
}
// templateExists returns true if a snapshot template already lives at
// TemplateDir(team, templateID). Used by CreateSnapshot to refuse silent
// overwrites — every snapshot must land in a fresh templateID.
func (m *Manager) templateExists(teamID, templateID pgtype.UUID) bool {
dir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
if _, err := os.Stat(dir); err != nil {
return false
}
return layout.IsSnapshotTemplate(dir)
}

File diff suppressed because it is too large Load Diff

1180
internal/sandbox/pause.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,14 @@
package sandbox
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"syscall"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
)
@ -48,42 +49,43 @@ func readCPUStat(pid int) (cpuStat, error) {
return cpuStat{utime: utime, stime: stime}, nil
}
// readEnvdMemUsed fetches mem_used from envd's /metrics endpoint. Returns
// guest-side total - MemAvailable (actual process memory, excluding reclaimable
// page cache). VmRSS of the Firecracker process includes guest page cache and
// never decreases, so this is the accurate metric for dashboard display.
func readEnvdMemUsed(client *envdclient.Client) (int64, error) {
resp, err := client.HTTPClient().Get(client.BaseURL() + "/metrics")
// envdMetrics holds metric values read from envd's /metrics endpoint.
type envdMetrics struct {
MemBytes int64
DiskBytes int64
}
// readEnvdMetrics fetches mem_used and disk_used from envd's /metrics endpoint.
// Returns guest-side process memory (total - available) and filesystem usage
// from statfs("/"). These are the guest-visible metrics users care about.
func readEnvdMetrics(ctx context.Context, client *envdclient.Client) (envdMetrics, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.BaseURL()+"/metrics", nil)
if err != nil {
return 0, fmt.Errorf("fetch envd metrics: %w", err)
return envdMetrics{}, fmt.Errorf("build metrics request: %w", err)
}
resp, err := client.HTTPClient().Do(req)
if err != nil {
return envdMetrics{}, fmt.Errorf("fetch envd metrics: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return 0, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
return envdMetrics{}, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("read envd metrics body: %w", err)
return envdMetrics{}, fmt.Errorf("read envd metrics body: %w", err)
}
var m struct {
MemUsed int64 `json:"mem_used"`
MemUsed int64 `json:"mem_used"`
DiskUsed int64 `json:"disk_used"`
}
if err := json.Unmarshal(body, &m); err != nil {
return 0, fmt.Errorf("decode envd metrics: %w", err)
return envdMetrics{}, fmt.Errorf("decode envd metrics: %w", err)
}
return m.MemUsed, nil
}
// readDiskAllocated returns the actual allocated bytes (not apparent size)
// of the file at path. This uses stat's block count × 512.
func readDiskAllocated(path string) (int64, error) {
var stat syscall.Stat_t
if err := syscall.Stat(path, &stat); err != nil {
return 0, fmt.Errorf("stat %s: %w", path, err)
}
return stat.Blocks * 512, nil
return envdMetrics{MemBytes: m.MemUsed, DiskBytes: m.DiskUsed}, nil
}

186
internal/sandbox/punch.go Normal file
View File

@ -0,0 +1,186 @@
// Package sandbox: post-snapshot hole punching for memory-ranges files.
//
// CH v52's SEEK_DATA/SEEK_HOLE snapshot writer only skips ranges already
// hole in the source memfd. Pages the guest never reported as free are
// written verbatim — including pages whose contents happen to be all zero
// (fresh allocations the guest scribbled then released without telling the
// balloon driver). Walking the resulting file and punching any 4 KiB block
// of zeros recovers that space without any guest cooperation.
package sandbox
import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/sys/unix"
)
const (
// punchBlockSize is the granularity at which we test for zero runs and
// issue FALLOC_FL_PUNCH_HOLE. Matches the kernel page size and the
// minimum hole size on ext4.
punchBlockSize = 4096
// punchReadSize is the IO chunk size used by the scan loop. We read
// many blocks per syscall and split them in-memory so a 20 GiB
// memory-ranges file costs ~20K read(2) syscalls instead of ~5M.
// Crucial under single-disk hosts where each syscall otherwise
// contends with sshd / journal IO.
punchReadSize = 1 << 20 // 1 MiB = 256 blocks
)
// punchZeroPagesInDir runs punchZeroPages on every memory* file in dir.
// CH writes its memory dump as one or more files prefixed "memory" inside
// the snapshot directory; everything else (config.json, state.json) is
// metadata and untouched.
func punchZeroPagesInDir(dir string) {
entries, err := os.ReadDir(dir)
if err != nil {
slog.Warn("punch: read snapshot dir", "dir", dir, "error", err)
return
}
for _, e := range entries {
if e.IsDir() || !strings.HasPrefix(e.Name(), "memory") {
continue
}
path := filepath.Join(dir, e.Name())
before, after, err := punchZeroPages(path)
if err != nil {
slog.Warn("punch: zero-page scan failed", "path", path, "error", err)
continue
}
slog.Info("punch: zero-page scan done",
"path", path,
"alloc_before", before,
"alloc_after", after,
"reclaimed", before-after)
}
}
// punchZeroPages scans path block-by-block, batching runs of all-zero 4 KiB
// blocks and punching them out via FALLOC_FL_PUNCH_HOLE. Existing holes are
// skipped via SEEK_DATA so a partially-sparse input stays cheap to scan.
//
// Returns the file's disk allocation (st_blocks * 512) before and after.
func punchZeroPages(path string) (int64, int64, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
return 0, 0, err
}
defer f.Close()
stBefore, err := statBlocks(f)
if err != nil {
return 0, 0, fmt.Errorf("stat before: %w", err)
}
fi, err := f.Stat()
if err != nil {
return 0, 0, fmt.Errorf("stat: %w", err)
}
size := fi.Size()
buf := make([]byte, punchReadSize)
off := int64(0)
for off < size {
// Skip ahead to next data region; nothing to do in holes.
next, err := f.Seek(off, 3) // SEEK_DATA = 3
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, unix.ENXIO) {
break
}
return 0, 0, fmt.Errorf("seek_data @ %d: %w", off, err)
}
off = next &^ (punchBlockSize - 1) // align down to block
// Find end of this data extent.
endData, err := f.Seek(off, 4) // SEEK_HOLE = 4
if err != nil {
return 0, 0, fmt.Errorf("seek_hole @ %d: %w", off, err)
}
// Scan [off, endData) chunk by chunk; batch zero runs across both
// intra-chunk and inter-chunk boundaries so a contiguous zero
// region is punched in a single fallocate.
zeroStart := int64(-1)
cur := off
for cur < endData {
toRead := min(int64(len(buf)), endData-cur)
n, err := readAt(f, buf[:toRead], cur)
if err != nil {
return 0, 0, fmt.Errorf("read @ %d: %w", cur, err)
}
if n == 0 {
break
}
// Walk the chunk one block at a time, tracking zero runs.
for blkOff := 0; blkOff < n; blkOff += punchBlockSize {
blkEnd := min(blkOff+punchBlockSize, n)
blk := buf[blkOff:blkEnd]
blkAbs := cur + int64(blkOff)
if isZero(blk) && len(blk) == punchBlockSize {
if zeroStart < 0 {
zeroStart = blkAbs
}
} else if zeroStart >= 0 {
if err := punch(f, zeroStart, blkAbs-zeroStart); err != nil {
return 0, 0, err
}
zeroStart = -1
}
}
cur += int64(n)
}
if zeroStart >= 0 {
if err := punch(f, zeroStart, cur-zeroStart); err != nil {
return 0, 0, err
}
}
off = endData
}
stAfter, err := statBlocks(f)
if err != nil {
return 0, 0, fmt.Errorf("stat after: %w", err)
}
return stBefore, stAfter, nil
}
func punch(f *os.File, off, length int64) error {
mode := uint32(unix.FALLOC_FL_PUNCH_HOLE | unix.FALLOC_FL_KEEP_SIZE)
if err := unix.Fallocate(int(f.Fd()), mode, off, length); err != nil {
return fmt.Errorf("fallocate punch @ %d len %d: %w", off, length, err)
}
return nil
}
func readAt(f *os.File, buf []byte, off int64) (int, error) {
n, err := f.ReadAt(buf, off)
if err == io.EOF {
return n, nil
}
return n, err
}
func isZero(b []byte) bool {
for _, x := range b {
if x != 0 {
return false
}
}
return true
}
func statBlocks(f *os.File) (int64, error) {
var st unix.Stat_t
if err := unix.Fstat(int(f.Fd()), &st); err != nil {
return 0, err
}
return int64(st.Blocks) * 512, nil
}

118
internal/sandbox/restore.go Normal file
View File

@ -0,0 +1,118 @@
// Package sandbox: shared CH-restore helpers used by both Resume (paused →
// running) and the snapshot-template launch path (template → fresh sandbox).
//
// The two callers diverge in how they acquire resources (slot, dm-snapshot,
// sandbox identity) but converge on:
//
// build VMConfig → CreateFromSnapshot → vm.Resume → wait envd → balloon deflate
//
// These steps are extracted here so the sequence — and its quirks (paused
// post-restore state, balloon best-effort, restored disk path baked into
// CH's config.json) — has a single source of truth.
package sandbox
import (
"context"
"fmt"
"log/slog"
"path/filepath"
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
"git.omukk.dev/wrenn/wrenn/internal/network"
"git.omukk.dev/wrenn/wrenn/internal/vm"
)
// restoreInputs is the common set of fields needed to build a restore VMConfig.
type restoreInputs struct {
sandboxID string // VM identity for the new CH process (sock path, log file)
templateID string // forwarded to envd via PostInit (informational)
snapDir string // directory containing CH snapshot artefacts
rootfsPath string // /dev/mapper/wrenn-{newID} — per-sandbox dm-snapshot
vcpus int
memoryMB int
slot *network.Slot
sandboxDir string // override for VMConfig.SandboxDir; "" = default
}
// buildRestoreVMConfig assembles the VMConfig used to launch a CH process in
// restore mode. sandboxDir, when non-empty, overrides the default
// "/tmp/ch-vm-{SandboxID}" — required when the snapshot's saved config.json
// points at a different sandbox's tmpfs path (i.e. snapshot-template launch).
func (m *Manager) buildRestoreVMConfig(in restoreInputs) vm.VMConfig {
return vm.VMConfig{
SandboxID: in.sandboxID,
TemplateID: in.templateID,
KernelPath: m.cfg.KernelPath,
RootfsPath: in.rootfsPath,
VCPUs: in.vcpus,
MemoryMB: in.memoryMB,
NetworkNamespace: in.slot.NamespaceID,
TapDevice: in.slot.TapName,
TapMAC: in.slot.TapMAC,
GuestIP: in.slot.GuestIP,
GatewayIP: in.slot.TapIP,
NetMask: in.slot.GuestNetMask,
VMMBin: m.cfg.VMMBin,
LogDir: filepath.Join(m.cfg.WrennDir, "logs"),
RestoreFromDir: in.snapDir,
RestoreLazyMemory: true,
SandboxDir: in.sandboxDir,
}
}
// launchRestoredVM starts CH in restore mode, resumes the vCPUs, waits for
// envd to be reachable, then best-effort deflates the balloon. On any failure
// the partial VM is destroyed before returning — the caller is responsible
// for tearing down dm/network/slot.
//
// Returns the connected envd client on success.
func (m *Manager) launchRestoredVM(ctx context.Context, vmCfg vm.VMConfig, hostIP string) (*envdclient.Client, error) {
if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg); err != nil {
return nil, fmt.Errorf("create from snapshot: %w", err)
}
if err := m.vm.Resume(ctx, vmCfg.SandboxID); err != nil {
_ = m.vm.Destroy(context.Background(), vmCfg.SandboxID)
return nil, fmt.Errorf("vm resume: %w", err)
}
client := envdclient.New(hostIP)
waitCtx, waitCancel := context.WithTimeout(ctx, envdReadyTimeout(vmCfg.MemoryMB))
defer waitCancel()
if err := client.WaitUntilReady(waitCtx); err != nil {
_ = m.vm.Destroy(context.Background(), vmCfg.SandboxID)
return nil, fmt.Errorf("wait envd: %w", err)
}
// Best-effort balloon deflate. Free-page reporting drains pages while the
// sandbox runs; the resumed guest needs its full memory budget back. A
// failure leaves the guest memory-starved but doesn't break correctness.
if err := m.vm.UpdateBalloon(ctx, vmCfg.SandboxID, 0); err != nil {
slog.Warn("balloon deflate after restore failed", "id", vmCfg.SandboxID, "error", err)
}
return client, nil
}
// initAndStartMemoryLoader runs envd's /init lifecycle bump and then kicks
// off the background memory loader. Ordering matters: /init resets envd's
// mem_preload_* atomics, so the loader's POST /memory/preload must land
// after — otherwise the next CreateSnapshot/Pause would observe a stale
// "idle" state and snapshot a memfile full of holes.
//
// Must be called with sb already registered in m.boxes with StatusRunning
// and sb.client populated.
func (m *Manager) initAndStartMemoryLoader(ctx context.Context, sb *sandboxState, defaultUser, templateIDStr string, envVars map[string]string) {
initCtx, initCancel := context.WithTimeout(ctx, m.cfg.EnvdTimeout)
defer initCancel()
c := sb.client.Load()
if c == nil {
slog.Warn("post-restore PostInit skipped: envd client cleared", "id", sb.ID)
return
}
if err := c.PostInitWithDefaults(initCtx, defaultUser, envVars, sb.ID, templateIDStr); err != nil {
slog.Warn("post-restore PostInit failed", "id", sb.ID, "error", err)
}
m.startMemoryLoader(sb)
}

View File

@ -0,0 +1,208 @@
package sandbox
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/google/uuid"
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/models"
)
// RestorePausedSandboxes scans WRENN_DIR/sandboxes/ for paused-sandbox
// snapshots left behind by a previous agent instance and re-registers them
// in m.boxes as StatusPaused. Without this, ListSandboxes would not report
// these sandboxes, and the CP's HostMonitor would mark them stopped via
// the missing-confirmed-dead reconcile path — orphaning the on-disk
// snapshot dir and surfacing a leaked "stopped" sandbox to users.
//
// Restored sandboxes hold ONLY the slot reservation; VM / network / dm /
// loop refcount stay unowned until Resume rebuilds them. baseImagePath is
// deliberately NOT set on the in-memory entry so cleanup() does not call
// loops.Release on a loop that was never Acquire'd — the registry tolerates
// a Release of an unknown key, but a coincident-same-base running sandbox
// would have its refcount decremented incorrectly.
//
// Must be called once at agent startup, AFTER CleanupOrphanPauseDirs (so
// .staging-* / .trash-* dirs are gone) and BEFORE the HTTP server starts
// serving — otherwise an early Create RPC can race the slot reservation.
//
// Corrupt snapshot dirs (unparseable meta, missing slot index) are renamed
// to .trash-{ts}/ so a future CleanupOrphanPauseDirs sweeps them. Soft
// errors are logged; this function never returns an error — startup should
// not fail because a single sandbox is unrecoverable.
func (m *Manager) RestorePausedSandboxes() {
sandboxesDir := layout.SandboxesDir(m.cfg.WrennDir)
entries, err := os.ReadDir(sandboxesDir)
if err != nil {
// Directory does not exist yet — fresh install, nothing to restore.
return
}
type candidate struct {
sandboxID string
snapDir string
meta *snapshotMeta
teamID [16]byte
templID [16]byte
}
// Pass 1: parse every snapshot meta. Trash anything unreadable or
// missing the slot index — those are crash artefacts, not recoverable
// sandboxes.
candidates := make([]candidate, 0, len(entries))
for _, e := range entries {
if !e.IsDir() {
continue
}
name := e.Name()
// Skip CleanupOrphanPauseDirs's territory. If it ran before us
// these are already gone; if not, leave them alone.
if strings.Contains(name, ".staging-") || strings.Contains(name, ".trash-") {
continue
}
snapDir := layout.PauseSnapshotDir(m.cfg.WrennDir, name)
meta, err := readSnapshotMeta(snapDir)
if err != nil {
slog.Warn("restore: unreadable snapshot meta, trashing dir",
"id", name, "error", err)
trashCorruptDir(snapDir)
continue
}
if meta.SlotIndex == 0 {
slog.Warn("restore: snapshot has no slot_index, trashing dir", "id", name)
trashCorruptDir(snapDir)
continue
}
teamBytes, err := parsePlainUUID(meta.TeamID)
if err != nil {
slog.Warn("restore: bad team_id in snapshot meta", "id", name, "error", err)
trashCorruptDir(snapDir)
continue
}
templateBytes, err := parsePlainUUID(meta.TemplateID)
if err != nil {
slog.Warn("restore: bad template_id in snapshot meta", "id", name, "error", err)
trashCorruptDir(snapDir)
continue
}
candidates = append(candidates, candidate{
sandboxID: name,
snapDir: snapDir,
meta: meta,
teamID: teamBytes,
templID: templateBytes,
})
}
// Pass 2: bucket by slot index, pick the newest CreatedAt per slot.
// Multiple candidates per slot happen when older paused-sandbox dirs
// were left on disk by the pre-fix leak (DB row marked stopped but the
// snapshot was never cleaned). The newest is the most likely live one;
// older losers are trashed so CleanupOrphanPauseDirs sweeps them on
// the next startup.
bySlot := make(map[int][]candidate, len(candidates))
for _, c := range candidates {
bySlot[c.meta.SlotIndex] = append(bySlot[c.meta.SlotIndex], c)
}
restored := 0
pruned := 0
for slot, cands := range bySlot {
sort.Slice(cands, func(i, j int) bool {
return cands[i].meta.CreatedAt.After(cands[j].meta.CreatedAt)
})
// Trash every loser. The host_monitor's zombie-cleanup path catches
// the winner if its DB row says 'stopped' — but losers never enter
// m.boxes and would otherwise sit on disk indefinitely.
for _, stale := range cands[1:] {
slog.Info("restore: pruning older snapshot for same slot",
"id", stale.sandboxID, "slot", slot, "created", stale.meta.CreatedAt,
"winner", cands[0].sandboxID, "winner_created", cands[0].meta.CreatedAt)
trashCorruptDir(stale.snapDir)
pruned++
}
winner := cands[0]
if err := m.slots.Reserve(winner.meta.SlotIndex); err != nil {
// Reserve only fails if another candidate (different slot value
// in meta but same numeric index) already grabbed it, or if the
// allocator is corrupt. Either way the snapshot is unusable
// without a slot, so trash it.
slog.Warn("restore: slot reservation failed, trashing dir",
"id", winner.sandboxID, "slot", winner.meta.SlotIndex, "error", err)
trashCorruptDir(winner.snapDir)
pruned++
continue
}
sb := &sandboxState{
Sandbox: models.Sandbox{
ID: winner.sandboxID,
Status: models.StatusPaused,
TemplateTeamID: winner.teamID,
TemplateID: winner.templID,
VCPUs: winner.meta.VCPUs,
MemoryMB: winner.meta.MemoryMB,
TimeoutSec: winner.meta.TimeoutSec,
SlotIndex: winner.meta.SlotIndex,
CreatedAt: winner.meta.CreatedAt,
// LastActiveAt cosmetic only — TTL reaper ignores non-Running.
LastActiveAt: winner.meta.CreatedAt,
},
// connTracker must be non-nil: resumeFromMeta calls Reset() on it
// unconditionally during rehydration. A nil pointer would panic.
connTracker: &ConnTracker{},
// baseImagePath intentionally left empty — see function doc.
// sandboxDirOverride intentionally left empty — resumeFromMeta
// reads meta.SandboxDir from disk on the resume path.
}
m.mu.Lock()
m.boxes[winner.sandboxID] = sb
m.mu.Unlock()
restored++
slog.Info("restored paused sandbox", "id", winner.sandboxID,
"slot", winner.meta.SlotIndex, "vcpus", winner.meta.VCPUs, "memory_mb", winner.meta.MemoryMB)
}
if restored > 0 || pruned > 0 {
slog.Info("paused sandbox restore complete", "restored", restored, "pruned", pruned)
}
}
// parsePlainUUID turns a standard hyphenated UUID string (as produced by
// id.UUIDString) back into the 16-byte representation used by sandboxState.
func parsePlainUUID(s string) ([16]byte, error) {
if s == "" {
return [16]byte{}, fmt.Errorf("empty uuid string")
}
u, err := uuid.Parse(s)
if err != nil {
return [16]byte{}, err
}
return [16]byte(u), nil
}
// trashCorruptDir renames a corrupt snapshot directory aside so a future
// CleanupOrphanPauseDirs sweeps it. Best-effort: if rename fails we log
// and move on — leaving the directory in place is safe (restore will skip
// it again next startup) but unwanted.
func trashCorruptDir(dir string) {
parent := filepath.Dir(dir)
base := filepath.Base(dir)
trash := filepath.Join(parent, fmt.Sprintf("%s.trash-%d", base, time.Now().UnixNano()))
if err := os.Rename(dir, trash); err != nil {
slog.Warn("restore: failed to trash corrupt snapshot dir",
"src", dir, "dst", trash, "error", err)
}
}

View File

@ -1,221 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
// Package snapshot implements snapshot storage, header-based memory mapping,
// and memory file processing for Firecracker VM snapshots.
//
// The header system implements a generational copy-on-write memory mapping.
// Each snapshot generation stores only the blocks that changed since the
// previous generation. A Header contains a sorted list of BuildMap entries
// that together cover the entire memory address space, with each entry
// pointing to a specific generation's diff file.
package snapshot
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/google/uuid"
)
const metadataVersion = 1
// Metadata is the fixed-size header prefix describing the snapshot memory layout.
// Binary layout (little-endian, 64 bytes total):
//
// Version uint64 (8 bytes)
// BlockSize uint64 (8 bytes)
// Size uint64 (8 bytes) — total memory size in bytes
// Generation uint64 (8 bytes)
// BuildID [16]byte (UUID)
// BaseBuildID [16]byte (UUID)
type Metadata struct {
Version uint64
BlockSize uint64
Size uint64
Generation uint64
BuildID uuid.UUID
BaseBuildID uuid.UUID
}
// NewMetadata creates metadata for a first-generation snapshot.
func NewMetadata(buildID uuid.UUID, blockSize, size uint64) *Metadata {
return &Metadata{
Version: metadataVersion,
Generation: 0,
BlockSize: blockSize,
Size: size,
BuildID: buildID,
BaseBuildID: buildID,
}
}
// NextGeneration creates metadata for the next generation in the chain.
func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata {
return &Metadata{
Version: m.Version,
Generation: m.Generation + 1,
BlockSize: m.BlockSize,
Size: m.Size,
BuildID: buildID,
BaseBuildID: m.BaseBuildID,
}
}
// BuildMap maps a contiguous range of the memory address space to a specific
// generation's diff file. Binary layout (little-endian, 40 bytes):
//
// Offset uint64 — byte offset in the virtual address space
// Length uint64 — byte count (multiple of BlockSize)
// BuildID [16]byte — which generation's diff file, uuid.Nil = zero-fill
// BuildStorageOffset uint64 — byte offset within that generation's diff file
type BuildMap struct {
Offset uint64
Length uint64
BuildID uuid.UUID
BuildStorageOffset uint64
}
// Header is the in-memory representation of a snapshot's memory mapping.
// It provides O(log N) lookup from any memory offset to the correct
// generation's diff file and offset within it.
type Header struct {
Metadata *Metadata
Mapping []*BuildMap
// blockStarts tracks which block indices start a new BuildMap entry.
// startMap provides direct access from block index to the BuildMap.
blockStarts []bool
startMap map[int64]*BuildMap
}
// NewHeader creates a Header from metadata and mapping entries.
// If mapping is nil/empty, a single entry covering the full size is created.
func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) {
if metadata.BlockSize == 0 {
return nil, fmt.Errorf("block size cannot be zero")
}
if len(mapping) == 0 {
mapping = []*BuildMap{{
Offset: 0,
Length: metadata.Size,
BuildID: metadata.BuildID,
BuildStorageOffset: 0,
}}
}
blocks := TotalBlocks(int64(metadata.Size), int64(metadata.BlockSize))
starts := make([]bool, blocks)
startMap := make(map[int64]*BuildMap, len(mapping))
for _, m := range mapping {
idx := BlockIdx(int64(m.Offset), int64(metadata.BlockSize))
if idx >= 0 && idx < blocks {
starts[idx] = true
startMap[idx] = m
}
}
return &Header{
Metadata: metadata,
Mapping: mapping,
blockStarts: starts,
startMap: startMap,
}, nil
}
// GetShiftedMapping resolves a memory offset to the corresponding diff file
// offset, remaining length, and build ID. This is the hot path called for
// every UFFD page fault.
func (h *Header) GetShiftedMapping(_ context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) {
if offset < 0 || offset >= int64(h.Metadata.Size) {
return 0, 0, nil, fmt.Errorf("offset %d out of bounds (size: %d)", offset, h.Metadata.Size)
}
blockSize := int64(h.Metadata.BlockSize)
block := BlockIdx(offset, blockSize)
// Walk backwards to find the BuildMap that contains this block.
start := block
for start >= 0 {
if h.blockStarts[start] {
break
}
start--
}
if start < 0 {
return 0, 0, nil, fmt.Errorf("no mapping found for offset %d", offset)
}
m, ok := h.startMap[start]
if !ok {
return 0, 0, nil, fmt.Errorf("no mapping at block %d", start)
}
shift := (block - start) * blockSize
if shift >= int64(m.Length) {
return 0, 0, nil, fmt.Errorf("offset %d beyond mapping end (mapping offset=%d, length=%d)", offset, m.Offset, m.Length)
}
return int64(m.BuildStorageOffset) + shift, int64(m.Length) - shift, &m.BuildID, nil
}
// Serialize writes metadata + mapping entries to binary (little-endian).
func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) {
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, metadata); err != nil {
return nil, fmt.Errorf("write metadata: %w", err)
}
for _, m := range mappings {
if err := binary.Write(&buf, binary.LittleEndian, m); err != nil {
return nil, fmt.Errorf("write mapping: %w", err)
}
}
return buf.Bytes(), nil
}
// Deserialize reads a header from binary data.
func Deserialize(data []byte) (*Header, error) {
reader := bytes.NewReader(data)
var metadata Metadata
if err := binary.Read(reader, binary.LittleEndian, &metadata); err != nil {
return nil, fmt.Errorf("read metadata: %w", err)
}
var mappings []*BuildMap
for {
var m BuildMap
if err := binary.Read(reader, binary.LittleEndian, &m); err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, fmt.Errorf("read mapping: %w", err)
}
mappings = append(mappings, &m)
}
return NewHeader(&metadata, mappings)
}
// Block index helpers.
func TotalBlocks(size, blockSize int64) int64 {
return (size + blockSize - 1) / blockSize
}
func BlockIdx(offset, blockSize int64) int64 {
return offset / blockSize
}
func BlockOffset(idx, blockSize int64) int64 {
return idx * blockSize
}

View File

@ -7,14 +7,15 @@ import (
"os"
"path/filepath"
"syscall"
"github.com/google/uuid"
)
const (
SnapFileName = "snapfile"
MemDiffName = "memfile"
MemHeaderName = "memfile.header"
// Cloud Hypervisor snapshot files.
CHConfigFile = "config.json"
CHMemRangesFile = "memory-ranges"
CHStateFile = "state.json"
// Rootfs files.
RootfsFileName = "rootfs.ext4"
RootfsCowName = "rootfs.cow"
RootfsMetaName = "rootfs.meta"
@ -25,27 +26,6 @@ func DirPath(baseDir, name string) string {
return filepath.Join(baseDir, name)
}
// SnapPath returns the path to the VM state snapshot file.
func SnapPath(baseDir, name string) string {
return filepath.Join(DirPath(baseDir, name), SnapFileName)
}
// MemDiffPath returns the path to the compact memory diff file (legacy single-generation).
func MemDiffPath(baseDir, name string) string {
return filepath.Join(DirPath(baseDir, name), MemDiffName)
}
// MemDiffPathForBuild returns the path to a specific generation's diff file.
// Format: memfile.{buildID}
func MemDiffPathForBuild(baseDir, name string, buildID uuid.UUID) string {
return filepath.Join(DirPath(baseDir, name), fmt.Sprintf("memfile.%s", buildID.String()))
}
// MemHeaderPath returns the path to the memory mapping header file.
func MemHeaderPath(baseDir, name string) string {
return filepath.Join(DirPath(baseDir, name), MemHeaderName)
}
// RootfsPath returns the path to the rootfs image.
func RootfsPath(baseDir, name string) string {
return filepath.Join(DirPath(baseDir, name), RootfsFileName)
@ -61,10 +41,13 @@ func MetaPath(baseDir, name string) string {
return filepath.Join(DirPath(baseDir, name), RootfsMetaName)
}
// RootfsMeta records which base template a CoW file was created against.
// RootfsMeta records which base template a CoW file was created against
// and the VM resource config needed to restart the sampler on resume.
type RootfsMeta struct {
BaseTemplate string `json:"base_template"`
TemplateID string `json:"template_id,omitempty"`
VCPUs int `json:"vcpus,omitempty"`
MemoryMB int `json:"memory_mb,omitempty"`
}
// WriteMeta writes rootfs metadata to the snapshot directory.
@ -92,102 +75,6 @@ func ReadMeta(baseDir, name string) (*RootfsMeta, error) {
return &meta, nil
}
// Exists reports whether a complete snapshot exists (all required files present).
// Supports both legacy (rootfs.ext4) and CoW-based (rootfs.cow + rootfs.meta) snapshots.
// Memory diff files can be either legacy "memfile" or generation-specific "memfile.{uuid}".
func Exists(baseDir, name string) bool {
dir := DirPath(baseDir, name)
// snapfile and header are always required.
for _, f := range []string{SnapFileName, MemHeaderName} {
if _, err := os.Stat(filepath.Join(dir, f)); err != nil {
return false
}
}
// Check that at least one memfile exists (legacy or generation-specific).
// We verify by reading the header and checking that referenced diff files exist.
// Fall back to checking for the legacy memfile name if header can't be read.
if _, err := os.Stat(filepath.Join(dir, MemDiffName)); err != nil {
// No legacy memfile — check if any memfile.{uuid} exists by
// looking for files matching the pattern.
matches, _ := filepath.Glob(filepath.Join(dir, "memfile.*"))
hasGenDiff := false
for _, m := range matches {
base := filepath.Base(m)
if base != MemHeaderName {
hasGenDiff = true
break
}
}
if !hasGenDiff {
return false
}
}
// Accept either rootfs.ext4 (legacy/template) or rootfs.cow + rootfs.meta (dm-snapshot).
if _, err := os.Stat(filepath.Join(dir, RootfsFileName)); err == nil {
return true
}
if _, err := os.Stat(filepath.Join(dir, RootfsCowName)); err == nil {
if _, err := os.Stat(filepath.Join(dir, RootfsMetaName)); err == nil {
return true
}
}
return false
}
// IsTemplate reports whether a template image directory exists (has rootfs.ext4).
func IsTemplate(baseDir, name string) bool {
_, err := os.Stat(filepath.Join(DirPath(baseDir, name), RootfsFileName))
return err == nil
}
// IsSnapshot reports whether a directory is a snapshot (has all snapshot files).
func IsSnapshot(baseDir, name string) bool {
return Exists(baseDir, name)
}
// HasCow reports whether a snapshot uses CoW format (rootfs.cow + rootfs.meta)
// as opposed to legacy full rootfs (rootfs.ext4).
func HasCow(baseDir, name string) bool {
dir := DirPath(baseDir, name)
_, cowErr := os.Stat(filepath.Join(dir, RootfsCowName))
_, metaErr := os.Stat(filepath.Join(dir, RootfsMetaName))
return cowErr == nil && metaErr == nil
}
// ListDiffFiles returns a map of build ID → file path for all memory diff files
// referenced by the given header. Handles both the legacy "memfile" name
// (single-generation) and generation-specific "memfile.{uuid}" names.
func ListDiffFiles(baseDir, name string, header *Header) (map[string]string, error) {
dir := DirPath(baseDir, name)
result := make(map[string]string)
for _, m := range header.Mapping {
if m.BuildID == uuid.Nil {
continue // zero-fill, no file needed
}
idStr := m.BuildID.String()
if _, exists := result[idStr]; exists {
continue
}
// Try generation-specific path first, fall back to legacy.
genPath := filepath.Join(dir, fmt.Sprintf("memfile.%s", idStr))
if _, err := os.Stat(genPath); err == nil {
result[idStr] = genPath
continue
}
legacyPath := filepath.Join(dir, MemDiffName)
if _, err := os.Stat(legacyPath); err == nil {
result[idStr] = legacyPath
continue
}
return nil, fmt.Errorf("diff file not found for build %s", idStr)
}
return result, nil
}
// EnsureDir creates the snapshot directory if it doesn't exist.
func EnsureDir(baseDir, name string) error {
dir := DirPath(baseDir, name)

View File

@ -1,214 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package snapshot
import "github.com/google/uuid"
// CreateMapping converts a dirty-block bitset (represented as a []bool) into
// a sorted list of BuildMap entries. Consecutive dirty blocks are merged into
// a single entry. BuildStorageOffset tracks the sequential position in the
// compact diff file.
func CreateMapping(buildID uuid.UUID, dirty []bool, blockSize int64) []*BuildMap {
var mappings []*BuildMap
var runStart int64 = -1
var runLength int64
var storageOffset uint64
for i, set := range dirty {
if !set {
if runLength > 0 {
mappings = append(mappings, &BuildMap{
Offset: uint64(runStart) * uint64(blockSize),
Length: uint64(runLength) * uint64(blockSize),
BuildID: buildID,
BuildStorageOffset: storageOffset,
})
storageOffset += uint64(runLength) * uint64(blockSize)
runLength = 0
}
runStart = -1
continue
}
if runStart < 0 {
runStart = int64(i)
runLength = 1
} else {
runLength++
}
}
if runLength > 0 {
mappings = append(mappings, &BuildMap{
Offset: uint64(runStart) * uint64(blockSize),
Length: uint64(runLength) * uint64(blockSize),
BuildID: buildID,
BuildStorageOffset: storageOffset,
})
}
return mappings
}
// MergeMappings overlays diffMapping on top of baseMapping. Where they overlap,
// diff takes priority. The result covers the entire address space.
//
// Both inputs must be sorted by Offset. The base mapping should cover the full size.
//
// Inspired by e2b's snapshot system (Apache 2.0, modified by Omukk).
func MergeMappings(baseMapping, diffMapping []*BuildMap) []*BuildMap {
if len(diffMapping) == 0 {
return baseMapping
}
// Work on a copy of baseMapping to avoid mutating the original.
baseCopy := make([]*BuildMap, len(baseMapping))
for i, m := range baseMapping {
cp := *m
baseCopy[i] = &cp
}
var result []*BuildMap
var bi, di int
for bi < len(baseCopy) && di < len(diffMapping) {
base := baseCopy[bi]
diff := diffMapping[di]
if base.Length == 0 {
bi++
continue
}
if diff.Length == 0 {
di++
continue
}
// No overlap: base entirely before diff.
if base.Offset+base.Length <= diff.Offset {
result = append(result, base)
bi++
continue
}
// No overlap: diff entirely before base.
if diff.Offset+diff.Length <= base.Offset {
result = append(result, diff)
di++
continue
}
// Base fully inside diff — skip base.
if base.Offset >= diff.Offset && base.Offset+base.Length <= diff.Offset+diff.Length {
bi++
continue
}
// Diff fully inside base — split base around diff.
if diff.Offset >= base.Offset && diff.Offset+diff.Length <= base.Offset+base.Length {
leftLen := int64(diff.Offset) - int64(base.Offset)
if leftLen > 0 {
result = append(result, &BuildMap{
Offset: base.Offset,
Length: uint64(leftLen),
BuildID: base.BuildID,
BuildStorageOffset: base.BuildStorageOffset,
})
}
result = append(result, diff)
di++
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
rightLen := int64(base.Length) - rightShift
if rightLen > 0 {
baseCopy[bi] = &BuildMap{
Offset: base.Offset + uint64(rightShift),
Length: uint64(rightLen),
BuildID: base.BuildID,
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
}
} else {
bi++
}
continue
}
// Base starts after diff with overlap — emit diff, trim base.
if base.Offset > diff.Offset {
result = append(result, diff)
di++
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
rightLen := int64(base.Length) - rightShift
if rightLen > 0 {
baseCopy[bi] = &BuildMap{
Offset: base.Offset + uint64(rightShift),
Length: uint64(rightLen),
BuildID: base.BuildID,
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
}
} else {
bi++
}
continue
}
// Diff starts after base with overlap — emit left part of base.
if diff.Offset > base.Offset {
leftLen := int64(diff.Offset) - int64(base.Offset)
if leftLen > 0 {
result = append(result, &BuildMap{
Offset: base.Offset,
Length: uint64(leftLen),
BuildID: base.BuildID,
BuildStorageOffset: base.BuildStorageOffset,
})
}
bi++
continue
}
}
// Append remaining entries.
result = append(result, baseCopy[bi:]...)
result = append(result, diffMapping[di:]...)
return result
}
// NormalizeMappings merges adjacent entries with the same BuildID.
func NormalizeMappings(mappings []*BuildMap) []*BuildMap {
if len(mappings) == 0 {
return nil
}
result := make([]*BuildMap, 0, len(mappings))
current := &BuildMap{
Offset: mappings[0].Offset,
Length: mappings[0].Length,
BuildID: mappings[0].BuildID,
BuildStorageOffset: mappings[0].BuildStorageOffset,
}
for i := 1; i < len(mappings); i++ {
m := mappings[i]
if m.BuildID == current.BuildID {
current.Length += m.Length
} else {
result = append(result, current)
current = &BuildMap{
Offset: m.Offset,
Length: m.Length,
BuildID: m.BuildID,
BuildStorageOffset: m.BuildStorageOffset,
}
}
}
result = append(result, current)
return result
}

View File

@ -1,285 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package snapshot
import (
"context"
"fmt"
"io"
"os"
"github.com/google/uuid"
)
const (
// DefaultBlockSize is 4KB — standard page size for Firecracker.
DefaultBlockSize int64 = 4096
)
// ProcessMemfile reads a full memory file produced by Firecracker's
// PUT /snapshot/create, identifies non-zero blocks, and writes only those
// blocks to a compact diff file. Returns the Header describing the mapping.
//
// The output diff file contains non-zero blocks written sequentially.
// The header maps each block in the full address space to either:
// - A position in the diff file (for non-zero blocks)
// - uuid.Nil (for zero/empty blocks, served as zeros without I/O)
//
// buildID identifies this snapshot generation in the header chain.
func ProcessMemfile(memfilePath, diffPath, headerPath string, buildID uuid.UUID) (*Header, error) {
src, err := os.Open(memfilePath)
if err != nil {
return nil, fmt.Errorf("open memfile: %w", err)
}
defer src.Close()
info, err := src.Stat()
if err != nil {
return nil, fmt.Errorf("stat memfile: %w", err)
}
memSize := info.Size()
dst, err := os.Create(diffPath)
if err != nil {
return nil, fmt.Errorf("create diff file: %w", err)
}
defer dst.Close()
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
dirty := make([]bool, totalBlocks)
empty := make([]bool, totalBlocks)
buf := make([]byte, DefaultBlockSize)
for i := int64(0); i < totalBlocks; i++ {
n, err := io.ReadFull(src, buf)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, fmt.Errorf("read block %d: %w", i, err)
}
// Zero-pad the last block if it's short.
if int64(n) < DefaultBlockSize {
for j := n; j < int(DefaultBlockSize); j++ {
buf[j] = 0
}
}
if isZeroBlock(buf) {
empty[i] = true
continue
}
dirty[i] = true
if _, err := dst.Write(buf); err != nil {
return nil, fmt.Errorf("write diff block %d: %w", i, err)
}
}
// Build header.
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
emptyMappings := CreateMapping(uuid.Nil, empty, DefaultBlockSize)
merged := MergeMappings(dirtyMappings, emptyMappings)
normalized := NormalizeMappings(merged)
metadata := NewMetadata(buildID, uint64(DefaultBlockSize), uint64(memSize))
header, err := NewHeader(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("create header: %w", err)
}
// Write header to disk.
headerData, err := Serialize(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("serialize header: %w", err)
}
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
return nil, fmt.Errorf("write header: %w", err)
}
return header, nil
}
// ProcessMemfileWithParent processes a memory file as a new generation on top
// of an existing parent header. The new diff file contains only blocks that
// differ from what the parent header maps. This is used for re-pause of a
// sandbox that was restored from a snapshot.
func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHeader *Header, buildID uuid.UUID) (*Header, error) {
src, err := os.Open(memfilePath)
if err != nil {
return nil, fmt.Errorf("open memfile: %w", err)
}
defer src.Close()
info, err := src.Stat()
if err != nil {
return nil, fmt.Errorf("stat memfile: %w", err)
}
memSize := info.Size()
dst, err := os.Create(diffPath)
if err != nil {
return nil, fmt.Errorf("create diff file: %w", err)
}
defer dst.Close()
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
dirty := make([]bool, totalBlocks)
buf := make([]byte, DefaultBlockSize)
for i := int64(0); i < totalBlocks; i++ {
n, err := io.ReadFull(src, buf)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, fmt.Errorf("read block %d: %w", i, err)
}
if int64(n) < DefaultBlockSize {
for j := n; j < int(DefaultBlockSize); j++ {
buf[j] = 0
}
}
if isZeroBlock(buf) {
// For a diff memfile, zero blocks mean "not dirtied since resume" —
// they should inherit the parent's mapping, not be zero-filled.
continue
}
dirty[i] = true
if _, err := dst.Write(buf); err != nil {
return nil, fmt.Errorf("write diff block %d: %w", i, err)
}
}
// Only dirty blocks go into the diff overlay; MergeMappings preserves the
// parent's mapping for everything else.
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
merged := MergeMappings(parentHeader.Mapping, dirtyMappings)
normalized := NormalizeMappings(merged)
metadata := parentHeader.Metadata.NextGeneration(buildID)
header, err := NewHeader(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("create header: %w", err)
}
headerData, err := Serialize(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("serialize header: %w", err)
}
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
return nil, fmt.Errorf("write header: %w", err)
}
return header, nil
}
// MergeDiffs consolidates multiple generation diff files into a single diff
// file and resets the generation counter to 0. This is a pure file-level
// operation — no Firecracker involvement.
//
// It reads each non-nil block from the appropriate diff file (as mapped by
// the header), writes them all sequentially into a single new diff file,
// and produces a fresh header pointing only at that file.
//
// diffFiles maps build ID (string) → open file path for each generation's diff.
func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
blockSize := int64(header.Metadata.BlockSize)
mergedBuildID := uuid.New()
// Open all source diff files.
sources := make(map[string]*os.File, len(diffFiles))
for id, path := range diffFiles {
f, err := os.Open(path)
if err != nil {
// Close already opened files.
for _, sf := range sources {
sf.Close()
}
return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
}
sources[id] = f
}
defer func() {
for _, f := range sources {
f.Close()
}
}()
dst, err := os.Create(mergedDiffPath)
if err != nil {
return nil, fmt.Errorf("create merged diff file: %w", err)
}
defer dst.Close()
totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
dirty := make([]bool, totalBlocks)
empty := make([]bool, totalBlocks)
buf := make([]byte, blockSize)
for i := int64(0); i < totalBlocks; i++ {
offset := i * blockSize
mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
if err != nil {
return nil, fmt.Errorf("lookup block %d: %w", i, err)
}
if *buildID == uuid.Nil {
empty[i] = true
continue
}
src, ok := sources[buildID.String()]
if !ok {
return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
}
if _, err := src.ReadAt(buf, mappedOffset); err != nil {
return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
}
dirty[i] = true
if _, err := dst.Write(buf); err != nil {
return nil, fmt.Errorf("write merged block %d: %w", i, err)
}
}
// Build fresh header with generation 0.
dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
merged := MergeMappings(dirtyMappings, emptyMappings)
normalized := NormalizeMappings(merged)
metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
newHeader, err := NewHeader(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("create merged header: %w", err)
}
headerData, err := Serialize(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("serialize merged header: %w", err)
}
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
return nil, fmt.Errorf("write merged header: %w", err)
}
return newHeader, nil
}
// isZeroBlock checks if a block is entirely zero bytes.
func isZeroBlock(block []byte) bool {
// Fast path: compare 8 bytes at a time.
for i := 0; i+8 <= len(block); i += 8 {
if block[i] != 0 || block[i+1] != 0 || block[i+2] != 0 || block[i+3] != 0 ||
block[i+4] != 0 || block[i+5] != 0 || block[i+6] != 0 || block[i+7] != 0 {
return false
}
}
// Tail bytes.
for i := len(block) &^ 7; i < len(block); i++ {
if block[i] != 0 {
return false
}
}
return true
}

View File

@ -1,92 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
// Package uffd implements a userfaultfd-based memory server for Firecracker
// snapshot restore. When a VM is restored from a snapshot, instead of loading
// the entire memory file upfront, the UFFD handler intercepts page faults
// and serves memory pages on demand from the snapshot's compact diff file.
package uffd
/*
#include <sys/syscall.h>
#include <fcntl.h>
#include <linux/userfaultfd.h>
#include <sys/ioctl.h>
struct uffd_pagefault {
__u64 flags;
__u64 address;
__u32 ptid;
};
*/
import "C"
import (
"fmt"
"syscall"
"unsafe"
)
const (
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
UFFD_EVENT_FORK = C.UFFD_EVENT_FORK
UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP
UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE
UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
UFFDIO_COPY = C.UFFDIO_COPY
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP
)
type (
uffdMsg = C.struct_uffd_msg
uffdPagefault = C.struct_uffd_pagefault
uffdioCopy = C.struct_uffdio_copy
)
// fd wraps a userfaultfd file descriptor received from Firecracker.
type fd uintptr
// copy installs a page into guest memory at the given address using UFFDIO_COPY.
// mode controls write-protection: use UFFDIO_COPY_MODE_WP to preserve WP bit.
func (f fd) copy(addr, pagesize uintptr, data []byte, mode C.ulonglong) error {
alignedAddr := addr &^ (pagesize - 1)
cpy := uffdioCopy{
src: C.ulonglong(uintptr(unsafe.Pointer(&data[0]))),
dst: C.ulonglong(alignedAddr),
len: C.ulonglong(pagesize),
mode: mode,
copy: 0,
}
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy)))
if errno != 0 {
return errno
}
if cpy.copy != C.longlong(pagesize) {
return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize)
}
return nil
}
// close closes the userfaultfd file descriptor.
func (f fd) close() error {
return syscall.Close(int(f))
}
// getMsgEvent extracts the event type from a uffd_msg.
func getMsgEvent(msg *uffdMsg) C.uchar {
return msg.event
}
// getMsgArg extracts the arg union from a uffd_msg.
func getMsgArg(msg *uffdMsg) [24]byte {
return msg.arg
}
// getPagefaultAddress extracts the faulting address from a uffd_pagefault.
func getPagefaultAddress(pf *uffdPagefault) uintptr {
return uintptr(pf.address)
}

View File

@ -1,41 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
//
// Modifications by Omukk (Wrenn Sandbox): merged Region and Mapping into
// single file, inlined shiftedOffset helper.
package uffd
import "fmt"
// Region is a mapping of guest memory to host virtual address space.
// Firecracker sends these as JSON when connecting to the UFFD socket.
// The JSON field names match Firecracker's UFFD protocol.
type Region struct {
BaseHostVirtAddr uintptr `json:"base_host_virt_addr"`
Size uintptr `json:"size"`
Offset uintptr `json:"offset"`
PageSize uintptr `json:"page_size_kib"` // Actually in bytes despite the name.
}
// Mapping translates between host virtual addresses and logical memory offsets.
type Mapping struct {
Regions []Region
}
// NewMapping creates a Mapping from a list of regions.
func NewMapping(regions []Region) *Mapping {
return &Mapping{Regions: regions}
}
// GetOffset converts a host virtual address to a logical memory file offset
// and returns the page size. This is called on every UFFD page fault.
func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) {
for _, r := range m.Regions {
if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.BaseHostVirtAddr+r.Size {
offset := int64(hostVirtAddr-r.BaseHostVirtAddr) + int64(r.Offset)
return offset, r.PageSize, nil
}
}
return 0, 0, fmt.Errorf("address %#x not found in any memory region", hostVirtAddr)
}

View File

@ -1,451 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
//
// Modifications by Omukk (Wrenn Sandbox): replaced errgroup with WaitGroup
// + semaphore, replaced fdexit abstraction with pipe, integrated with
// snapshot.Header-based DiffFileSource instead of block.ReadonlyDevice,
// fixed EAGAIN handling in poll loop.
package uffd
/*
#include <linux/userfaultfd.h>
*/
import "C"
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"os"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"git.omukk.dev/wrenn/wrenn/internal/snapshot"
)
const (
fdSize = 4
regionMappingsSize = 1024
maxConcurrentFaults = 4096
)
// MemorySource provides page data for the UFFD handler.
// Given a logical memory offset and a size, it returns the page data.
type MemorySource interface {
ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error)
}
// Server manages the UFFD Unix socket lifecycle and page fault handling
// for a single Firecracker snapshot restore.
type Server struct {
socketPath string
source MemorySource
lis *net.UnixListener
readyCh chan struct{}
readyOnce sync.Once
doneCh chan struct{}
doneErr error
// exitPipe signals the poll loop to stop.
exitR *os.File
exitW *os.File
// Set by handle() after Firecracker connects; read by Prefetch()
// after waiting on readyCh (which establishes happens-before).
uffdFd fd
mapping *Mapping
// Prefetch lifecycle: cancel stops the goroutine, prefetchDone is
// closed when it exits. Stop() drains prefetchDone before returning
// so the caller can safely close diff file handles.
prefetchCancel context.CancelFunc
prefetchDone chan struct{}
}
// NewServer creates a UFFD server that will listen on the given socket path
// and serve memory pages from the given source.
func NewServer(socketPath string, source MemorySource) *Server {
return &Server{
socketPath: socketPath,
source: source,
readyCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins listening on the Unix socket. Firecracker will connect to this
// socket after loadSnapshot is called with the UFFD backend.
// Start returns immediately; the server runs in a background goroutine.
func (s *Server) Start(ctx context.Context) error {
lis, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.socketPath, Net: "unix"})
if err != nil {
return fmt.Errorf("listen on uffd socket: %w", err)
}
s.lis = lis
if err := os.Chmod(s.socketPath, 0o777); err != nil {
lis.Close()
return fmt.Errorf("chmod uffd socket: %w", err)
}
// Create exit signal pipe.
r, w, err := os.Pipe()
if err != nil {
lis.Close()
return fmt.Errorf("create exit pipe: %w", err)
}
s.exitR = r
s.exitW = w
go func() {
defer close(s.doneCh)
s.doneErr = s.handle(ctx)
s.lis.Close()
s.exitR.Close()
s.exitW.Close()
s.readyOnce.Do(func() { close(s.readyCh) })
}()
return nil
}
// Ready returns a channel that is closed when the UFFD handler is ready
// (after Firecracker has connected and sent the uffd fd).
func (s *Server) Ready() <-chan struct{} {
return s.readyCh
}
// Stop signals the UFFD poll loop to exit and waits for it to finish.
// Also cancels and waits for any running prefetch goroutine.
func (s *Server) Stop() error {
if s.prefetchCancel != nil {
s.prefetchCancel()
}
// Write a byte to the exit pipe to wake the poll loop.
_, _ = s.exitW.Write([]byte{0})
<-s.doneCh
if s.prefetchDone != nil {
<-s.prefetchDone
}
return s.doneErr
}
// Wait blocks until the server exits.
func (s *Server) Wait() error {
<-s.doneCh
return s.doneErr
}
// handle accepts the Firecracker connection, receives the UFFD fd via
// SCM_RIGHTS, and runs the page fault poll loop.
func (s *Server) handle(ctx context.Context) error {
conn, err := s.lis.Accept()
if err != nil {
return fmt.Errorf("accept uffd connection: %w", err)
}
unixConn := conn.(*net.UnixConn)
defer unixConn.Close()
// Read the memory region mappings (JSON) and the UFFD fd (SCM_RIGHTS).
regionBuf := make([]byte, regionMappingsSize)
uffdBuf := make([]byte, syscall.CmsgSpace(fdSize))
nRegion, nFd, _, _, err := unixConn.ReadMsgUnix(regionBuf, uffdBuf)
if err != nil {
return fmt.Errorf("read uffd message: %w", err)
}
var regions []Region
if err := json.Unmarshal(regionBuf[:nRegion], &regions); err != nil {
return fmt.Errorf("parse memory regions: %w", err)
}
controlMsgs, err := syscall.ParseSocketControlMessage(uffdBuf[:nFd])
if err != nil {
return fmt.Errorf("parse control messages: %w", err)
}
if len(controlMsgs) != 1 {
return fmt.Errorf("expected 1 control message, got %d", len(controlMsgs))
}
fds, err := syscall.ParseUnixRights(&controlMsgs[0])
if err != nil {
return fmt.Errorf("parse unix rights: %w", err)
}
if len(fds) != 1 {
return fmt.Errorf("expected 1 fd, got %d", len(fds))
}
uffdFd := fd(fds[0])
defer uffdFd.close()
mapping := NewMapping(regions)
// Store for use by Prefetch().
s.uffdFd = uffdFd
s.mapping = mapping
slog.Info("uffd handler connected",
"regions", len(regions),
"fd", int(uffdFd),
)
// Signal readiness.
s.readyOnce.Do(func() { close(s.readyCh) })
// Run the poll loop.
return s.serve(ctx, uffdFd, mapping)
}
// serve is the main poll loop. It polls the UFFD fd for page fault events
// and the exit pipe for shutdown signals.
func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
pollFds := []unix.PollFd{
{Fd: int32(uffdFd), Events: unix.POLLIN},
{Fd: int32(s.exitR.Fd()), Events: unix.POLLIN},
}
var wg sync.WaitGroup
sem := make(chan struct{}, maxConcurrentFaults)
// Always wait for in-flight goroutines before returning, so the caller
// can safely close the uffd fd after serve returns.
defer wg.Wait()
for {
if _, err := unix.Poll(pollFds, -1); err != nil {
if err == unix.EINTR || err == unix.EAGAIN {
continue
}
return fmt.Errorf("poll: %w", err)
}
// Check exit signal.
if pollFds[1].Revents&unix.POLLIN != 0 {
return nil
}
if pollFds[0].Revents&unix.POLLIN == 0 {
continue
}
// Read the uffd_msg. The fd is O_NONBLOCK (set by Firecracker),
// so EAGAIN is expected — just go back to poll.
buf := make([]byte, unsafe.Sizeof(uffdMsg{}))
n, err := readUffdMsg(uffdFd, buf)
if err == syscall.EAGAIN {
continue
}
if err != nil {
return fmt.Errorf("read uffd msg: %w", err)
}
if n == 0 {
continue
}
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
event := getMsgEvent(&msg)
switch event {
case UFFD_EVENT_PAGEFAULT:
// Handled below.
case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK:
// Non-fatal lifecycle events from the guest kernel (e.g. balloon
// deflation, mmap/munmap). No action needed — continue polling.
continue
default:
return fmt.Errorf("unexpected uffd event type: %d", event)
}
arg := getMsgArg(&msg)
pf := *(*uffdPagefault)(unsafe.Pointer(&arg[0]))
addr := getPagefaultAddress(&pf)
offset, pagesize, err := mapping.GetOffset(addr)
if err != nil {
return fmt.Errorf("resolve address %#x: %w", addr, err)
}
sem <- struct{}{}
wg.Add(1)
go func() {
defer wg.Done()
defer func() { <-sem }()
if err := s.faultPage(ctx, uffdFd, addr, offset, pagesize); err != nil {
slog.Error("uffd fault page error",
"addr", fmt.Sprintf("%#x", addr),
"offset", offset,
"error", err,
)
}
}()
}
}
// readUffdMsg reads a single uffd_msg, retrying on EINTR.
// Returns (n, EAGAIN) if the non-blocking read has nothing available.
func readUffdMsg(uffdFd fd, buf []byte) (int, error) {
for {
n, err := syscall.Read(int(uffdFd), buf)
if err == syscall.EINTR {
continue
}
return n, err
}
}
// faultPage fetches a page from the memory source and copies it into
// guest memory via UFFDIO_COPY.
func (s *Server) faultPage(ctx context.Context, uffdFd fd, addr uintptr, offset int64, pagesize uintptr) error {
data, err := s.source.ReadPage(ctx, offset, int64(pagesize))
if err != nil {
return fmt.Errorf("read page at offset %d: %w", offset, err)
}
// Mode 0: no write-protect. Standard Firecracker does not register
// UFFD ranges with WP support, so UFFDIO_COPY_MODE_WP would fail.
if err := uffdFd.copy(addr, pagesize, data, 0); err != nil {
if errors.Is(err, unix.EEXIST) {
// Page already mapped (race with prefetch or concurrent fault).
return nil
}
return fmt.Errorf("uffdio_copy: %w", err)
}
return nil
}
// Prefetch proactively loads all guest memory pages in the background.
// It iterates over every page in every UFFD region and copies it from the
// diff file into guest memory via UFFDIO_COPY. Pages already loaded by
// on-demand faults return nil from faultPage (EEXIST handled internally).
// This eliminates the per-request latency caused by lazy page faulting
// after snapshot restore.
//
// The goroutine blocks on readyCh before reading the uffd fd and mapping
// fields (establishes happens-before with handle()). It uses an internal
// context independent of the caller's RPC context so it survives after the
// create/resume RPC returns. Stop() cancels and joins the goroutine.
func (s *Server) Prefetch() {
ctx, cancel := context.WithCancel(context.Background())
s.prefetchCancel = cancel
s.prefetchDone = make(chan struct{})
go func() {
defer close(s.prefetchDone)
// Wait for Firecracker to connect and send the uffd fd.
select {
case <-s.readyCh:
case <-ctx.Done():
return
}
uffdFd := s.uffdFd
mapping := s.mapping
if mapping == nil {
return
}
var total, errored int
for _, region := range mapping.Regions {
pageSize := region.PageSize
if pageSize == 0 {
continue
}
for off := uintptr(0); off < region.Size; off += pageSize {
if ctx.Err() != nil {
slog.Debug("uffd prefetch cancelled",
"pages", total, "errors", errored)
return
}
addr := region.BaseHostVirtAddr + off
memOffset := int64(off) + int64(region.Offset)
if err := s.faultPage(ctx, uffdFd, addr, memOffset, pageSize); err != nil {
errored++
} else {
total++
}
}
}
slog.Info("uffd prefetch complete",
"pages", total, "errors", errored)
}()
}
// DiffFileSource serves pages from a snapshot's compact diff file using
// the header's block mapping to resolve offsets.
type DiffFileSource struct {
header *snapshot.Header
// diffs maps build ID → open file handle for each generation's diff file.
diffs map[string]*os.File
}
// NewDiffFileSource creates a memory source backed by snapshot diff files.
// diffs maps build ID string to the file path of each generation's diff file.
func NewDiffFileSource(header *snapshot.Header, diffPaths map[string]string) (*DiffFileSource, error) {
diffs := make(map[string]*os.File, len(diffPaths))
for id, path := range diffPaths {
f, err := os.Open(path)
if err != nil {
// Close already opened files.
for _, opened := range diffs {
opened.Close()
}
return nil, fmt.Errorf("open diff file %s: %w", path, err)
}
diffs[id] = f
}
return &DiffFileSource{header: header, diffs: diffs}, nil
}
// ReadPage resolves a memory offset through the header mapping and reads
// the corresponding page from the correct generation's diff file.
func (s *DiffFileSource) ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error) {
mappedOffset, _, buildID, err := s.header.GetShiftedMapping(ctx, offset)
if err != nil {
return nil, fmt.Errorf("resolve offset %d: %w", offset, err)
}
// uuid.Nil means zero-fill (empty page).
var nilUUID [16]byte
if *buildID == nilUUID {
return make([]byte, size), nil
}
f, ok := s.diffs[buildID.String()]
if !ok {
return nil, fmt.Errorf("no diff file for build %s", buildID)
}
buf := make([]byte, size)
n, err := f.ReadAt(buf, mappedOffset)
if err != nil && int64(n) < size {
return nil, fmt.Errorf("read diff at offset %d: %w", mappedOffset, err)
}
return buf, nil
}
// Close closes all open diff file handles.
func (s *DiffFileSource) Close() error {
var errs []error
for _, f := range s.diffs {
if err := f.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

232
internal/vm/ch.go Normal file
View File

@ -0,0 +1,232 @@
package vm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
)
// chClient talks to the Cloud Hypervisor HTTP API over a Unix socket.
type chClient struct {
http *http.Client
socketPath string
}
func newCHClient(socketPath string) *chClient {
return &chClient{
socketPath: socketPath,
http: &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
},
},
},
}
}
func (c *chClient) do(ctx context.Context, method, path string, body any) error {
return c.doJSON(ctx, method, path, body, nil)
}
// doJSON sends a request and optionally decodes a JSON response into out.
// out may be nil if the response body should be discarded.
func (c *chClient) doJSON(ctx context.Context, method, path string, body, out any) error {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("%s %s: %w", method, path, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s %s: status %d: %s", method, path, resp.StatusCode, string(respBody))
}
if out != nil {
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return fmt.Errorf("%s %s: decode response: %w", method, path, err)
}
}
return nil
}
func boolPtr(b bool) *bool { return &b }
// --- CH API payload types ---
type chPayload struct {
Firmware string `json:"firmware,omitempty"`
Kernel string `json:"kernel"`
Cmdline string `json:"cmdline"`
}
type chCPUs struct {
BootVCPUs int `json:"boot_vcpus"`
MaxVCPUs int `json:"max_vcpus"`
}
type chMemory struct {
Size uint64 `json:"size"`
Shared bool `json:"shared,omitempty"`
// Thp uses a pointer with NO omitempty so explicit false is always
// serialized (CH defaults to true). Must be false so the backing memfile
// remains 4 KiB-granular: balloon-reported free pages get punched as
// holes and CH's SEEK_DATA/SEEK_HOLE snapshot writer (v52+) skips them.
// A nil Thp would silently re-enable THP and break sparse snapshots —
// rejecting "thp": null at the wire is preferable to a silent fallback.
Thp *bool `json:"thp"`
Prefault bool `json:"prefault,omitempty"`
HotplugSize uint64 `json:"hotplug_size,omitempty"`
HotplugMethod string `json:"hotplug_method,omitempty"`
}
type chDisk struct {
Path string `json:"path"`
Readonly bool `json:"readonly,omitempty"`
ImageType string `json:"image_type,omitempty"`
}
type chNet struct {
Tap string `json:"tap"`
MAC string `json:"mac"`
NumQs int `json:"num_queues,omitempty"`
QueueS int `json:"queue_size,omitempty"`
}
type chBalloon struct {
Size int64 `json:"size"`
DeflateOnOOM bool `json:"deflate_on_oom"`
FreePageRep bool `json:"free_page_reporting,omitempty"`
}
type chConsole struct {
Mode string `json:"mode"`
}
type chCreatePayload struct {
Payload chPayload `json:"payload"`
CPUs chCPUs `json:"cpus"`
Memory chMemory `json:"memory"`
Disks []chDisk `json:"disks"`
Net []chNet `json:"net"`
Balloon *chBalloon `json:"balloon,omitempty"`
Serial chConsole `json:"serial"`
Console chConsole `json:"console"`
}
// createVM sends the full VM configuration as a single payload.
func (c *chClient) createVM(ctx context.Context, cfg *VMConfig) error {
memBytes := uint64(cfg.MemoryMB) * 1024 * 1024
payload := chCreatePayload{
Payload: chPayload{
Kernel: cfg.KernelPath,
Cmdline: cfg.kernelArgs(),
},
CPUs: chCPUs{
BootVCPUs: cfg.VCPUs,
MaxVCPUs: cfg.VCPUs,
},
Memory: chMemory{
Size: memBytes,
Shared: true,
Thp: boolPtr(false),
},
Disks: []chDisk{
{
Path: cfg.SandboxDir + "/rootfs.ext4",
ImageType: "Raw",
},
},
Net: []chNet{
{
Tap: cfg.TapDevice,
MAC: cfg.TapMAC,
},
},
Balloon: &chBalloon{
Size: 0,
DeflateOnOOM: true,
FreePageRep: true,
},
Serial: chConsole{
Mode: "Tty",
},
Console: chConsole{
Mode: "Off",
},
}
return c.do(ctx, http.MethodPut, "/api/v1/vm.create", payload)
}
// bootVM starts the VM after creation.
func (c *chClient) bootVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/api/v1/vm.boot", nil)
}
// shutdownVMM cleanly shuts down the Cloud Hypervisor VMM process.
func (c *chClient) shutdownVMM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/api/v1/vmm.shutdown", nil)
}
// resizeBalloon adjusts the balloon target at runtime.
// sizeBytes is memory to take FROM the guest (0 = give all back).
func (c *chClient) resizeBalloon(ctx context.Context, sizeBytes int64) error {
return c.do(ctx, http.MethodPut, "/api/v1/vm.resize", map[string]int64{
"desired_balloon": sizeBytes,
})
}
// pauseVM freezes guest vCPUs and devices via the CH API.
func (c *chClient) pauseVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/api/v1/vm.pause", nil)
}
// resumeVM unfreezes a paused VM via the CH API.
func (c *chClient) resumeVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/api/v1/vm.resume", nil)
}
// snapshotVM dumps VM config + state + memory to a directory URL of the form
// `file:///abs/path/`. VM must be paused before calling.
func (c *chClient) snapshotVM(ctx context.Context, destURL string) error {
return c.do(ctx, http.MethodPut, "/api/v1/vm.snapshot", map[string]string{
"destination_url": destURL,
})
}
// vmInfo reports the runtime state of the VM. Used after a restore to confirm
// CH successfully hydrated the snapshot before registering the VM.
func (c *chClient) vmInfo(ctx context.Context) (state string, err error) {
var resp struct {
State string `json:"state"`
}
if err := c.doJSON(ctx, http.MethodGet, "/api/v1/vm.info", nil, &resp); err != nil {
return "", err
}
return resp.State, nil
}

104
internal/vm/cleanup.go Normal file
View File

@ -0,0 +1,104 @@
package vm
import (
"log/slog"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"syscall"
"time"
)
// CleanupStaleProcesses kills any cloud-hypervisor processes left behind by a
// previous agent that crashed without graceful shutdown. Must run at agent
// startup before devicemapper.CleanupStaleDevices — a still-running CH process
// holds the dm-snapshot open and would cause "Device or resource busy" on
// dmsetup remove.
//
// Matches processes by argv containing the wrenn CH API socket path
// (/tmp/ch-<sandboxID>.sock) so we don't kill unrelated cloud-hypervisor VMs
// the operator may be running.
//
// Also removes stale /tmp/ch-*.sock files once the owning process is gone.
func CleanupStaleProcesses() {
socketPattern := regexp.MustCompile(`/tmp/ch-[A-Za-z0-9-]+\.sock`)
pids, err := scanProcs()
if err != nil {
slog.Debug("scan procs failed", "error", err)
return
}
killed := 0
for _, pid := range pids {
cmdline, err := readCmdline(pid)
if err != nil {
continue
}
if !strings.Contains(cmdline, "cloud-hypervisor") {
continue
}
if !socketPattern.MatchString(cmdline) {
continue
}
slog.Warn("killing stale cloud-hypervisor process", "pid", pid, "cmdline", cmdline)
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
slog.Warn("SIGTERM stale CH failed", "pid", pid, "error", err)
}
killed++
}
// Give SIGTERM'd processes a brief window to exit so subsequent dm/loop
// teardown sees no open fd, then SIGKILL anything still alive.
if killed > 0 {
time.Sleep(500 * time.Millisecond)
for _, pid := range pids {
cmdline, err := readCmdline(pid)
if err != nil {
continue
}
if !strings.Contains(cmdline, "cloud-hypervisor") || !socketPattern.MatchString(cmdline) {
continue
}
_ = syscall.Kill(pid, syscall.SIGKILL)
}
time.Sleep(200 * time.Millisecond)
}
matches, _ := filepath.Glob("/tmp/ch-*.sock")
for _, sock := range matches {
if err := os.Remove(sock); err == nil {
slog.Info("removed stale CH socket", "path", sock)
}
}
}
func scanProcs() ([]int, error) {
entries, err := os.ReadDir("/proc")
if err != nil {
return nil, err
}
var pids []int
for _, e := range entries {
if !e.IsDir() {
continue
}
pid, err := strconv.Atoi(e.Name())
if err != nil {
continue
}
pids = append(pids, pid)
}
return pids, nil
}
func readCmdline(pid int) (string, error) {
b, err := os.ReadFile("/proc/" + strconv.Itoa(pid) + "/cmdline")
if err != nil {
return "", err
}
// /proc/<pid>/cmdline is NUL-separated; convert to spaces for substring match.
return strings.ReplaceAll(string(b), "\x00", " "), nil
}

View File

@ -2,13 +2,25 @@ package vm
import "fmt"
// VMConfig holds the configuration for creating a Firecracker microVM.
// SandboxTmpDir returns the per-sandbox tmpfs mount point used inside the
// VMM's private mount namespace. Recorded as the disk path in CH's saved
// config.json, so restore paths must reconstruct it exactly to make the
// symlink prelude resolve.
func SandboxTmpDir(sandboxID string) string {
return fmt.Sprintf("/tmp/ch-vm-%s", sandboxID)
}
// SandboxSocketPath returns the Cloud Hypervisor API socket path for a sandbox.
func SandboxSocketPath(sandboxID string) string {
return fmt.Sprintf("/tmp/ch-%s.sock", sandboxID)
}
// VMConfig holds the configuration for creating a Cloud Hypervisor microVM.
type VMConfig struct {
// SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4").
SandboxID string
// TemplateID is the template UUID string used to populate MMDS metadata
// so that envd can read WRENN_TEMPLATE_ID from inside the guest.
// TemplateID is the template UUID string, passed to envd via PostInit.
TemplateID string
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
@ -25,12 +37,12 @@ type VMConfig struct {
MemoryMB int
// NetworkNamespace is the name of the network namespace to launch
// Firecracker inside (e.g., "ns-1"). The namespace must already exist
// Cloud Hypervisor inside (e.g., "ns-1"). The namespace must already exist
// with a TAP device configured.
NetworkNamespace string
// TapDevice is the name of the TAP device inside the network namespace
// that Firecracker will attach to (e.g., "tap0").
// that Cloud Hypervisor will attach to (e.g., "tap0").
TapDevice string
// TapMAC is the MAC address for the TAP device.
@ -45,19 +57,34 @@ type VMConfig struct {
// NetMask is the subnet mask for the guest network (e.g., "255.255.255.252").
NetMask string
// FirecrackerBin is the path to the firecracker binary.
FirecrackerBin string
// VMMBin is the path to the cloud-hypervisor binary.
VMMBin string
// SocketPath is the path for the Firecracker API Unix socket.
// SocketPath is the path for the Cloud Hypervisor API Unix socket.
SocketPath string
// SandboxDir is the tmpfs mount point for per-sandbox files inside the
// mount namespace (e.g., "/fc-vm").
// mount namespace (e.g., "/ch-vm").
SandboxDir string
// InitPath is the path to the init process inside the guest.
// Defaults to "/sbin/init" if empty.
InitPath string
// RestoreFromDir, if non-empty, switches the process launcher into restore
// mode. CH is invoked with `--restore source_url=file://{dir}/` instead of
// the fresh-boot path. The directory must contain CH's snapshot artefacts
// (config.json, state.json, memory-ranges, memory file).
RestoreFromDir string
// RestoreLazyMemory enables `memory_restore_mode=ondemand` so guest pages
// fault in lazily via userfaultfd. Only honored when RestoreFromDir is set.
RestoreLazyMemory bool
// LogDir is the directory for Cloud Hypervisor log files. If set, CH
// stdout/stderr are written to {LogDir}/ch-{SandboxID}.log instead of
// the parent process's stdout/stderr.
LogDir string
}
func (c *VMConfig) applyDefaults() {
@ -67,14 +94,14 @@ func (c *VMConfig) applyDefaults() {
if c.MemoryMB == 0 {
c.MemoryMB = 512
}
if c.FirecrackerBin == "" {
c.FirecrackerBin = "/usr/local/bin/firecracker"
if c.VMMBin == "" {
c.VMMBin = "/usr/local/bin/cloud-hypervisor"
}
if c.SocketPath == "" {
c.SocketPath = fmt.Sprintf("/tmp/fc-%s.sock", c.SandboxID)
c.SocketPath = SandboxSocketPath(c.SandboxID)
}
if c.SandboxDir == "" {
c.SandboxDir = "/tmp/fc-vm"
c.SandboxDir = SandboxTmpDir(c.SandboxID)
}
if c.TapDevice == "" {
c.TapDevice = "tap0"
@ -95,7 +122,7 @@ func (c *VMConfig) kernelArgs() string {
)
return fmt.Sprintf(
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s",
"console=ttyS0 root=/dev/vda rw rootflags=nodiscard reboot=k panic=1 quiet loglevel=1 init_on_free=1 clocksource=kvm-clock init=%s %s",
c.InitPath, ipArg,
)
}

View File

@ -1,202 +0,0 @@
package vm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
)
// fcClient talks to the Firecracker HTTP API over a Unix socket.
type fcClient struct {
http *http.Client
socketPath string
}
func newFCClient(socketPath string) *fcClient {
return &fcClient{
socketPath: socketPath,
http: &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
},
},
// No global timeout — callers pass context.Context with appropriate
// deadlines. A fixed 10s timeout was too short for snapshot/resume
// operations on large-memory VMs (20GB+ memfiles).
},
}
}
func (c *fcClient) do(ctx context.Context, method, path string, body any) error {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(data)
}
// The host in the URL is ignored for Unix sockets; we use "localhost" by convention.
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("%s %s: %w", method, path, err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s %s: status %d: %s", method, path, resp.StatusCode, string(respBody))
}
return nil
}
// setBootSource configures the kernel and boot args.
func (c *fcClient) setBootSource(ctx context.Context, kernelPath, bootArgs string) error {
return c.do(ctx, http.MethodPut, "/boot-source", map[string]string{
"kernel_image_path": kernelPath,
"boot_args": bootArgs,
})
}
// setRootfsDrive configures the root filesystem drive.
func (c *fcClient) setRootfsDrive(ctx context.Context, driveID, path string, readOnly bool) error {
return c.do(ctx, http.MethodPut, "/drives/"+driveID, map[string]any{
"drive_id": driveID,
"path_on_host": path,
"is_root_device": true,
"is_read_only": readOnly,
})
}
// setNetworkInterface configures a network interface attached to a TAP device.
// A tx_rate_limiter caps sustained guest→host throughput to prevent user
// application traffic from completely saturating the TAP device and starving
// envd control traffic (PTY, exec, file ops).
func (c *fcClient) setNetworkInterface(ctx context.Context, ifaceID, tapName, macAddr string) error {
return c.do(ctx, http.MethodPut, "/network-interfaces/"+ifaceID, map[string]any{
"iface_id": ifaceID,
"host_dev_name": tapName,
"guest_mac": macAddr,
"tx_rate_limiter": map[string]any{
"bandwidth": map[string]any{
"size": 209715200, // 200 MB/s sustained
"refill_time": 1000, // refill period: 1 second
"one_time_burst": 104857600, // 100 MB initial burst
},
},
})
}
// setMachineConfig configures vCPUs, memory, and other machine settings.
func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error {
return c.do(ctx, http.MethodPut, "/machine-config", map[string]any{
"vcpu_count": vcpus,
"mem_size_mib": memMB,
"smt": false,
})
}
// setMMDSConfig enables MMDS V2 token-based access on the given network interface.
// Must be called before startVM.
func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error {
return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{
"version": "V2",
"network_interfaces": []string{ifaceID},
})
}
// mmdsMetadata is the metadata payload written to the Firecracker MMDS store.
// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID.
type mmdsMetadata struct {
SandboxID string `json:"instanceID"`
TemplateID string `json:"envID"`
}
// setMMDS writes sandbox metadata to the Firecracker MMDS store.
// Can be called after the VM has started.
func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error {
return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{
SandboxID: sandboxID,
TemplateID: templateID,
})
}
// setBalloon configures the Firecracker balloon device for dynamic memory
// management. deflateOnOom lets the guest reclaim balloon pages under memory
// pressure. statsInterval enables periodic stats via GET /balloon/statistics.
// Must be called before startVM.
func (c *fcClient) setBalloon(ctx context.Context, amountMiB int, deflateOnOom bool, statsIntervalS int) error {
return c.do(ctx, http.MethodPut, "/balloon", map[string]any{
"amount_mib": amountMiB,
"deflate_on_oom": deflateOnOom,
"stats_polling_interval_s": statsIntervalS,
})
}
// updateBalloon adjusts the balloon target at runtime.
func (c *fcClient) updateBalloon(ctx context.Context, amountMiB int) error {
return c.do(ctx, http.MethodPatch, "/balloon", map[string]any{
"amount_mib": amountMiB,
})
}
// startVM issues the InstanceStart action.
func (c *fcClient) startVM(ctx context.Context) error {
return c.do(ctx, http.MethodPut, "/actions", map[string]string{
"action_type": "InstanceStart",
})
}
// pauseVM pauses the microVM.
func (c *fcClient) pauseVM(ctx context.Context) error {
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
"state": "Paused",
})
}
// resumeVM resumes a paused microVM.
func (c *fcClient) resumeVM(ctx context.Context) error {
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
"state": "Resumed",
})
}
// createSnapshot creates a VM snapshot.
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
func (c *fcClient) createSnapshot(ctx context.Context, snapPath, memPath, snapshotType string) error {
return c.do(ctx, http.MethodPut, "/snapshot/create", map[string]any{
"snapshot_type": snapshotType,
"snapshot_path": snapPath,
"mem_file_path": memPath,
})
}
// loadSnapshotWithUffd loads a VM snapshot using a UFFD socket for
// lazy memory loading. Firecracker will connect to the socket and
// send the uffd fd + memory region mappings.
func (c *fcClient) loadSnapshotWithUffd(ctx context.Context, snapPath, uffdSocketPath string) error {
return c.do(ctx, http.MethodPut, "/snapshot/load", map[string]any{
"snapshot_path": snapPath,
"resume_vm": false,
"mem_backend": map[string]any{
"backend_type": "Uffd",
"backend_path": uffdSocketPath,
},
})
}

View File

@ -1,128 +0,0 @@
package vm
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"syscall"
"time"
)
// process represents a running Firecracker process with mount and network
// namespace isolation.
type process struct {
cmd *exec.Cmd
cancel context.CancelFunc
exitCh chan struct{}
exitErr error
}
// startProcess launches the Firecracker binary inside an isolated mount namespace
// and the specified network namespace. The launch sequence:
//
// 1. unshare -m: creates a private mount namespace
// 2. mount --make-rprivate /: prevents mount propagation to host
// 3. mount tmpfs at SandboxDir: ephemeral workspace for this VM
// 4. symlink kernel and rootfs into SandboxDir
// 5. ip netns exec <ns>: enters the network namespace where TAP is configured
// 6. exec firecracker with the API socket path
func startProcess(ctx context.Context, cfg *VMConfig) (*process, error) {
// Use a background context for the long-lived Firecracker process.
// The request context (ctx) is only used for the startup phase — we must
// not tie the VM's lifetime to the HTTP request that created it.
execCtx, cancel := context.WithCancel(context.Background())
script := buildStartScript(cfg)
cmd := exec.CommandContext(execCtx, "unshare", "-m", "--", "bash", "-c", script)
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true, // new session so signals don't propagate from parent
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
cancel()
return nil, fmt.Errorf("start firecracker process: %w", err)
}
p := &process{
cmd: cmd,
cancel: cancel,
exitCh: make(chan struct{}),
}
go func() {
p.exitErr = cmd.Wait()
close(p.exitCh)
}()
slog.Info("firecracker process started",
"pid", cmd.Process.Pid,
"sandbox", cfg.SandboxID,
)
return p, nil
}
// buildStartScript generates the bash script that sets up the mount namespace,
// symlinks kernel/rootfs, and execs Firecracker inside the network namespace.
func buildStartScript(cfg *VMConfig) string {
return fmt.Sprintf(`
set -euo pipefail
# Prevent mount propagation to the host
mount --make-rprivate /
# Create ephemeral tmpfs workspace
mkdir -p %[1]s
mount -t tmpfs tmpfs %[1]s
# Symlink kernel and rootfs into the workspace
ln -s %[2]s %[1]s/vmlinux
ln -s %[3]s %[1]s/rootfs.ext4
# Launch Firecracker inside the network namespace
exec ip netns exec %[4]s %[5]s --api-sock %[6]s
`,
cfg.SandboxDir, // 1
cfg.KernelPath, // 2
cfg.RootfsPath, // 3
cfg.NetworkNamespace, // 4
cfg.FirecrackerBin, // 5
cfg.SocketPath, // 6
)
}
// stop sends SIGTERM and waits for the process to exit. If it doesn't exit
// within 10 seconds, SIGKILL is sent.
func (p *process) stop() error {
if p.cmd.Process == nil {
return nil
}
// Send SIGTERM to the process group (negative PID).
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGTERM); err != nil {
slog.Debug("sigterm failed, process may have exited", "error", err)
}
select {
case <-p.exitCh:
return nil
case <-time.After(10 * time.Second):
slog.Warn("firecracker did not exit after SIGTERM, sending SIGKILL")
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGKILL); err != nil {
slog.Debug("sigkill failed", "error", err)
}
<-p.exitCh
return nil
}
}
// exited returns a channel that is closed when the process exits.
func (p *process) exited() <-chan struct{} {
return p.exitCh
}

View File

@ -5,18 +5,19 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"sync"
"time"
)
// VM represents a running Firecracker microVM.
// VM represents a running Cloud Hypervisor microVM.
type VM struct {
Config VMConfig
process *process
client *fcClient
client *chClient
}
// Manager handles the lifecycle of Firecracker microVMs.
// Manager handles the lifecycle of Cloud Hypervisor microVMs.
type Manager struct {
mu sync.RWMutex
// vms tracks running VMs by sandbox ID.
@ -30,7 +31,7 @@ func NewManager() *Manager {
}
}
// Create boots a new Firecracker microVM with the given configuration.
// Create boots a new Cloud Hypervisor microVM with the given configuration.
// The network namespace and TAP device must already be set up.
func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
cfg.applyDefaults()
@ -38,7 +39,6 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
return nil, fmt.Errorf("invalid config: %w", err)
}
// Clean up any leftover socket from a previous run.
os.Remove(cfg.SocketPath)
slog.Info("creating VM",
@ -47,8 +47,8 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
"memory_mb", cfg.MemoryMB,
)
// Step 1: Launch the Firecracker process.
proc, err := startProcess(ctx, &cfg)
// Step 1: Launch the Cloud Hypervisor process.
proc, err := startProcess(&cfg)
if err != nil {
return nil, fmt.Errorf("start process: %w", err)
}
@ -59,25 +59,18 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
return nil, fmt.Errorf("wait for socket: %w", err)
}
// Step 3: Configure the VM via the Firecracker API.
client := newFCClient(cfg.SocketPath)
// Step 3: Configure and boot the VM via a single API call.
client := newCHClient(cfg.SocketPath)
if err := configureVM(ctx, client, &cfg); err != nil {
if err := client.createVM(ctx, &cfg); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("configure VM: %w", err)
return nil, fmt.Errorf("create VM config: %w", err)
}
// Step 4: Start the VM.
if err := client.startVM(ctx); err != nil {
// Step 4: Boot the VM.
if err := client.bootVM(ctx); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("start VM: %w", err)
}
// Step 5: Push sandbox metadata into MMDS so envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("set MMDS metadata: %w", err)
return nil, fmt.Errorf("boot VM: %w", err)
}
vm := &VM{
@ -95,78 +88,34 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
return vm, nil
}
// configureVM sends the configuration to Firecracker via its HTTP API.
func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
// Boot source (kernel + args)
if err := client.setBootSource(ctx, cfg.KernelPath, cfg.kernelArgs()); err != nil {
return fmt.Errorf("set boot source: %w", err)
}
// Root drive — use the symlink path inside the mount namespace so that
// snapshots record a stable path that works on restore.
rootfsSymlink := cfg.SandboxDir + "/rootfs.ext4"
if err := client.setRootfsDrive(ctx, "rootfs", rootfsSymlink, false); err != nil {
return fmt.Errorf("set rootfs drive: %w", err)
}
// Network interface
if err := client.setNetworkInterface(ctx, "eth0", cfg.TapDevice, cfg.TapMAC); err != nil {
return fmt.Errorf("set network interface: %w", err)
}
// Machine config (vCPUs + memory)
if err := client.setMachineConfig(ctx, cfg.VCPUs, cfg.MemoryMB); err != nil {
return fmt.Errorf("set machine config: %w", err)
}
// Balloon device — allows the host to reclaim unused guest memory.
// Start with 0 (no inflation). deflate_on_oom lets the guest reclaim
// balloon pages under memory pressure. Stats interval enables monitoring.
if err := client.setBalloon(ctx, 0, true, 5); err != nil {
slog.Warn("set balloon failed (non-fatal, VM will run without memory reclaim)", "error", err)
}
// MMDS config — enable V2 token access on eth0 so that envd can read
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
return fmt.Errorf("set MMDS config: %w", err)
}
return nil
}
// Pause pauses a running VM.
// Pause freezes a running VM's vCPUs via the CH API.
func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
vm, ok := m.Get(sandboxID)
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
if err := vm.client.pauseVM(ctx); err != nil {
return fmt.Errorf("pause VM: %w", err)
}
slog.Info("VM paused", "sandbox", sandboxID)
return nil
return vm.client.pauseVM(ctx)
}
// Resume resumes a paused VM.
// Resume unfreezes a paused VM via the CH API.
func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
vm, ok := m.Get(sandboxID)
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
return vm.client.resumeVM(ctx)
}
if err := vm.client.resumeVM(ctx); err != nil {
return fmt.Errorf("resume VM: %w", err)
// Info returns the CH VM state (e.g. "Running", "Paused", "Shutdown") via
// the CH unix-socket API. Returns an error if the socket is dead or the VM
// is not registered. Use to probe liveness before issuing destructive ops
// like pause or snapshot.
func (m *Manager) Info(ctx context.Context, sandboxID string) (string, error) {
vm, ok := m.Get(sandboxID)
if !ok {
return "", fmt.Errorf("VM not found: %s", sandboxID)
}
slog.Info("VM resumed", "sandbox", sandboxID)
return nil
return vm.client.vmInfo(ctx)
}
// UpdateBalloon adjusts the balloon target for a running VM.
@ -179,7 +128,8 @@ func (m *Manager) UpdateBalloon(ctx context.Context, sandboxID string, amountMiB
return fmt.Errorf("VM not found: %s", sandboxID)
}
return vm.client.updateBalloon(ctx, amountMiB)
sizeBytes := int64(amountMiB) * 1024 * 1024
return vm.client.resizeBalloon(ctx, sizeBytes)
}
// Destroy stops and cleans up a VM.
@ -190,103 +140,98 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
m.mu.Unlock()
return fmt.Errorf("VM not found: %s", sandboxID)
}
delete(m.vms, sandboxID)
m.mu.Unlock()
slog.Info("destroying VM", "sandbox", sandboxID)
// Stop the Firecracker process.
// Try clean shutdown first, fall back to process kill.
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second)
if err := vm.client.shutdownVMM(shutdownCtx); err != nil {
slog.Debug("clean VMM shutdown failed, killing process", "sandbox", sandboxID, "error", err)
}
shutdownCancel()
if err := vm.process.stop(); err != nil {
slog.Warn("error stopping process", "sandbox", sandboxID, "error", err)
}
// Clean up the API socket.
os.Remove(vm.Config.SocketPath)
m.mu.Lock()
delete(m.vms, sandboxID)
m.mu.Unlock()
slog.Info("VM destroyed", "sandbox", sandboxID)
return nil
}
// Snapshot creates a VM snapshot. The VM must already be paused.
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapPath, memPath, snapshotType string) error {
m.mu.RLock()
vm, ok := m.vms[sandboxID]
m.mu.RUnlock()
// Snapshot writes the VM's config/state/memory to snapshotDir via CH's
// vm.snapshot API. The VM must already be paused. snapshotDir must be an
// absolute path; it is passed to CH as `file://{dir}/`.
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapshotDir string) error {
vm, ok := m.Get(sandboxID)
if !ok {
return fmt.Errorf("VM not found: %s", sandboxID)
}
if err := vm.client.createSnapshot(ctx, snapPath, memPath, snapshotType); err != nil {
return fmt.Errorf("create snapshot: %w", err)
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
return fmt.Errorf("mkdir snapshot dir: %w", err)
}
slog.Info("VM snapshot created", "sandbox", sandboxID, "snap_path", snapPath, "type", snapshotType)
url := "file://" + strings.TrimRight(snapshotDir, "/") + "/"
if err := vm.client.snapshotVM(ctx, url); err != nil {
return fmt.Errorf("vm.snapshot: %w", err)
}
slog.Info("VM snapshot written", "sandbox", sandboxID, "dir", snapshotDir)
return nil
}
// CreateFromSnapshot boots a new Firecracker VM by loading a snapshot
// using UFFD for lazy memory loading. The network namespace and TAP
// device must already be set up.
// CreateFromSnapshot launches a Cloud Hypervisor process in restore mode,
// connecting it to an existing snapshot directory. The VM is left in the
// paused state — the caller is expected to call Resume after any post-restore
// setup (e.g. re-acquiring envd connectivity is implicit via TCP).
//
// No boot resources (kernel, drives, machine config) are configured —
// the snapshot carries all that state. The rootfs path recorded in the
// snapshot is resolved via a stable symlink at SandboxDir/rootfs.ext4
// inside the mount namespace (created by the start script in jailer.go).
//
// The sequence is:
// 1. Start FC process in mount+network namespace (creates tmpfs + rootfs symlink)
// 2. Wait for API socket
// 3. Load snapshot with UFFD backend
// 4. Resume VM execution
func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath, uffdSocketPath string) (*VM, error) {
// cfg.RestoreFromDir must point to an absolute path containing the CH
// snapshot artefacts. The disk path inside config.json must already resolve
// (CH receives the same SandboxDir/rootfs.ext4 symlink as for fresh boot).
func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig) (*VM, error) {
cfg.applyDefaults()
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
if cfg.RestoreFromDir == "" {
return nil, fmt.Errorf("RestoreFromDir is required for restore")
}
os.Remove(cfg.SocketPath)
slog.Info("restoring VM from snapshot",
"sandbox", cfg.SandboxID,
"snap_path", snapPath,
"restore_dir", cfg.RestoreFromDir,
"lazy_memory", cfg.RestoreLazyMemory,
)
// Step 1: Launch the Firecracker process.
// The start script creates a tmpfs at SandboxDir and symlinks
// rootfs.ext4 → cfg.RootfsPath, so the snapshot's recorded rootfs
// path (/fc-vm/rootfs.ext4) resolves to the new clone.
proc, err := startProcess(ctx, &cfg)
proc, err := startRestoreProcess(&cfg)
if err != nil {
return nil, fmt.Errorf("start process: %w", err)
return nil, fmt.Errorf("start restore process: %w", err)
}
// Step 2: Wait for the API socket.
if err := waitForSocket(ctx, cfg.SocketPath, proc); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("wait for socket: %w", err)
}
client := newFCClient(cfg.SocketPath)
client := newCHClient(cfg.SocketPath)
// Step 3: Load the snapshot with UFFD backend.
// No boot resources are configured — the snapshot carries kernel,
// drive, network, and machine config state.
if err := client.loadSnapshotWithUffd(ctx, snapPath, uffdSocketPath); err != nil {
// Confirm CH actually hydrated the snapshot before registering. Without
// this check, a broken snapshot would leave a zombie *VM in the map that
// blocks future restores for the same sandbox ID.
state, err := client.vmInfo(ctx)
if err != nil {
_ = proc.stop()
return nil, fmt.Errorf("load snapshot: %w", err)
return nil, fmt.Errorf("vm.info after restore: %w", err)
}
// Step 4: Resume the VM.
if err := client.resumeVM(ctx); err != nil {
if state != "Paused" {
_ = proc.stop()
return nil, fmt.Errorf("resume VM: %w", err)
}
// Step 5: Push sandbox metadata into MMDS.
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
_ = proc.stop()
return nil, fmt.Errorf("set MMDS metadata: %w", err)
return nil, fmt.Errorf("unexpected post-restore VM state %q (want Paused)", state)
}
vm := &VM{
@ -299,16 +244,20 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
m.vms[cfg.SandboxID] = vm
m.mu.Unlock()
slog.Info("VM restored from snapshot", "sandbox", cfg.SandboxID)
slog.Info("VM restored from snapshot (paused)", "sandbox", cfg.SandboxID)
return vm, nil
}
// PID returns the process ID of the unshare wrapper process.
// The actual Firecracker process is a direct child of this PID.
func (v *VM) PID() int {
return v.process.cmd.Process.Pid
}
// Exited returns a channel that is closed when the VM process exits.
func (v *VM) Exited() <-chan struct{} {
return v.process.exited()
}
// Get returns a running VM by sandbox ID.
func (m *Manager) Get(sandboxID string) (*VM, bool) {
m.mu.RLock()
@ -317,7 +266,7 @@ func (m *Manager) Get(sandboxID string) (*VM, bool) {
return vm, ok
}
// waitForSocket polls for the Firecracker API socket to appear on disk.
// waitForSocket polls for the Cloud Hypervisor API socket to appear on disk.
func waitForSocket(ctx context.Context, socketPath string, proc *process) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
@ -329,7 +278,7 @@ func waitForSocket(ctx context.Context, socketPath string, proc *process) error
case <-ctx.Done():
return ctx.Err()
case <-proc.exited():
return fmt.Errorf("firecracker process exited before socket was ready")
return fmt.Errorf("cloud-hypervisor process exited before socket was ready")
case <-timeout:
return fmt.Errorf("timed out waiting for API socket at %s", socketPath)
case <-ticker.C:

174
internal/vm/process.go Normal file
View File

@ -0,0 +1,174 @@
package vm
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"strings"
"syscall"
"time"
)
// process represents a running Cloud Hypervisor process with mount and network
// namespace isolation.
type process struct {
cmd *exec.Cmd
cancel context.CancelFunc
exitCh chan struct{}
exitErr error
logFile *os.File
}
// startProcess launches the Cloud Hypervisor binary inside an isolated mount
// namespace and the specified network namespace. Used for fresh boot (no
// snapshot). The launch sequence:
//
// 1. unshare -m: creates a private mount namespace
// 2. mount --make-rprivate /: prevents mount propagation to host
// 3. mount tmpfs at SandboxDir: ephemeral workspace for this VM
// 4. symlink kernel and rootfs into SandboxDir
// 5. ip netns exec <ns>: enters the network namespace where TAP is configured
// 6. exec cloud-hypervisor with the API socket path
func startProcess(cfg *VMConfig) (*process, error) {
script := buildStartScript(cfg)
return launchScript(script, cfg)
}
// startRestoreProcess launches CH in restore mode. It mirrors startProcess
// for namespace/tmpfs/symlink setup so the disk paths recorded in the
// snapshot's config.json remain valid, then execs CH with `--restore`.
func startRestoreProcess(cfg *VMConfig) (*process, error) {
script := buildRestoreScript(cfg)
return launchScript(script, cfg)
}
func launchScript(script string, cfg *VMConfig) (*process, error) {
execCtx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(execCtx, "unshare", "-m", "--", "bash", "-c", script)
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
var logFile *os.File
if cfg.LogDir != "" {
logPath := fmt.Sprintf("%s/ch-%s.log", cfg.LogDir, cfg.SandboxID)
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640)
if err != nil {
cancel()
return nil, fmt.Errorf("open CH log file %s: %w", logPath, err)
}
cmd.Stdout = f
cmd.Stderr = f
logFile = f
}
if err := cmd.Start(); err != nil {
cancel()
if logFile != nil {
logFile.Close()
}
return nil, fmt.Errorf("start cloud-hypervisor process: %w", err)
}
p := &process{
cmd: cmd,
cancel: cancel,
exitCh: make(chan struct{}),
logFile: logFile,
}
go func() {
p.exitErr = cmd.Wait()
if p.logFile != nil {
p.logFile.Close()
}
close(p.exitCh)
}()
slog.Info("cloud-hypervisor process started",
"pid", cmd.Process.Pid,
"sandbox", cfg.SandboxID,
)
return p, nil
}
// buildStartScript generates the bash script for fresh boot: sets up mount
// namespace, symlinks kernel/rootfs, and execs Cloud Hypervisor.
func buildStartScript(cfg *VMConfig) string {
return buildLaunchScript(cfg, "")
}
// buildRestoreScript generates the bash script for restoring a VM from a
// snapshot directory. The mount/symlink prelude is identical to fresh boot
// so disk paths in the snapshot config.json resolve correctly.
func buildRestoreScript(cfg *VMConfig) string {
dir := strings.TrimRight(cfg.RestoreFromDir, "/")
restoreArg := fmt.Sprintf("--restore source_url=file://%s/", dir)
if cfg.RestoreLazyMemory {
restoreArg += ",memory_restore_mode=ondemand"
}
return buildLaunchScript(cfg, restoreArg)
}
// buildLaunchScript composes the namespace/tmpfs/symlink prelude and the
// final cloud-hypervisor exec line. extraArgs is appended verbatim — used
// to inject `--restore source_url=...` for restore launches.
func buildLaunchScript(cfg *VMConfig, extraArgs string) string {
chCmd := fmt.Sprintf("ip netns exec %s %s --api-socket path=%s",
cfg.NetworkNamespace, cfg.VMMBin, cfg.SocketPath)
if extraArgs != "" {
chCmd += " " + extraArgs
}
return fmt.Sprintf(`
set -euo pipefail
mount --make-rprivate /
mkdir -p %[1]s
mount -t tmpfs tmpfs %[1]s
ln -s %[2]s %[1]s/vmlinux
ln -s %[3]s %[1]s/rootfs.ext4
exec %[4]s
`,
cfg.SandboxDir, // 1
cfg.KernelPath, // 2
cfg.RootfsPath, // 3
chCmd, // 4
)
}
// stop sends SIGTERM and waits for the process to exit. If it doesn't exit
// within 10 seconds, SIGKILL is sent.
func (p *process) stop() error {
if p.cmd.Process == nil {
return nil
}
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGTERM); err != nil {
slog.Debug("sigterm failed, process may have exited", "error", err)
}
select {
case <-p.exitCh:
return nil
case <-time.After(10 * time.Second):
slog.Warn("cloud-hypervisor did not exit after SIGTERM, sending SIGKILL")
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGKILL); err != nil {
slog.Debug("sigkill failed", "error", err)
}
<-p.exitCh
return nil
}
}
// exited returns a channel that is closed when the process exits.
func (p *process) exited() <-chan struct{} {
return p.exitCh
}