forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -3,11 +3,20 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
@ -20,3 +29,119 @@ func agentForHost(ctx context.Context, queries *db.Queries, pool *lifecycle.Host
|
||||
}
|
||||
return pool.GetForHost(host)
|
||||
}
|
||||
|
||||
// requireRunningSandbox parses the sandbox ID from the URL, looks it up by team,
|
||||
// and verifies it is running. On failure it writes the appropriate HTTP error and
|
||||
// returns false.
|
||||
func requireRunningSandbox(w http.ResponseWriter, r *http.Request, queries *db.Queries, teamID pgtype.UUID) (db.Sandbox, pgtype.UUID, string, bool) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
|
||||
sb, err := queries.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return db.Sandbox{}, pgtype.UUID{}, "", false
|
||||
}
|
||||
|
||||
return sb, sandboxID, sandboxIDStr, true
|
||||
}
|
||||
|
||||
// upgradeAndAuthenticate upgrades the HTTP connection to WebSocket. The
|
||||
// auth context must already be populated by upstream middleware — browser
|
||||
// clients via the wrenn_sid cookie (sent automatically on WS upgrade),
|
||||
// SDK clients via X-API-Key. Requests without an auth context are rejected
|
||||
// with a 401 before the upgrade.
|
||||
func upgradeAndAuthenticate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, auth.AuthContext, error) {
|
||||
ac, hasAuth := auth.FromContext(r.Context())
|
||||
if !hasAuth {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
|
||||
return nil, auth.AuthContext{}, fmt.Errorf("unauthenticated")
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return nil, auth.AuthContext{}, fmt.Errorf("websocket upgrade: %w", err)
|
||||
}
|
||||
return conn, ac, nil
|
||||
}
|
||||
|
||||
// resolveTemplateSizes queries a host agent for the actual disk usage of any
|
||||
// templates with size_bytes <= 0 (e.g. system base templates seeded with
|
||||
// size_bytes = 0 before the rootfs was built). Results are persisted to the
|
||||
// DB so subsequent requests serve the correct size without an RPC call.
|
||||
// Errors are logged but do not prevent the caller from serving the templates.
|
||||
func resolveTemplateSizes(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, templates []db.Template) []db.Template {
|
||||
needResolve := false
|
||||
for _, t := range templates {
|
||||
if t.SizeBytes <= 0 {
|
||||
needResolve = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needResolve {
|
||||
return templates
|
||||
}
|
||||
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil || len(hosts) == 0 {
|
||||
slog.Warn("resolveTemplateSizes: no active hosts available", "error", err)
|
||||
return templates
|
||||
}
|
||||
|
||||
agent, err := pool.GetForHost(hosts[0])
|
||||
if err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to connect to host",
|
||||
"host_id", id.UUIDString(hosts[0].ID), "error", err)
|
||||
return templates
|
||||
}
|
||||
|
||||
for i, t := range templates {
|
||||
if t.SizeBytes > 0 {
|
||||
continue
|
||||
}
|
||||
resp, err := agent.GetTemplateSize(ctx, connect.NewRequest(&pb.GetTemplateSizeRequest{
|
||||
TeamId: formatUUIDForRPC(t.TeamID),
|
||||
TemplateId: formatUUIDForRPC(t.ID),
|
||||
}))
|
||||
if err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to get size from host",
|
||||
"template", t.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
templates[i].SizeBytes = resp.Msg.SizeBytes
|
||||
if err := queries.UpdateTemplateSize(ctx, db.UpdateTemplateSizeParams{
|
||||
ID: t.ID,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
}); err != nil {
|
||||
slog.Warn("resolveTemplateSizes: failed to persist size",
|
||||
"template", t.Name, "error", err)
|
||||
}
|
||||
}
|
||||
return templates
|
||||
}
|
||||
|
||||
// updateLastActive updates the sandbox last_active_at timestamp.
|
||||
// Uses a background context with timeout for streaming handlers where
|
||||
// the request context may already be cancelled.
|
||||
func updateLastActive(queries *db.Queries, sandboxID pgtype.UUID, sandboxIDStr string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := queries.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
54
internal/api/auth_hooks.go
Normal file
54
internal/api/auth_hooks.go
Normal file
@ -0,0 +1,54 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// collectAuthHooks filters the extension list down to those that implement
|
||||
// cpextension.AuthHook.
|
||||
func collectAuthHooks(extensions []cpextension.Extension) []cpextension.AuthHook {
|
||||
var hooks []cpextension.AuthHook
|
||||
for _, ext := range extensions {
|
||||
if h, ok := ext.(cpextension.AuthHook); ok {
|
||||
hooks = append(hooks, h)
|
||||
}
|
||||
}
|
||||
return hooks
|
||||
}
|
||||
|
||||
// fireOnSignup runs every OnSignup hook sequentially. The first error wins —
|
||||
// signup must abort so the cloud-side billing customer creation stays the
|
||||
// gating step. Other hook errors after a failure are not run.
|
||||
func fireOnSignup(ctx context.Context, hooks []cpextension.AuthHook, userID, teamID pgtype.UUID, email string) error {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnSignup(ctx, userID, teamID, email); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fireOnLogin runs every OnLogin hook. Errors are logged and swallowed —
|
||||
// login must never be blocked by a misbehaving extension.
|
||||
func fireOnLogin(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnLogin(ctx, userID); err != nil {
|
||||
slog.Warn("auth hook OnLogin failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fireOnSoftDelete runs every OnAccountSoftDelete hook. Errors are logged.
|
||||
func fireOnSoftDelete(ctx context.Context, hooks []cpextension.AuthHook, userID pgtype.UUID) {
|
||||
for _, h := range hooks {
|
||||
if err := h.OnAccountSoftDelete(ctx, userID); err != nil {
|
||||
slog.Warn("auth hook OnAccountSoftDelete failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
@ -17,8 +12,6 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type adminCapsuleHandler struct {
|
||||
@ -49,15 +42,18 @@ func (h *adminCapsuleHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
|
||||
if err != nil {
|
||||
if sb.ID.Valid {
|
||||
h.audit.LogSandboxDestroySystem(r.Context(), id.PlatformTeamID, sb.ID, "cleanup_after_create_error", nil)
|
||||
}
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/admin/capsules.
|
||||
@ -106,26 +102,27 @@ func (h *adminCapsuleHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID); err != nil {
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
err = h.svc.Destroy(r.Context(), sandboxID, id.PlatformTeamID)
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
type adminSnapshotRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot.
|
||||
// Pauses the capsule, takes a snapshot as a platform template, then destroys the capsule.
|
||||
// Snapshot handles POST /v1/admin/capsules/{id}/snapshot. Takes a live
|
||||
// snapshot of a platform-owned capsule and registers the result as a
|
||||
// platform template (team_id = 00000000-...).
|
||||
func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
@ -133,117 +130,22 @@ func (h *adminCapsuleHandler) Snapshot(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req adminSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Verify sandbox exists and belongs to platform team BEFORE any
|
||||
// destructive operations (template overwrite).
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: id.PlatformTeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists as a platform template.
|
||||
if existing, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
// Delete old snapshot files from all hosts before removing the DB record.
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, existing.TeamID, existing.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete existing snapshot files")
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: id.PlatformTeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
if r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
sb, name, err := h.svc.CreateSnapshot(r.Context(), sandboxID, id.PlatformTeamID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" to prevent the reconciler from racing.
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context so the snapshot completes even if the client disconnects.
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: sandboxIDStr,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(id.PlatformTeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", sandboxIDStr, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: id.PlatformTeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the ephemeral capsule after successful snapshot.
|
||||
if err := h.svc.Destroy(snapCtx, sandboxID, id.PlatformTeamID); err != nil {
|
||||
slog.Error("failed to destroy capsule after snapshot", "sandbox_id", sandboxIDStr, "error", err)
|
||||
// Don't fail the response — the snapshot was created successfully.
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
@ -20,6 +20,8 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
@ -140,14 +142,15 @@ type switchTeamRequest struct {
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
sessions *session.Service
|
||||
mailer email.Mailer
|
||||
rdb *redis.Client
|
||||
redirectURL string
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, mailer email.Mailer, rdb *redis.Client, redirectURL string) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/")}
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, sessions *session.Service, mailer email.Mailer, rdb *redis.Client, redirectURL string, hooks []cpextension.AuthHook) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, sessions: sessions, mailer: mailer, rdb: rdb, redirectURL: strings.TrimRight(redirectURL, "/"), authHooks: hooks}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
@ -166,11 +169,41 @@ type activateRequest struct {
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
// issueSession creates a new session and writes both the session and CSRF
|
||||
// cookies to the response. On success it writes the authResponse JSON body.
|
||||
func (h *authHandler) issueSession(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
) error {
|
||||
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIP returns the request's apparent client IP, honoring
|
||||
// X-Forwarded-For when behind a reverse proxy.
|
||||
func clientIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
if i := strings.IndexByte(fwd, ','); i > 0 {
|
||||
return strings.TrimSpace(fwd[:i])
|
||||
}
|
||||
return strings.TrimSpace(fwd)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
type signupResponse struct {
|
||||
@ -253,8 +286,8 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Generate activation token and store in Redis.
|
||||
rawToken := generateActivationToken()
|
||||
tokenHash := hashActivationToken(rawToken)
|
||||
rawToken := generateOpaqueToken()
|
||||
tokenHash := hashOpaqueToken(rawToken)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(userID), activationTTL).Err(); err != nil {
|
||||
@ -296,7 +329,7 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashActivationToken(req.Token)
|
||||
tokenHash := hashOpaqueToken(req.Token)
|
||||
redisKey := activationKeyPrefix + tokenHash
|
||||
|
||||
userIDStr, err := h.rdb.GetDel(ctx, redisKey).Result()
|
||||
@ -345,18 +378,26 @@ func (h *authHandler) Activate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
// Fire OnSignup before issuing a session — billing must succeed first.
|
||||
if err := fireOnSignup(ctx, h.authHooks, userID, team.ID, user.Email); err != nil {
|
||||
slog.Error("activate: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "signup_hook_failed", "failed to finalize account setup")
|
||||
return
|
||||
}
|
||||
if err := h.issueSession(w, r, userID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
|
||||
slog.Error("activate: failed to issue session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, userID)
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(userID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
@ -427,18 +468,20 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
if err := h.issueSession(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin); err != nil {
|
||||
slog.Error("login: failed to issue session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create session")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(user.ID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
@ -503,24 +546,54 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, teamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
|
||||
// Rotate the SID so any leaked old cookie loses access at the moment of
|
||||
// privilege change.
|
||||
newSess, err := h.sessions.Rotate(ctx, ac.SessionID, ac.UserID, teamID, user.Email, user.Name, membership.Role, user.IsAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
slog.Error("switch team: failed to rotate session", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to switch team")
|
||||
return
|
||||
}
|
||||
setSessionCookies(w, newSess.RawSID, newSess.CSRFToken, isSecure(r))
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: ac.Email,
|
||||
Name: user.Name,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Role: membership.Role,
|
||||
IsAdmin: user.IsAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
// Logout handles POST /v1/auth/logout — revokes the caller's current session.
|
||||
func (h *authHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if err := h.sessions.Revoke(r.Context(), ac.SessionID); err != nil {
|
||||
slog.Warn("logout: revoke failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// LogoutAll handles POST /v1/auth/logout-all — revokes every session for the
|
||||
// current user, including the caller's own.
|
||||
func (h *authHandler) LogoutAll(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
if err := h.sessions.RevokeAllForUser(r.Context(), ac.UserID); err != nil {
|
||||
slog.Error("logout-all: revoke failed", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to revoke sessions")
|
||||
return
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateActivationToken() string {
|
||||
// generateOpaqueToken returns a fresh 16-byte hex-encoded random token,
|
||||
// used for email activation links and password reset links.
|
||||
func generateOpaqueToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
@ -528,7 +601,9 @@ func generateActivationToken() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashActivationToken(raw string) string {
|
||||
// hashOpaqueToken returns the SHA-256 hex digest of raw, used as the
|
||||
// lookup key for one-shot tokens stored in Redis.
|
||||
func hashOpaqueToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
158
internal/api/handlers_build_stream.go
Normal file
158
internal/api/handlers_build_stream.go
Normal file
@ -0,0 +1,158 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/recipe"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
// buildStreamKeepalive is the interval between server pings on an idle build
|
||||
// stream, preventing intermediaries from closing the WebSocket.
|
||||
const buildStreamKeepalive = 30 * time.Second
|
||||
|
||||
// buildStreamHandler serves the live admin build console WebSocket.
|
||||
type buildStreamHandler struct {
|
||||
db *db.Queries
|
||||
broker *service.BuildBroker
|
||||
}
|
||||
|
||||
func newBuildStreamHandler(db *db.Queries, broker *service.BuildBroker) *buildStreamHandler {
|
||||
return &buildStreamHandler{db: db, broker: broker}
|
||||
}
|
||||
|
||||
// Stream handles WS /v1/admin/builds/{id}/stream. On connect it replays the
|
||||
// completed-step history from the DB log, sends the current build status,
|
||||
// then live-tails events from the build broker until the build finishes or
|
||||
// the client disconnects. Admin auth is enforced by upstream middleware.
|
||||
func (h *buildStreamHandler) Stream(w http.ResponseWriter, r *http.Request) {
|
||||
buildIDStr := chi.URLParam(r, "id")
|
||||
buildID, err := id.ParseBuildID(buildIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid build ID")
|
||||
return
|
||||
}
|
||||
|
||||
build, err := h.db.GetTemplateBuild(r.Context(), buildID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "build not found")
|
||||
return
|
||||
}
|
||||
|
||||
conn, _, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("build stream websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h.runStream(r.Context(), conn, build)
|
||||
}
|
||||
|
||||
func (h *buildStreamHandler) runStream(ctx context.Context, conn *websocket.Conn, build db.TemplateBuild) {
|
||||
ws := &wsWriter{conn: conn}
|
||||
buildIDStr := id.FormatBuildID(build.ID)
|
||||
|
||||
// Replay completed-step history from the DB log snapshot. lastStep is the
|
||||
// highest step number already delivered, used to dedup overlapping live
|
||||
// events for a step that finished between the DB read and the subscribe.
|
||||
lastStep := replayBuildHistory(ws, build)
|
||||
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "build-status",
|
||||
Status: build.Status,
|
||||
CurrentStep: build.CurrentStep,
|
||||
TotalSteps: build.TotalSteps,
|
||||
Error: build.Error,
|
||||
})
|
||||
|
||||
// A finished build has no live events to follow.
|
||||
if service.IsTerminalBuildStatus(build.Status) {
|
||||
return
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
events, release := h.broker.Subscribe(buildIDStr)
|
||||
defer release()
|
||||
|
||||
// Drain client reads so a disconnect cancels the stream. The client sends
|
||||
// nothing meaningful; any read error means the socket is gone.
|
||||
go func() {
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(buildStreamKeepalive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ws.writeJSON(map[string]string{"type": "ping"})
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Skip step events already covered by the history replay.
|
||||
if ev.Type != "build-status" && ev.Step > 0 && ev.Step <= lastStep {
|
||||
continue
|
||||
}
|
||||
ws.writeJSON(ev)
|
||||
if ev.Type == "build-status" && service.IsTerminalBuildStatus(ev.Status) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replayBuildHistory synthesizes step-start/output/step-end events from the
|
||||
// build's persisted log entries and writes them to the WebSocket. It returns
|
||||
// the highest step number replayed.
|
||||
func replayBuildHistory(ws *wsWriter, build db.TemplateBuild) int {
|
||||
if len(build.Logs) == 0 {
|
||||
return 0
|
||||
}
|
||||
var entries []recipe.BuildLogEntry
|
||||
if err := json.Unmarshal(build.Logs, &entries); err != nil {
|
||||
slog.Warn("build stream: bad log JSON", "build_id", id.FormatBuildID(build.ID), "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastStep := 0
|
||||
for _, e := range entries {
|
||||
ws.writeJSON(service.BuildStreamEvent{Type: "step-start", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd})
|
||||
if out := e.Stdout + e.Stderr; out != "" {
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "output",
|
||||
Step: e.Step,
|
||||
Data: base64.StdEncoding.EncodeToString([]byte(out)),
|
||||
})
|
||||
}
|
||||
ws.writeJSON(service.BuildStreamEvent{
|
||||
Type: "step-end", Step: e.Step, Phase: e.Phase, Cmd: e.Cmd,
|
||||
Exit: e.Exit, Ok: e.Ok, ElapsedMs: e.Elapsed,
|
||||
})
|
||||
if e.Step > lastStep {
|
||||
lastStep = e.Step
|
||||
}
|
||||
}
|
||||
return lastStep
|
||||
}
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
@ -20,7 +19,6 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
type buildHandler struct {
|
||||
@ -42,6 +40,7 @@ type createBuildRequest struct {
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
SkipPrePost bool `json:"skip_pre_post"`
|
||||
RunAsRoot bool `json:"run_as_root"`
|
||||
}
|
||||
|
||||
type buildResponse struct {
|
||||
@ -181,6 +180,7 @@ func (h *buildHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
SkipPrePost: req.SkipPrePost,
|
||||
RunAsRoot: req.RunAsRoot,
|
||||
Archive: archive,
|
||||
ArchiveName: archiveName,
|
||||
})
|
||||
@ -238,6 +238,9 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve actual on-disk sizes for templates with unknown size.
|
||||
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
|
||||
|
||||
type templateResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@ -246,6 +249,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Protected bool `json:"protected"`
|
||||
}
|
||||
|
||||
resp := make([]templateResponse, len(templates))
|
||||
@ -257,6 +261,7 @@ func (h *buildHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: t.MemoryMb,
|
||||
SizeBytes: t.SizeBytes,
|
||||
TeamID: id.FormatTeamID(t.TeamID),
|
||||
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp[i].CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -280,29 +285,17 @@ func (h *buildHandler) DeleteTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
// Broadcast delete to all online hosts.
|
||||
hosts, _ := h.db.ListActiveHosts(ctx)
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
}
|
||||
agent, err := h.pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(tmpl.TeamID),
|
||||
TemplateId: formatUUIDForRPC(tmpl.ID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("admin: failed to delete template on host", "host_id", id.FormatHostID(host.ID), "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
// Remove the files from every host before dropping the DB record, so a
|
||||
// failure leaves the template intact and retryable rather than orphaned.
|
||||
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusConflict, "delete_failed",
|
||||
"could not remove template files from all hosts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplate(ctx, tmpl.ID); err != nil {
|
||||
|
||||
@ -3,14 +3,11 @@ package api
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
@ -58,23 +55,11 @@ type backgroundExecResponse struct {
|
||||
|
||||
// Exec handles POST /v1/capsules/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, sandboxID, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,15 +101,7 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
|
||||
writeJSON(w, http.StatusAccepted, backgroundExecResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
@ -142,6 +119,8 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
Envs: req.Envs,
|
||||
Cwd: req.Cwd,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
@ -151,41 +130,24 @@ func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Update last active.
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
stdout := resp.Msg.Stdout
|
||||
stderr := resp.Msg.Stderr
|
||||
encoding := "utf-8"
|
||||
|
||||
encoding := "utf-8"
|
||||
stdoutStr, stderrStr := string(stdout), string(stderr)
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
return
|
||||
stdoutStr = base64.StdEncoding.EncodeToString(stdout)
|
||||
stderrStr = base64.StdEncoding.EncodeToString(stderr)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxIDStr,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
Stdout: stdoutStr,
|
||||
Stderr: stderrStr,
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -20,13 +19,12 @@ import (
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newExecStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@ -59,37 +57,9 @@ func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runExecStream(ctx, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
slog.Error("websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -186,18 +156,7 @@ func (h *execStreamHandler) runExecStream(ctx context.Context, conn *websocket.C
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
|
||||
@ -7,11 +7,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -30,23 +28,11 @@ func newFilesHandler(db *db.Queries, pool *lifecycle.HostClientPool) *filesHandl
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -108,23 +94,11 @@ type readFileRequest struct {
|
||||
// Download handles POST /v1/capsules/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -8,11 +8,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -30,23 +28,11 @@ func newFilesStreamHandler(db *db.Queries, pool *lifecycle.HostClientPool) *file
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -103,6 +89,12 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := agent.WriteFileStream(ctx)
|
||||
var streamClosed bool
|
||||
defer func() {
|
||||
if !streamClosed {
|
||||
_, _ = stream.CloseAndReceive()
|
||||
}
|
||||
}()
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
@ -141,6 +133,7 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
streamClosed = true
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
@ -153,23 +146,11 @@ func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request
|
||||
// StreamDownload handles POST /v1/capsules/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -4,11 +4,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
@ -58,23 +56,11 @@ type removeRequest struct {
|
||||
|
||||
// ListDir handles POST /v1/capsules/{id}/files/list.
|
||||
func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -115,23 +101,11 @@ func (h *fsHandler) ListDir(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// MakeDir handles POST /v1/capsules/{id}/files/mkdir.
|
||||
func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -166,23 +140,11 @@ func (h *fsHandler) MakeDir(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Remove handles POST /v1/capsules/{id}/files/remove.
|
||||
func (h *fsHandler) Remove(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@ -21,10 +22,11 @@ type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
monitor *HostMonitor
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al}
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries, al *audit.AuditLogger, monitor *HostMonitor) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries, audit: al, monitor: monitor}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
@ -98,6 +100,11 @@ type hostResponse struct {
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
RunningDiskMb int32 `json:"running_disk_mb"`
|
||||
PausedMemoryMb int32 `json:"paused_memory_mb"`
|
||||
PausedDiskMb int32 `json:"paused_disk_mb"`
|
||||
}
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
@ -136,12 +143,37 @@ func hostToResponse(h db.Host) hostResponse {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
resp.LastHeartbeatAt = &s
|
||||
}
|
||||
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
|
||||
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
|
||||
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
|
||||
return resp
|
||||
}
|
||||
|
||||
func hostToResponseWithLoad(h db.ListHostsByTeamRow) hostResponse {
|
||||
resp := hostToResponse(db.Host{
|
||||
ID: h.ID,
|
||||
Type: h.Type,
|
||||
TeamID: h.TeamID,
|
||||
Provider: h.Provider,
|
||||
AvailabilityZone: h.AvailabilityZone,
|
||||
Arch: h.Arch,
|
||||
CpuCores: h.CpuCores,
|
||||
MemoryMb: h.MemoryMb,
|
||||
DiskGb: h.DiskGb,
|
||||
Address: h.Address,
|
||||
Status: h.Status,
|
||||
LastHeartbeatAt: h.LastHeartbeatAt,
|
||||
CreatedBy: h.CreatedBy,
|
||||
CreatedAt: h.CreatedAt,
|
||||
UpdatedAt: h.UpdatedAt,
|
||||
})
|
||||
resp.RunningVcpus = h.RunningVcpus
|
||||
resp.RunningMemoryMb = h.RunningMemoryMb
|
||||
resp.RunningDiskMb = h.RunningDiskMb
|
||||
resp.PausedMemoryMb = h.PausedMemoryMb
|
||||
resp.PausedDiskMb = h.PausedDiskMb
|
||||
return resp
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID pgtype.UUID) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
@ -233,7 +265,7 @@ func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
resp[i] = hostToResponseWithLoad(host)
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
@ -335,6 +367,54 @@ func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// AdminList handles GET /v1/admin/hosts.
|
||||
// Returns all hosts with per-host resource consumption. Admin-only.
|
||||
func (h *hostHandler) AdminList(w http.ResponseWriter, r *http.Request) {
|
||||
hosts, err := h.svc.ListAdmin(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique team IDs to fetch team names.
|
||||
var teamNames map[string]string
|
||||
seen := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) > 0 {
|
||||
teamNames = make(map[string]string, len(seen))
|
||||
for _, host := range hosts {
|
||||
if !host.TeamID.Valid {
|
||||
continue
|
||||
}
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if _, ok := teamNames[key]; ok {
|
||||
continue
|
||||
}
|
||||
if team, err := h.queries.GetTeam(r.Context(), host.TeamID); err == nil {
|
||||
teamNames[key] = team.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponseWithLoad(db.ListHostsByTeamRow(host))
|
||||
if host.TeamID.Valid {
|
||||
key := id.FormatTeamID(host.TeamID)
|
||||
if name, ok := teamNames[key]; ok {
|
||||
resp[i].TeamName = &name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostIDStr := chi.URLParam(r, "id")
|
||||
@ -426,9 +506,12 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log marked_up if the host just recovered from unreachable.
|
||||
// If the host just recovered from unreachable, log it and trigger immediate
|
||||
// reconciliation so "missing" sandboxes are resolved without waiting for the
|
||||
// next monitor tick.
|
||||
if prevHost.Status == "unreachable" {
|
||||
h.audit.LogHostMarkedUp(r.Context(), prevHost.TeamID, hc.HostID)
|
||||
go h.monitor.ReconcileHost(context.Background(), hc.HostID)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
@ -2,9 +2,6 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@ -20,6 +17,8 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
@ -34,42 +33,62 @@ type meHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
jwtSecret []byte
|
||||
hmacKey []byte // HMAC key for OAuth state and link cookies (was jwtSecret)
|
||||
sessions *session.Service
|
||||
mailer email.Mailer
|
||||
oauthRegistry *oauth.Registry
|
||||
redirectURL string
|
||||
teamSvc *service.TeamService
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newMeHandler(
|
||||
db *db.Queries,
|
||||
pool *pgxpool.Pool,
|
||||
rdb *redis.Client,
|
||||
jwtSecret []byte,
|
||||
hmacKey []byte,
|
||||
sessions *session.Service,
|
||||
mailer email.Mailer,
|
||||
registry *oauth.Registry,
|
||||
redirectURL string,
|
||||
teamSvc *service.TeamService,
|
||||
hooks []cpextension.AuthHook,
|
||||
) *meHandler {
|
||||
return &meHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
rdb: rdb,
|
||||
jwtSecret: jwtSecret,
|
||||
hmacKey: hmacKey,
|
||||
sessions: sessions,
|
||||
mailer: mailer,
|
||||
oauthRegistry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
teamSvc: teamSvc,
|
||||
authHooks: hooks,
|
||||
}
|
||||
}
|
||||
|
||||
type meResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
HasPassword bool `json:"has_password"`
|
||||
Providers []string `json:"providers"`
|
||||
}
|
||||
|
||||
type sessionRow struct {
|
||||
ID string `json:"id"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastSeenAt string `json:"last_seen_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
Current bool `json:"current"`
|
||||
}
|
||||
|
||||
type updateNameRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
@ -116,14 +135,19 @@ func (h *meHandler) GetMe(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, meResponse{
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(ac.TeamID),
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: ac.Role,
|
||||
IsAdmin: user.IsAdmin,
|
||||
HasPassword: user.PasswordHash.Valid,
|
||||
Providers: providerNames,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and re-issues a JWT.
|
||||
// UpdateName handles PATCH /v1/me — updates the user's name and refreshes
|
||||
// any cached session blobs so the new name shows up on next request.
|
||||
func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
@ -148,31 +172,11 @@ func (h *meHandler) UpdateName(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.db.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get user")
|
||||
return
|
||||
if err := h.sessions.InvalidateCacheForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("update name: invalidate session cache failed", "error", err)
|
||||
}
|
||||
|
||||
team, role, err := loginTeam(ctx, h.db, ac.UserID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to get team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, team.ID, user.Email, req.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: id.FormatUserID(ac.UserID),
|
||||
TeamID: id.FormatTeamID(team.ID),
|
||||
Email: user.Email,
|
||||
Name: req.Name,
|
||||
})
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ChangePassword handles POST /v1/me/password.
|
||||
@ -235,6 +239,14 @@ func (h *meHandler) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke every session for this user — including the caller's — so a new
|
||||
// password resets all device access. Clear cookies on the response so the
|
||||
// caller is signed out immediately.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("change password: revoke sessions failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
|
||||
isAdding := !user.PasswordHash.Valid
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
@ -285,8 +297,8 @@ func (h *meHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
rawToken := generateResetToken()
|
||||
tokenHash := hashResetToken(rawToken)
|
||||
rawToken := generateOpaqueToken()
|
||||
tokenHash := hashOpaqueToken(rawToken)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
if err := h.rdb.Set(ctx, redisKey, id.FormatUserID(user.ID), passwordResetTTL).Err(); err != nil {
|
||||
@ -330,7 +342,7 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
tokenHash := hashResetToken(req.Token)
|
||||
tokenHash := hashOpaqueToken(req.Token)
|
||||
redisKey := passwordResetKeyPrefix + tokenHash
|
||||
|
||||
// GetDel atomically retrieves and removes the token in a single round-trip,
|
||||
@ -371,6 +383,11 @@ func (h *meHandler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset invalidates every active session for the user.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, userID); err != nil {
|
||||
slog.Warn("confirm password reset: revoke sessions failed", "error", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
@ -404,7 +421,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state+":"+"login")
|
||||
mac := computeHMAC(h.hmacKey, state+":"+"login")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state + ":" + mac + ":" + "login",
|
||||
@ -416,7 +433,7 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
userIDStr := id.FormatUserID(ac.UserID)
|
||||
linkMac := computeHMAC(h.jwtSecret, userIDStr)
|
||||
linkMac := computeHMAC(h.hmacKey, userIDStr)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_link_user_id",
|
||||
Value: userIDStr + ":" + linkMac,
|
||||
@ -552,6 +569,14 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke every active session and clear the caller's cookies.
|
||||
if err := h.sessions.RevokeAllForUser(ctx, ac.UserID); err != nil {
|
||||
slog.Warn("delete account: revoke sessions failed", "error", err)
|
||||
}
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
|
||||
fireOnSoftDelete(ctx, h.authHooks, ac.UserID)
|
||||
|
||||
slog.Info("account soft-deleted", "user_id", id.FormatUserID(ac.UserID), "email", user.Email)
|
||||
|
||||
go func() {
|
||||
@ -569,17 +594,4 @@ func (h *meHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func generateResetToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func hashResetToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
// (token helpers live in handlers_auth.go)
|
||||
|
||||
@ -14,10 +14,12 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
@ -25,18 +27,22 @@ import (
|
||||
type oauthHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
hmacKey []byte // HMAC key for OAuth state and link cookies
|
||||
sessions *session.Service
|
||||
registry *oauth.Registry
|
||||
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||
authHooks []cpextension.AuthHook
|
||||
}
|
||||
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, hmacKey []byte, sessions *session.Service, registry *oauth.Registry, redirectURL string, hooks []cpextension.AuthHook) *oauthHandler {
|
||||
return &oauthHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
jwtSecret: jwtSecret,
|
||||
hmacKey: hmacKey,
|
||||
sessions: sessions,
|
||||
registry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
authHooks: hooks,
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +67,7 @@ func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
intent = "login"
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state+":"+intent)
|
||||
mac := computeHMAC(h.hmacKey, state+":"+intent)
|
||||
cookieVal := state + ":" + mac + ":" + intent
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
@ -121,7 +127,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if len(parts) == 3 && parts[2] == "signup" {
|
||||
intent = "signup"
|
||||
}
|
||||
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce+":"+intent)), []byte(expectedMAC)) {
|
||||
if !hmac.Equal([]byte(computeHMAC(h.hmacKey, nonce+":"+intent)), []byte(expectedMAC)) {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
@ -164,7 +170,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Verify the HMAC to prevent cookie forgery.
|
||||
linkParts := strings.SplitN(linkCookie.Value, ":", 2)
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.jwtSecret, linkParts[0])), []byte(linkParts[1])) {
|
||||
if len(linkParts) != 2 || !hmac.Equal([]byte(computeHMAC(h.hmacKey, linkParts[0])), []byte(linkParts[1])) {
|
||||
slog.Warn("oauth link: invalid or tampered link cookie")
|
||||
http.Redirect(w, r, settingsBase+"?connect_error=invalid_state", http.StatusFound)
|
||||
return
|
||||
@ -244,13 +250,12 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
|
||||
slog.Error("oauth login: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
@ -374,10 +379,10 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", isFirstUser)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
// Fire OnSignup before session issuance — billing must succeed first.
|
||||
if hookErr := fireOnSignup(ctx, h.authHooks, userID, teamID, email); hookErr != nil {
|
||||
slog.Error("oauth signup: OnSignup hook failed", "user_id", id.FormatUserID(userID), "error", hookErr)
|
||||
redirectWithError(w, r, redirectBase, "signup_hook_failed")
|
||||
return
|
||||
}
|
||||
|
||||
@ -385,14 +390,19 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "wrenn_oauth_new_signup",
|
||||
Value: "1",
|
||||
Path: "/auth/",
|
||||
Path: "/",
|
||||
MaxAge: 60,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(userID), id.FormatTeamID(teamID), email, profile.Name)
|
||||
if err := h.issueSessionAndRedirect(w, r, userID, teamID, email, profile.Name, "owner", isFirstUser, redirectBase); err != nil {
|
||||
slog.Error("oauth: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
fireOnLogin(ctx, h.authHooks, userID)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
@ -431,33 +441,35 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
return
|
||||
}
|
||||
isAdmin := user.IsAdmin || isFirstUser
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
if err := h.issueSessionAndRedirect(w, r, user.ID, team.ID, user.Email, user.Name, role, isAdmin, redirectBase); err != nil {
|
||||
slog.Error("oauth: retry login: failed to issue session", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, id.FormatUserID(user.ID), id.FormatTeamID(team.ID), user.Email, user.Name)
|
||||
fireOnLogin(ctx, h.authHooks, user.ID)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email, name string) {
|
||||
// Set auth data as short-lived cookies instead of URL query parameters.
|
||||
// This prevents token leakage via server access logs, Referer headers, and browser history.
|
||||
for _, c := range []http.Cookie{
|
||||
{Name: "wrenn_oauth_token", Value: token},
|
||||
{Name: "wrenn_oauth_user_id", Value: userID},
|
||||
{Name: "wrenn_oauth_team_id", Value: teamID},
|
||||
{Name: "wrenn_oauth_email", Value: email},
|
||||
{Name: "wrenn_oauth_name", Value: name},
|
||||
} {
|
||||
c.Path = "/auth/"
|
||||
c.MaxAge = 60
|
||||
c.HttpOnly = false // frontend JS must read these
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
c.Secure = isSecure(r)
|
||||
http.SetCookie(w, &c)
|
||||
// issueSessionAndRedirect creates a session, sets the session and CSRF
|
||||
// cookies, and redirects to the frontend dashboard. The redirectBase param
|
||||
// is the OAuth callback URL; we ignore it after success and send the user to
|
||||
// /dashboard directly (callback page will probe /v1/me to hydrate state).
|
||||
func (h *oauthHandler) issueSessionAndRedirect(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
redirectBase string,
|
||||
) error {
|
||||
sess, err := h.sessions.Create(r.Context(), userID, teamID, email, name, role, isAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, base, http.StatusFound)
|
||||
setSessionCookies(w, sess.RawSID, sess.CSRFToken, isSecure(r))
|
||||
// Send the user to the callback page so the SPA can probe /v1/me and
|
||||
// trigger any post-OAuth UX (e.g. the new-signup name confirmation).
|
||||
http.Redirect(w, r, redirectBase, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
@ -477,7 +489,3 @@ func computeHMAC(key []byte, data string) string {
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -20,13 +19,12 @@ import (
|
||||
)
|
||||
|
||||
type processHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *processHandler {
|
||||
return &processHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newProcessHandler(db *db.Queries, pool *lifecycle.HostClientPool) *processHandler {
|
||||
return &processHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// processResponse is a single entry in the process list.
|
||||
@ -44,23 +42,11 @@ type processListResponse struct {
|
||||
|
||||
// ListProcesses handles GET /v1/capsules/{id}/processes.
|
||||
func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -95,24 +81,12 @@ func (h *processHandler) ListProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
// KillProcess handles DELETE /v1/capsules/{id}/processes/{selector}.
|
||||
// The selector can be a numeric PID or a string tag.
|
||||
func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
selectorStr := chi.URLParam(r, "selector")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
sb, _, sandboxIDStr, ok := requireRunningSandbox(w, r, h.db, ac.TeamID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@ -146,14 +120,6 @@ func (h *processHandler) KillProcess(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// wsProcessOut is the JSON message sent to the WebSocket client.
|
||||
type wsProcessOut struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ConnectProcess handles WS /v1/capsules/{id}/processes/{selector}/stream.
|
||||
func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxIDStr := chi.URLParam(r, "id")
|
||||
@ -166,37 +132,9 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate: use context from middleware (API key) or WS first message (JWT).
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
var authErr error
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, authErr = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, authErr = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if authErr != nil {
|
||||
sendProcessWSError(conn, "authentication failed")
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runConnectProcess(ctx, conn, ac, sandboxID, sandboxIDStr, selectorStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("process stream websocket upgrade failed", "error", err)
|
||||
slog.Error("process stream websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -207,17 +145,17 @@ func (h *processHandler) ConnectProcess(w http.ResponseWriter, r *http.Request)
|
||||
func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.Conn, ac auth.AuthContext, sandboxID pgtype.UUID, sandboxIDStr, selectorStr string) {
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox not found")
|
||||
sendWSError(conn, "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
sendProcessWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
sendWSError(conn, "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "sandbox host is not reachable")
|
||||
sendWSError(conn, "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
@ -236,7 +174,7 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
||||
|
||||
stream, err := agent.ConnectProcess(streamCtx, connect.NewRequest(connectReq))
|
||||
if err != nil {
|
||||
sendProcessWSError(conn, "failed to connect to process: "+err.Error())
|
||||
sendWSError(conn, "failed to connect to process: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
@ -257,42 +195,27 @@ func (h *processHandler) runConnectProcess(ctx context.Context, conn *websocket.
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ConnectProcessResponse_Start:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "start", PID: ev.Start.Pid})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ConnectProcessResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stdout", Data: string(o.Stdout)})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsProcessOut{Type: "stderr", Data: string(o.Stderr)})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ConnectProcessResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsProcessOut{Type: "exit", ExitCode: &exitCode})
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
if streamCtx.Err() == nil {
|
||||
sendProcessWSError(conn, err.Error())
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after process stream", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendProcessWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsProcessOut{Type: "error", Data: msg})
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
@ -30,13 +30,12 @@ const (
|
||||
)
|
||||
|
||||
type ptyHandler struct {
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
jwtSecret []byte
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
}
|
||||
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool, jwtSecret []byte) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
func newPtyHandler(db *db.Queries, pool *lifecycle.HostClientPool) *ptyHandler {
|
||||
return &ptyHandler{db: db, pool: pool}
|
||||
}
|
||||
|
||||
// --- WebSocket message types ---
|
||||
@ -90,40 +89,9 @@ func (h *ptyHandler) PtySession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// API key auth is handled by middleware (sets context).
|
||||
// For browser JWT auth, we authenticate after upgrade via first WS message.
|
||||
ac, hasAuth := auth.FromContext(ctx)
|
||||
|
||||
if !hasAuth {
|
||||
// No pre-upgrade auth — upgrade first, then authenticate via WS message.
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ws := &wsWriter{conn: conn}
|
||||
|
||||
var wsAC auth.AuthContext
|
||||
if isAdminWSRoute(ctx) {
|
||||
wsAC, err = wsAuthenticateAdmin(ctx, conn, h.jwtSecret, h.db)
|
||||
} else {
|
||||
wsAC, err = wsAuthenticate(ctx, conn, h.jwtSecret, h.db)
|
||||
}
|
||||
if err != nil {
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "authentication failed", Fatal: true})
|
||||
return
|
||||
}
|
||||
ac = wsAC
|
||||
|
||||
h.runPtySession(ctx, ws, conn, ac, sandboxID, sandboxIDStr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, ac, err := upgradeAndAuthenticate(w, r)
|
||||
if err != nil {
|
||||
slog.Error("pty websocket upgrade failed", "error", err)
|
||||
slog.Error("pty websocket upgrade/auth failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
@ -168,18 +136,7 @@ func (h *ptyHandler) runPtySession(ctx context.Context, ws *wsWriter, conn *webs
|
||||
ws.writeJSON(wsPtyOut{Type: "error", Data: "first message must be type 'start' or 'connect'", Fatal: true})
|
||||
}
|
||||
|
||||
// Update last active using a fresh context.
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after pty session", "sandbox_id", sandboxIDStr, "error", err)
|
||||
}
|
||||
updateLastActive(h.db, sandboxID, sandboxIDStr)
|
||||
}
|
||||
|
||||
func (h *ptyHandler) handleStart(
|
||||
|
||||
@ -37,8 +37,8 @@ type sandboxResponse struct {
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
DiskSizeMB int32 `json:"disk_size_mb"`
|
||||
DiskUsedMB *int64 `json:"disk_used_mb,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
@ -54,8 +54,7 @@ func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
VCPUs: sb.Vcpus,
|
||||
MemoryMB: sb.MemoryMb,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
DiskSizeMB: sb.DiskSizeMb,
|
||||
}
|
||||
if len(sb.Metadata) > 0 {
|
||||
var meta map[string]string
|
||||
@ -101,14 +100,17 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, req.Template, err)
|
||||
if err != nil {
|
||||
if sb.ID.Valid {
|
||||
h.audit.LogSandboxDestroySystem(r.Context(), ac.TeamID, sb.ID, "cleanup_after_create_error", nil)
|
||||
}
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxCreate(r.Context(), ac, sb.ID, sb.Template)
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/capsules.
|
||||
@ -145,7 +147,15 @@ func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
resp := sandboxToResponse(sb)
|
||||
|
||||
diskBytes, err := h.svc.GetDiskUsage(r.Context(), sandboxID, ac.TeamID)
|
||||
if err == nil {
|
||||
diskUsedMB := diskBytes / (1024 * 1024)
|
||||
resp.DiskUsedMB = &diskUsedMB
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/capsules/{id}/pause.
|
||||
@ -160,14 +170,14 @@ func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxPause(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/capsules/{id}/resume.
|
||||
@ -182,14 +192,14 @@ func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxResume(r.Context(), ac, sandboxID)
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/capsules/{id}/ping.
|
||||
@ -223,12 +233,13 @@ func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
err = h.svc.Destroy(r.Context(), sandboxID, ac.TeamID)
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID, err)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSandboxDestroy(r.Context(), ac, sandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
169
internal/api/handlers_sandbox_events.go
Normal file
169
internal/api/handlers_sandbox_events.go
Normal file
@ -0,0 +1,169 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
type sandboxEventHandler struct {
|
||||
db *db.Queries
|
||||
eventPub *channels.Publisher
|
||||
}
|
||||
|
||||
func newSandboxEventHandler(queries *db.Queries, eventPub *channels.Publisher) *sandboxEventHandler {
|
||||
return &sandboxEventHandler{db: queries, eventPub: eventPub}
|
||||
}
|
||||
|
||||
type sandboxEventRequest struct {
|
||||
Event string `json:"event"`
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
HostID string `json:"host_id"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Handle receives lifecycle event callbacks from host agents, translates the
|
||||
// raw host event into the canonical events.Event taxonomy, and publishes once
|
||||
// to the unified Redis stream. The SandboxEventConsumer (independent
|
||||
// consumer group) drives DB reconciliation; the channels dispatcher delivers
|
||||
// to subscribed channels; the SSE relay mirrors via Pub/Sub.
|
||||
func (h *sandboxEventHandler) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
var req sandboxEventRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Event == "" || req.SandboxID == "" || req.HostID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "event, sandbox_id, and host_id are required")
|
||||
return
|
||||
}
|
||||
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
callerHostID := id.FormatHostID(hc.HostID)
|
||||
if callerHostID != req.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host_id does not match authenticated host")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timestamp == 0 {
|
||||
req.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
evt, ok := h.translate(r.Context(), req)
|
||||
if !ok {
|
||||
// Unknown event type — log and accept so the host agent doesn't retry.
|
||||
slog.Warn("sandbox event callback: untranslatable event", "event", req.Event, "sandbox_id", req.SandboxID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
h.eventPub.Publish(r.Context(), evt)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// translate converts a raw host-agent callback into the canonical event.
|
||||
// For failure events without an in-flight verb (e.g. sandbox.failed), the
|
||||
// current DB status is consulted to pick the appropriate verb.
|
||||
func (h *sandboxEventHandler) translate(ctx context.Context, req sandboxEventRequest) (events.Event, bool) {
|
||||
sandboxUUID, parseErr := id.ParseSandboxID(req.SandboxID)
|
||||
if parseErr != nil {
|
||||
return events.Event{}, false
|
||||
}
|
||||
|
||||
var teamID pgtype.UUID
|
||||
if sb, dbErr := h.db.GetSandbox(ctx, sandboxUUID); dbErr == nil {
|
||||
teamID = sb.TeamID
|
||||
}
|
||||
|
||||
base := events.Event{
|
||||
Timestamp: time.Unix(req.Timestamp, 0).UTC().Format(time.RFC3339),
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: req.SandboxID, Type: "sandbox"},
|
||||
}
|
||||
|
||||
switch req.Event {
|
||||
case "sandbox.started":
|
||||
meta := map[string]string{}
|
||||
if req.HostIP != "" {
|
||||
meta["host_ip"] = req.HostIP
|
||||
}
|
||||
meta["host_id"] = req.HostID
|
||||
base.Event = events.CapsuleCreate
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = meta
|
||||
case "sandbox.resumed":
|
||||
meta := map[string]string{"host_id": req.HostID}
|
||||
if req.HostIP != "" {
|
||||
meta["host_ip"] = req.HostIP
|
||||
}
|
||||
base.Event = events.CapsuleResume
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = meta
|
||||
case "sandbox.paused":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
case "sandbox.auto_paused":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
base.Metadata = map[string]string{"reason": "ttl_expired"}
|
||||
case "sandbox.stopped":
|
||||
base.Event = events.CapsuleDestroy
|
||||
base.Outcome = events.OutcomeSuccess
|
||||
case "sandbox.pause_failed":
|
||||
base.Event = events.CapsulePause
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
case "sandbox.resume_failed":
|
||||
base.Event = events.CapsuleResume
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
case "sandbox.failed", "sandbox.error":
|
||||
// Pick a verb based on the sandbox's current DB status.
|
||||
verb := h.verbForFailure(ctx, sandboxUUID)
|
||||
base.Event = verb
|
||||
base.Outcome = events.OutcomeError
|
||||
base.Error = req.Error
|
||||
base.Metadata = map[string]string{"reason": "host_failure"}
|
||||
default:
|
||||
return events.Event{}, false
|
||||
}
|
||||
return base, true
|
||||
}
|
||||
|
||||
func (h *sandboxEventHandler) verbForFailure(ctx context.Context, sandboxID pgtype.UUID) string {
|
||||
sb, err := h.db.GetSandbox(ctx, sandboxID)
|
||||
if err != nil {
|
||||
return events.CapsuleDestroy
|
||||
}
|
||||
switch sb.Status {
|
||||
case "starting":
|
||||
return events.CapsuleCreate
|
||||
case "resuming":
|
||||
return events.CapsuleResume
|
||||
case "pausing":
|
||||
return events.CapsulePause
|
||||
case "snapshotting":
|
||||
// A snapshot pauses then resumes the VM; a host-side failure leaves the
|
||||
// sandbox errored, not destroyed. Route through CapsuleCreate so the
|
||||
// consumer's handleFailed marks it "error" rather than removing the row.
|
||||
return events.CapsuleCreate
|
||||
default:
|
||||
return events.CapsuleDestroy
|
||||
}
|
||||
}
|
||||
66
internal/api/handlers_sessions.go
Normal file
66
internal/api/handlers_sessions.go
Normal file
@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
)
|
||||
|
||||
type sessionsHandler struct {
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newSessionsHandler(svc *session.Service) *sessionsHandler {
|
||||
return &sessionsHandler{sessions: svc}
|
||||
}
|
||||
|
||||
// List handles GET /v1/me/sessions — returns all active sessions for the
|
||||
// current user, flagging the caller's own session as current.
|
||||
func (h *sessionsHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
rows, err := h.sessions.ListForUser(r.Context(), ac.UserID)
|
||||
if err != nil {
|
||||
slog.Error("list sessions: db error", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sessions")
|
||||
return
|
||||
}
|
||||
out := make([]sessionRow, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, sessionRow{
|
||||
ID: row.ID,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
CreatedAt: row.CreatedAt.Time.UTC().Format(time.RFC3339),
|
||||
LastSeenAt: row.LastSeenAt.Time.UTC().Format(time.RFC3339),
|
||||
ExpiresAt: row.ExpiresAt.Time.UTC().Format(time.RFC3339),
|
||||
Current: row.ID == ac.SessionID,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"sessions": out})
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/me/sessions/{id} — revokes a single session
|
||||
// belonging to the current user. If the caller revokes their own session,
|
||||
// their cookies are cleared on the response.
|
||||
func (h *sessionsHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sid := chi.URLParam(r, "id")
|
||||
if sid == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "missing session id")
|
||||
return
|
||||
}
|
||||
if err := h.sessions.DeleteForUser(r.Context(), sid, ac.UserID); err != nil {
|
||||
slog.Error("delete session: db error", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete session")
|
||||
return
|
||||
}
|
||||
if sid == ac.SessionID {
|
||||
clearSessionCookies(w, isSecure(r))
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -25,49 +24,53 @@ import (
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
svc *service.TemplateService
|
||||
sandboxSvc *service.SandboxService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
// deleteSnapshotEverywhere removes a template's files from every active host.
|
||||
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
|
||||
//
|
||||
// It is strict by design: deletion is reported successful only when every
|
||||
// active host has either removed the files or reported NotFound (it never
|
||||
// held them). If any host is offline or returns an error, it returns an error
|
||||
// and the caller MUST NOT delete the DB record — doing so would orphan the
|
||||
// files on disk with no record left to retry against.
|
||||
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
|
||||
id.FormatHostID(host.ID), host.Status)
|
||||
}
|
||||
agent, err := pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(teamID),
|
||||
TemplateId: formatUUIDForRPC(templateID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
// NotFound just means this host never held the template.
|
||||
if connect.CodeOf(err) == connect.CodeNotFound {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@ -76,6 +79,7 @@ type snapshotResponse struct {
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Protected bool `json:"protected"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
Platform: t.TeamID == id.PlatformTeamID,
|
||||
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
|
||||
}
|
||||
if t.Vcpus != 0 {
|
||||
resp.VCPUs = &t.Vcpus
|
||||
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots.
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
|
||||
// registers the result as a new template.
|
||||
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SandboxID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(req.SandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
|
||||
// template is registered by a background goroutine; clients learn of the
|
||||
// result via the SSE template.snapshot.create event (or by polling).
|
||||
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
|
||||
// The host agent's CreateSnapshot removes the sandbox from its in-memory
|
||||
// map immediately; if the reconciler fires during the flatten window and
|
||||
// the DB still says "running", it will mark the sandbox "stopped".
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context with a generous timeout so the snapshot completes
|
||||
// even if the client disconnects (the flatten step can take 10-20s).
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
// Generate the new template ID upfront so the host agent knows where to store files.
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(ac.TeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status back to what it was.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/snapshots.
|
||||
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve actual on-disk sizes for templates with unknown size (e.g.
|
||||
// system base templates seeded with size_bytes = 0). This queries a host
|
||||
// agent and persists the result to the DB for subsequent requests.
|
||||
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
|
||||
|
||||
resp := make([]snapshotResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateToResponse(t)
|
||||
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusConflict, "delete_failed",
|
||||
"could not remove snapshot files from all hosts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name)
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
79
internal/api/handlers_sse.go
Normal file
79
internal/api/handlers_sse.go
Normal file
@ -0,0 +1,79 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const sseKeepaliveInterval = 30 * time.Second
|
||||
|
||||
type sseHandler struct {
|
||||
broker *SSEBroker
|
||||
}
|
||||
|
||||
func newSSEHandler(broker *SSEBroker) *sseHandler {
|
||||
return &sseHandler{broker: broker}
|
||||
}
|
||||
|
||||
// Stream handles GET /v1/events/stream. Authentication is performed by
|
||||
// upstream middleware: browser clients use the wrenn_sid cookie (which the
|
||||
// EventSource API forwards automatically); SDK clients use X-API-Key.
|
||||
func (h *sseHandler) Stream(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "session cookie or X-API-Key required")
|
||||
return
|
||||
}
|
||||
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), false)
|
||||
}
|
||||
|
||||
// AdminStream handles GET /v1/admin/events/stream. Upstream middleware
|
||||
// must enforce session auth + requireAdmin before reaching this handler.
|
||||
func (h *sseHandler) AdminStream(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "admin session required")
|
||||
return
|
||||
}
|
||||
h.serveSSE(w, r, id.FormatTeamID(ac.TeamID), true)
|
||||
}
|
||||
|
||||
func (h *sseHandler) serveSSE(w http.ResponseWriter, r *http.Request, teamID string, isAdmin bool) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
subID, ch := h.broker.Subscribe(teamID, isAdmin)
|
||||
defer h.broker.Unsubscribe(subID)
|
||||
|
||||
fmt.Fprintf(w, "event: connected\ndata: {\"message\":\"connected\"}\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
keepalive := time.NewTicker(sseKeepaliveInterval)
|
||||
defer keepalive.Stop()
|
||||
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", msg.EventType, msg.Data)
|
||||
flusher.Flush()
|
||||
case <-keepalive.C:
|
||||
fmt.Fprintf(w, ": keepalive\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -14,19 +14,21 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/internal/email"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type teamHandler struct {
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
svc *service.TeamService
|
||||
audit *audit.AuditLogger
|
||||
mailer email.Mailer
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer}
|
||||
func newTeamHandler(svc *service.TeamService, al *audit.AuditLogger, mailer email.Mailer, sessions *session.Service) *teamHandler {
|
||||
return &teamHandler{svc: svc, audit: al, mailer: mailer, sessions: sessions}
|
||||
}
|
||||
|
||||
// teamResponse is the JSON shape for a team.
|
||||
@ -366,6 +368,11 @@ func (h *teamHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Drop cached session blobs so the new role propagates immediately.
|
||||
if err := h.sessions.InvalidateCacheForUser(r.Context(), targetUserID); err != nil {
|
||||
_ = err
|
||||
}
|
||||
|
||||
h.audit.LogMemberRoleUpdate(r.Context(), ac, targetUserID, req.Role)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -450,6 +457,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
ActiveSandboxCount int32 `json:"active_sandbox_count"`
|
||||
ChannelCount int32 `json:"channel_count"`
|
||||
RunningVcpus int32 `json:"running_vcpus"`
|
||||
RunningMemoryMb int32 `json:"running_memory_mb"`
|
||||
}
|
||||
|
||||
resp := make([]adminTeamResponse, len(teams))
|
||||
@ -465,6 +474,8 @@ func (h *teamHandler) AdminListTeams(w http.ResponseWriter, r *http.Request) {
|
||||
OwnerEmail: t.OwnerEmail,
|
||||
ActiveSandboxCount: t.ActiveSandboxCount,
|
||||
ChannelCount: t.ChannelCount,
|
||||
RunningVcpus: t.RunningVcpus,
|
||||
RunningMemoryMb: t.RunningMemoryMb,
|
||||
}
|
||||
if t.DeletedAt != nil {
|
||||
s := t.DeletedAt.Format(time.RFC3339)
|
||||
|
||||
@ -11,19 +11,21 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
)
|
||||
|
||||
type usersHandler struct {
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
audit *audit.AuditLogger
|
||||
db *db.Queries
|
||||
svc *service.UserService
|
||||
audit *audit.AuditLogger
|
||||
sessions *session.Service
|
||||
}
|
||||
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc, audit: al}
|
||||
func newUsersHandler(db *db.Queries, svc *service.UserService, al *audit.AuditLogger, sessions *session.Service) *usersHandler {
|
||||
return &usersHandler{db: db, svc: svc, audit: al, sessions: sessions}
|
||||
}
|
||||
|
||||
// Search handles GET /v1/users/search?email=<prefix>
|
||||
@ -158,6 +160,10 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Active {
|
||||
h.audit.LogUserActivate(r.Context(), ac, userID, user.Email)
|
||||
} else {
|
||||
// Disabled users must be kicked out of every active session.
|
||||
if err := h.sessions.RevokeAllForUser(r.Context(), userID); err != nil {
|
||||
_ = err
|
||||
}
|
||||
h.audit.LogUserDeactivate(r.Context(), ac, userID, user.Email)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
@ -215,5 +221,14 @@ func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email)
|
||||
}
|
||||
|
||||
// Invalidate cached session blobs so the new is_admin flag is reflected
|
||||
// on the user's next request without waiting for the Redis TTL.
|
||||
if err := h.sessions.InvalidateCacheForUser(r.Context(), userID); err != nil {
|
||||
// Cache is best-effort; the DB is authoritative and requireAdmin
|
||||
// always re-reads it.
|
||||
_ = err
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
||||
type ctxKeyAdminWS struct{}
|
||||
|
||||
// setAdminWSFlag marks the context as an admin WebSocket route.
|
||||
func setAdminWSFlag(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
||||
}
|
||||
|
||||
// isAdminWSRoute checks if the request context was marked as admin WS.
|
||||
func isAdminWSRoute(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
||||
type wsAuthMsg struct {
|
||||
Type string `json:"type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
||||
// authenticated context. The caller must send this as the first message after
|
||||
// connecting.
|
||||
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
var msg wsAuthMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
if msg.Type != "auth" || msg.Token == "" {
|
||||
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
||||
}
|
||||
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
||||
}
|
||||
|
||||
return auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
||||
// returning an AuthContext with the platform team ID.
|
||||
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
||||
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, err
|
||||
}
|
||||
|
||||
user, err := queries.GetUserByID(ctx, ac.UserID)
|
||||
if err != nil {
|
||||
return auth.AuthContext{}, fmt.Errorf("user not found")
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
||||
}
|
||||
|
||||
ac.TeamID = id.PlatformTeamID
|
||||
return ac, nil
|
||||
}
|
||||
@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
@ -15,10 +16,30 @@ import (
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
)
|
||||
|
||||
// errInferredTransientTimeout marks a state change that the reconciler
|
||||
// inferred after a transient (starting/resuming) sandbox failed to settle
|
||||
// within the grace period. Used as the err value on system audit calls so
|
||||
// the published event carries Outcome=error with a human-readable message.
|
||||
var errInferredTransientTimeout = errors.New("transient state did not settle within grace period")
|
||||
|
||||
// unreachableThreshold is how long a host can go without a heartbeat before
|
||||
// it is considered unreachable (3 missed 30-second heartbeats).
|
||||
const unreachableThreshold = 90 * time.Second
|
||||
|
||||
// transientGracePeriod is how long a sandbox is allowed to stay in a transient
|
||||
// status (starting, resuming, pausing, stopping) before the monitor infers a
|
||||
// final state. This prevents the monitor from racing against in-flight RPCs
|
||||
// that may not have registered the sandbox on the host agent yet.
|
||||
const transientGracePeriod = 2 * time.Minute
|
||||
|
||||
// snapshotGracePeriod is the grace for a sandbox stuck in "snapshotting" while
|
||||
// the VM is still alive on the host. Snapshots dump guest RAM and flatten the
|
||||
// rootfs, which can run for minutes on large sandboxes, and the agent reports
|
||||
// the VM as alive throughout — so we must not race the in-flight operation.
|
||||
// It exceeds the background goroutine's 10-minute deadline, so reaching it
|
||||
// means the control plane crashed mid-snapshot and the sandbox needs recovery.
|
||||
const snapshotGracePeriod = 15 * time.Minute
|
||||
|
||||
// HostMonitor runs on a fixed interval and performs two duties:
|
||||
//
|
||||
// 1. Passive check: marks hosts whose last_heartbeat_at is stale as
|
||||
@ -77,6 +98,21 @@ func (m *HostMonitor) run(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// ReconcileHost triggers immediate active reconciliation for a single host.
|
||||
// Called when a host transitions from unreachable → online so sandboxes marked
|
||||
// "missing" are resolved without waiting for the next monitor tick.
|
||||
func (m *HostMonitor) ReconcileHost(ctx context.Context, hostID pgtype.UUID) {
|
||||
host, err := m.db.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: reconcile-on-connect: failed to get host", "error", err)
|
||||
return
|
||||
}
|
||||
if host.Status != "online" {
|
||||
return
|
||||
}
|
||||
m.checkHost(ctx, host)
|
||||
}
|
||||
|
||||
func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
// --- Passive phase: check heartbeat staleness ---
|
||||
|
||||
@ -116,21 +152,29 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build set of sandbox IDs alive on the host.
|
||||
// The host agent returns sandbox IDs as strings (formatted with prefix).
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
// Build map of sandbox ID -> reported status. Transient statuses
|
||||
// (pausing/resuming/starting/stopping) are coerced to a presence-only
|
||||
// entry: ListSandboxes can observe the in-memory status mid-transition
|
||||
// (Pause flips the status under m.mu while List holds m.mu.RLock), and
|
||||
// writing those transient labels into the DB would force the transient
|
||||
// reconciliation phase to wait the full grace period before resolving.
|
||||
// Recording the presence keeps "missing → restore" and "running →
|
||||
// orphan-stop" logic correct without overwriting with stale labels;
|
||||
// the next monitor tick reads the settled status.
|
||||
aliveStatus := make(map[string]string, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
autoPaused := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, apID := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPaused[apID] = struct{}{}
|
||||
status := sb.Status
|
||||
switch status {
|
||||
case "pausing", "resuming", "starting", "stopping":
|
||||
status = ""
|
||||
}
|
||||
aliveStatus[sb.SandboxId] = status
|
||||
}
|
||||
|
||||
// --- Restore sandboxes that are "missing" in DB but alive on host ---
|
||||
// This handles the case where CP marked them missing due to a transient
|
||||
// heartbeat gap, but the host was actually fine.
|
||||
// Handles transient heartbeat gaps where the host was actually fine. The
|
||||
// reported status must be honored: a sandbox the agent paused while CP
|
||||
// was disconnected must not be silently promoted back to running.
|
||||
|
||||
missingSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
@ -139,34 +183,65 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
var toRestore []pgtype.UUID
|
||||
var toStop []pgtype.UUID
|
||||
restoreByStatus := make(map[string][]db.Sandbox)
|
||||
var toStop []db.Sandbox
|
||||
for _, sb := range missingSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
toRestore = append(toRestore, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok {
|
||||
toStop = append(toStop, sb)
|
||||
continue
|
||||
}
|
||||
if status == "" {
|
||||
continue
|
||||
}
|
||||
restoreByStatus[status] = append(restoreByStatus[status], sb)
|
||||
}
|
||||
if len(toRestore) > 0 {
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toRestore))
|
||||
if err := m.db.BulkRestoreRunning(ctx, toRestore); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
for status, sbs := range restoreByStatus {
|
||||
ids := make([]pgtype.UUID, len(sbs))
|
||||
for i, sb := range sbs {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: restoring missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
|
||||
if err := m.db.BulkRestoreMissingToStatus(ctx, db.BulkRestoreMissingToStatusParams{
|
||||
Column1: ids,
|
||||
Status: status,
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to restore missing sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
|
||||
continue
|
||||
}
|
||||
// Only restore→paused emits a notification (per design: running restore is silent).
|
||||
if status == "paused" {
|
||||
for _, sb := range sbs {
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "restored_after_host_recovery", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
ids := make([]pgtype.UUID, len(toStop))
|
||||
for i, sb := range toStop {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: stopping confirmed-dead missing sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Column1: ids,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to stop missing sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range toStop {
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Find running sandboxes in DB that are no longer alive on the host ---
|
||||
// --- Reconcile running sandboxes in DB against live host state ---
|
||||
// Three cases per DB-running row:
|
||||
// absent on host -> stopped
|
||||
// present and running -> no change
|
||||
// present but paused/etc. -> sync DB to reported status (catches the
|
||||
// shutdown-pause notify failure case)
|
||||
|
||||
runningSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
@ -177,40 +252,196 @@ func (m *HostMonitor) checkHost(ctx context.Context, host db.Host) {
|
||||
return
|
||||
}
|
||||
|
||||
var toPause, toStop []pgtype.UUID
|
||||
sbTeamID := make(map[pgtype.UUID]pgtype.UUID, len(runningSandboxes))
|
||||
var toStop []db.Sandbox
|
||||
syncByStatus := make(map[string][]db.Sandbox)
|
||||
for _, sb := range runningSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
sbTeamID[sb.ID] = sb.TeamID
|
||||
if _, ok := alive[sbIDStr]; ok {
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok {
|
||||
toStop = append(toStop, sb)
|
||||
continue
|
||||
}
|
||||
if _, ok := autoPaused[sbIDStr]; ok {
|
||||
toPause = append(toPause, sb.ID)
|
||||
} else {
|
||||
toStop = append(toStop, sb.ID)
|
||||
if status == "running" || status == "" {
|
||||
continue
|
||||
}
|
||||
syncByStatus[status] = append(syncByStatus[status], sb)
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("host monitor: marking auto-paused sandboxes", "host_id", id.FormatHostID(host.ID), "count", len(toPause))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark paused", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
}
|
||||
for _, sbID := range toPause {
|
||||
m.audit.LogSandboxAutoPause(ctx, sbTeamID[sbID], sbID)
|
||||
}
|
||||
}
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(toStop))
|
||||
ids := make([]pgtype.UUID, len(toStop))
|
||||
for i, sb := range toStop {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: marking orphaned sandboxes stopped", "host_id", id.FormatHostID(host.ID), "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Column1: ids,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to mark stopped", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range toStop {
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "orphaned", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
for status, sbs := range syncByStatus {
|
||||
ids := make([]pgtype.UUID, len(sbs))
|
||||
for i, sb := range sbs {
|
||||
ids[i] = sb.ID
|
||||
}
|
||||
slog.Info("host monitor: syncing running→reported status", "host_id", id.FormatHostID(host.ID), "status", status, "count", len(ids))
|
||||
if err := m.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: ids,
|
||||
Status: status,
|
||||
}); err != nil {
|
||||
slog.Warn("host monitor: failed to sync running sandboxes", "host_id", id.FormatHostID(host.ID), "status", status, "error", err)
|
||||
continue
|
||||
}
|
||||
if status == "paused" {
|
||||
for _, sb := range sbs {
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "host_state_sync", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Reconcile DB-stopped + agent-paused zombies ---
|
||||
// A sandbox the agent reports as 'paused' but DB has as 'stopped' is an
|
||||
// orphan from a previous bug where a successful pause's auto_paused
|
||||
// callback was lost (e.g. CP unreachable during agent shutdown). With the
|
||||
// agent-side fix (RestorePausedSandboxes), the snapshot survives across
|
||||
// agent restarts and surfaces here. Authoritative direction: DB wins
|
||||
// (user already saw 'stopped' and may have stopped tracking it).
|
||||
// Issue Destroy so the on-disk snapshot dir is removed and the agent's
|
||||
// slot reservation released.
|
||||
//
|
||||
// Gate: only run the DB query if the agent reports at least one paused
|
||||
// sandbox. Otherwise we'd fetch every historically-stopped sandbox on
|
||||
// this host every monitor tick — unbounded growth over a host's lifetime.
|
||||
hasPaused := false
|
||||
for _, status := range aliveStatus {
|
||||
if status == "paused" {
|
||||
hasPaused = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasPaused {
|
||||
stoppedSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"stopped"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list stopped sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
} else {
|
||||
for _, sb := range stoppedSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
status, ok := aliveStatus[sbIDStr]
|
||||
if !ok || status != "paused" {
|
||||
continue
|
||||
}
|
||||
slog.Info("host monitor: destroying DB-stopped agent-paused zombie",
|
||||
"host_id", id.FormatHostID(host.ID), "sandbox_id", sbIDStr)
|
||||
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sbIDStr,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("host monitor: zombie destroy failed",
|
||||
"sandbox_id", sbIDStr, "error", err)
|
||||
continue
|
||||
}
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "paused_zombie_cleanup", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Reconcile transient statuses (starting, resuming, pausing, stopping) ---
|
||||
// These represent in-flight operations. If the sandbox is no longer alive on
|
||||
// the host, infer the final state based on the transient status.
|
||||
|
||||
transientSandboxes, err := m.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: host.ID,
|
||||
Column2: []string{"starting", "resuming", "pausing", "stopping", "snapshotting"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("host monitor: failed to list transient sandboxes", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, sb := range transientSandboxes {
|
||||
sbIDStr := id.FormatSandboxID(sb.ID)
|
||||
if agentStatus, ok := aliveStatus[sbIDStr]; ok {
|
||||
// Sandbox is alive on host — the background goroutine should
|
||||
// finalize the transition. For starting/resuming, if the sandbox
|
||||
// is alive it means creation/resume succeeded.
|
||||
if sb.Status == "starting" || sb.Status == "resuming" {
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: sb.Status, Status_2: "running",
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: promoted transient sandbox to running", "sandbox_id", sbIDStr, "from", sb.Status)
|
||||
}
|
||||
}
|
||||
// A snapshot keeps the source sandbox alive throughout, so an alive
|
||||
// sandbox does NOT mean the snapshot finished. Only recover it once
|
||||
// it has been stuck past the snapshot grace period (i.e. the CP
|
||||
// crashed mid-op). Recover to the sandbox's actual host-side status:
|
||||
// a running sandbox is snapshotted live and stays running, but a
|
||||
// paused sandbox is snapshotted from disk and must return to paused.
|
||||
if sb.Status == "snapshotting" &&
|
||||
sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) >= snapshotGracePeriod {
|
||||
recoverTo := agentStatus
|
||||
if recoverTo != "running" && recoverTo != "paused" {
|
||||
// Coerced/unknown agent label — default to running.
|
||||
recoverTo = "running"
|
||||
}
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: "snapshotting", Status_2: recoverTo,
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: recovered stuck snapshotting sandbox", "sandbox_id", sbIDStr, "to", recoverTo)
|
||||
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "snapshot_recovered", nil)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Sandbox is not alive on host. If the transition is recent, give the
|
||||
// in-flight RPC time to finish before declaring a final state.
|
||||
if sb.LastUpdated.Valid && time.Since(sb.LastUpdated.Time) < transientGracePeriod {
|
||||
slog.Debug("host monitor: transient sandbox still within grace period",
|
||||
"sandbox_id", sbIDStr, "status", sb.Status,
|
||||
"age", time.Since(sb.LastUpdated.Time).Round(time.Second))
|
||||
continue
|
||||
}
|
||||
|
||||
// Grace period expired — infer final state.
|
||||
var finalStatus string
|
||||
switch sb.Status {
|
||||
case "starting", "resuming":
|
||||
finalStatus = "error"
|
||||
case "pausing":
|
||||
finalStatus = "paused"
|
||||
case "stopping":
|
||||
finalStatus = "stopped"
|
||||
case "snapshotting":
|
||||
// VM is gone but DB says snapshotting → the snapshot died with the VM.
|
||||
finalStatus = "error"
|
||||
}
|
||||
fromStatus := sb.Status
|
||||
if _, err := m.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sb.ID, Status: fromStatus, Status_2: finalStatus,
|
||||
}); err == nil {
|
||||
slog.Info("host monitor: resolved transient sandbox", "sandbox_id", sbIDStr, "from", fromStatus, "to", finalStatus)
|
||||
inferredErr := errInferredTransientTimeout
|
||||
switch fromStatus {
|
||||
case "starting":
|
||||
m.audit.LogSandboxCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "resuming":
|
||||
m.audit.LogSandboxResumeSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "pausing":
|
||||
// Pause assumed to have succeeded host-side; emit success with inferred metadata.
|
||||
m.audit.LogSandboxAutoPause(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
|
||||
case "snapshotting":
|
||||
// VM gone mid-snapshot; the sandbox is errored.
|
||||
m.audit.LogSnapshotCreateSystem(ctx, sb.TeamID, sb.ID, "transient_timeout", inferredErr)
|
||||
case "stopping":
|
||||
m.audit.LogSandboxDestroySystem(ctx, sb.TeamID, sb.ID, "transient_timeout_inferred", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +50,9 @@ func agentErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition, connect.CodeAlreadyExists:
|
||||
case connect.CodeAlreadyExists:
|
||||
return http.StatusConflict, "already_exists", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
case connect.CodePermissionDenied:
|
||||
return http.StatusForbidden, "forbidden", err.Error()
|
||||
|
||||
@ -26,19 +26,10 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// markAdminWS flags the request context as an admin WebSocket route.
|
||||
// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin
|
||||
// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin.
|
||||
func markAdminWS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context())))
|
||||
})
|
||||
}
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
// Must run after requireSession (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the session blob's is_admin flag is for
|
||||
// UI hints; the DB is the source of truth for admin access.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
// Both stamp TeamID into the request context via auth.AuthContext.
|
||||
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key first.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token from Authorization header.
|
||||
tokenStr := ""
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr != "" {
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth failed", "error", err, "ip", r.RemoteAddr)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject
|
||||
// unauthenticated requests. It injects auth context when valid credentials
|
||||
// are present (supporting SDK clients that set X-API-Key on WebSocket
|
||||
// upgrades) and passes through otherwise so the handler can authenticate
|
||||
// after the WebSocket upgrade via the first message.
|
||||
func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err == nil {
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil {
|
||||
if teamID, err := id.ParseTeamID(claims.TeamID); err == nil {
|
||||
if userID, err := id.ParseUserID(claims.Subject); err == nil {
|
||||
if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" {
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No valid credentials — pass through for handler to authenticate.
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/api/middleware_csrf.go
Normal file
38
internal/api/middleware_csrf.go
Normal file
@ -0,0 +1,38 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
)
|
||||
|
||||
// requireCSRF enforces double-submit CSRF: the wrenn_csrf cookie value must
|
||||
// equal the X-CSRF-Token header. Skipped for safe methods (GET/HEAD/OPTIONS)
|
||||
// and for requests authenticated via X-API-Key (SDK clients are not
|
||||
// vulnerable to cross-site request forgery against cookie auth).
|
||||
func requireCSRF() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// API-key auth path: no CSRF check needed.
|
||||
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie(csrfCookieName)
|
||||
header := r.Header.Get("X-CSRF-Token")
|
||||
if err != nil || cookie.Value == "" || header == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
|
||||
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,69 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// requireJWT validates a JWT from the Authorization: Bearer header.
|
||||
// It also verifies the user is still active in the database.
|
||||
// WebSocket upgrade requests without an Authorization header are passed through
|
||||
// — WS handlers authenticate via the first message after upgrade.
|
||||
func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var tokenStr string
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr = strings.TrimPrefix(header, "Bearer ")
|
||||
}
|
||||
if tokenStr == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(claims.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid team ID in token")
|
||||
return
|
||||
}
|
||||
userID, err := id.ParseUserID(claims.Subject)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid user ID in token")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify user is still active in the database.
|
||||
user, err := queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
slog.Warn("jwt auth: failed to look up user", "user_id", claims.Subject, "error", err)
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "user not found")
|
||||
return
|
||||
}
|
||||
if user.Status != "active" {
|
||||
writeError(w, http.StatusForbidden, "account_deactivated", "your account has been deactivated — contact your administrator to regain access")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: teamID,
|
||||
UserID: userID,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
IsAdmin: claims.IsAdmin,
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
34
internal/api/middleware_session.go
Normal file
34
internal/api/middleware_session.go
Normal file
@ -0,0 +1,34 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
sessionmw "git.omukk.dev/wrenn/wrenn/pkg/auth/session/middleware"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
// Internal aliases — the canonical implementations live in the public
|
||||
// pkg/auth/session/middleware package so cloud extensions can call them.
|
||||
|
||||
const csrfCookieName = sessionmw.CSRFCookieName
|
||||
|
||||
func requireSession(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSession(svc, queries)
|
||||
}
|
||||
|
||||
func requireSessionOrAPIKey(queries *db.Queries, svc *session.Service) func(http.Handler) http.Handler {
|
||||
return sessionmw.RequireSessionOrAPIKey(svc, queries)
|
||||
}
|
||||
|
||||
func setSessionCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
|
||||
sessionmw.SetCookies(w, sid, csrfToken, secure)
|
||||
}
|
||||
|
||||
func clearSessionCookies(w http.ResponseWriter, secure bool) {
|
||||
sessionmw.ClearCookies(w, secure)
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return sessionmw.IsSecure(r)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
310
internal/api/sandbox_event_consumer.go
Normal file
310
internal/api/sandbox_event_consumer.go
Normal file
@ -0,0 +1,310 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const (
|
||||
unifiedEventStream = "wrenn:events"
|
||||
reconcilerConsumerGrp = "wrenn-sandbox-reconciler-v1"
|
||||
reconcilerConsumer = "cp-0"
|
||||
)
|
||||
|
||||
// SandboxEventConsumer reads capsule lifecycle events from the unified Redis
|
||||
// stream and drives DB state reconciliation. Uses an independent consumer
|
||||
// group so its cursor is separate from the channels dispatcher.
|
||||
type SandboxEventConsumer struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
audit *audit.AuditLogger
|
||||
hooks []cpextension.SandboxEventHook
|
||||
}
|
||||
|
||||
// NewSandboxEventConsumer creates a consumer.
|
||||
func NewSandboxEventConsumer(rdb *redis.Client, queries *db.Queries, al *audit.AuditLogger, hooks []cpextension.SandboxEventHook) *SandboxEventConsumer {
|
||||
return &SandboxEventConsumer{rdb: rdb, db: queries, audit: al, hooks: hooks}
|
||||
}
|
||||
|
||||
// Start launches the consumer goroutine. Reads from "$" so prior history
|
||||
// is not replayed.
|
||||
func (c *SandboxEventConsumer) Start(ctx context.Context) {
|
||||
go c.run(ctx)
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) run(ctx context.Context) {
|
||||
err := c.rdb.XGroupCreateMkStream(ctx, unifiedEventStream, reconcilerConsumerGrp, "$").Err()
|
||||
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||
slog.Error("sandbox event consumer: failed to create consumer group", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
streams, err := c.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: reconcilerConsumerGrp,
|
||||
Consumer: reconcilerConsumer,
|
||||
Streams: []string{unifiedEventStream, ">"},
|
||||
Count: 10,
|
||||
Block: 5 * time.Second,
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil || ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
slog.Warn("sandbox event consumer: xreadgroup error", "error", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, msg := range stream.Messages {
|
||||
c.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
ack := true
|
||||
defer func() {
|
||||
if !ack {
|
||||
return
|
||||
}
|
||||
ackCtx, ackCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer ackCancel()
|
||||
if err := c.rdb.XAck(ackCtx, unifiedEventStream, reconcilerConsumerGrp, msg.ID).Err(); err != nil {
|
||||
slog.Warn("sandbox event consumer: xack failed", "id", msg.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload, ok := msg.Values["payload"].(string)
|
||||
if !ok {
|
||||
slog.Warn("sandbox event consumer: message missing payload", "id", msg.ID)
|
||||
return
|
||||
}
|
||||
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
slog.Warn("sandbox event consumer: failed to unmarshal event", "id", msg.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Only capsule.* events drive sandbox reconciliation.
|
||||
if !strings.HasPrefix(event.Event, "capsule.") || event.Event == events.CapsuleStateChanged {
|
||||
return
|
||||
}
|
||||
// Only system-actor events represent host-side state we need to reflect
|
||||
// in the DB; user-actor events are already mirrored by the handler that
|
||||
// produced them.
|
||||
if event.Actor.Type != events.ActorSystem {
|
||||
// Exception: handlers publish capsule.create with user actor before
|
||||
// the host has reported back. Those are owned by the service goroutine.
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(event.Resource.ID)
|
||||
if err != nil {
|
||||
slog.Warn("sandbox event consumer: invalid sandbox ID", "sandbox_id", event.Resource.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Event {
|
||||
case events.CapsuleCreate:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStarted(ctx, sandboxID, event, "starting")
|
||||
} else {
|
||||
c.handleFailed(ctx, sandboxID, event)
|
||||
}
|
||||
case events.CapsuleResume:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStarted(ctx, sandboxID, event, "resuming")
|
||||
} else {
|
||||
c.handleFailed(ctx, sandboxID, event)
|
||||
}
|
||||
case events.CapsulePause:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleAutoPaused(ctx, sandboxID)
|
||||
}
|
||||
case events.CapsuleDestroy:
|
||||
if event.Outcome == events.OutcomeSuccess {
|
||||
c.handleStopped(ctx, sandboxID)
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch to extension hooks (cloud billing, audit shipping, etc.). Any
|
||||
// hook error suppresses the ack so the message will be redelivered. Hooks
|
||||
// MUST be idempotent — duplicate deliveries are expected on transient
|
||||
// failures.
|
||||
if len(c.hooks) > 0 && event.Outcome == events.OutcomeSuccess {
|
||||
if verb, ok := canonicalSandboxVerb(event.Event); ok {
|
||||
teamID, _ := id.ParseTeamID(event.TeamID)
|
||||
meta := map[string]any{}
|
||||
for k, v := range event.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
ev := cpextension.SandboxEvent{
|
||||
SandboxID: sandboxID,
|
||||
TeamID: teamID,
|
||||
Type: verb,
|
||||
OccurredAt: parseEventTimestamp(event.Timestamp),
|
||||
Metadata: meta,
|
||||
}
|
||||
for _, h := range c.hooks {
|
||||
if err := h.OnSandboxEvent(ctx, ev); err != nil {
|
||||
slog.Warn("sandbox event hook failed; leaving message un-acked", "id", msg.ID, "event", event.Event, "error", err)
|
||||
ack = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalSandboxVerb(event string) (string, bool) {
|
||||
switch event {
|
||||
case events.CapsuleCreate:
|
||||
return "created", true
|
||||
case events.CapsuleResume:
|
||||
return "resumed", true
|
||||
case events.CapsulePause:
|
||||
return "paused", true
|
||||
case events.CapsuleDestroy:
|
||||
return "destroyed", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func parseEventTimestamp(s string) time.Time {
|
||||
if s == "" {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// handleStarted is a fallback writer for capsule.create.success and
|
||||
// capsule.resume.success. The background goroutine in SandboxService is the
|
||||
// primary writer; this only succeeds if the goroutine's conditional update
|
||||
// was missed.
|
||||
func (c *SandboxEventConsumer) handleStarted(ctx context.Context, sandboxID pgtype.UUID, event events.Event, fromStatus string) {
|
||||
hostIP := event.Metadata["host_ip"]
|
||||
now := time.Now()
|
||||
if _, err := c.db.UpdateSandboxRunningIf(ctx, db.UpdateSandboxRunningIfParams{
|
||||
ID: sandboxID,
|
||||
Status: fromStatus,
|
||||
HostIp: hostIP,
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleAutoPaused(ctx context.Context, sandboxID pgtype.UUID) {
|
||||
for _, fromStatus := range []string{"running", "pausing"} {
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: fromStatus, Status_2: "paused",
|
||||
}); err == nil {
|
||||
slog.Debug("sandbox event consumer: auto-paused fallback applied", "sandbox_id", id.FormatSandboxID(sandboxID), "from", fromStatus)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SandboxEventConsumer) handleStopped(ctx context.Context, sandboxID pgtype.UUID) {
|
||||
// stopping → stopped (CP-initiated destroy completed). No audit row here;
|
||||
// the handler that issued the destroy already wrote one.
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "stopping",
|
||||
Status_2: "stopped",
|
||||
}); err == nil {
|
||||
return
|
||||
}
|
||||
// running → stopped (autonomous destroy, e.g. TTL destroy fallback).
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID,
|
||||
Status: "running",
|
||||
Status_2: "stopped",
|
||||
}); err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Warn("sandbox event consumer: failed to update sandbox to stopped", "sandbox_id", id.FormatSandboxID(sandboxID), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleFailed marks a sandbox as "error" when a verb event reports failure
|
||||
// and writes a system audit row. The DB update is idempotent — the
|
||||
// SandboxService background goroutine usually wrote "error" already on the
|
||||
// fast-fail path, which settles in seconds and so never reaches the
|
||||
// HostMonitor's transient-timeout reconciliation.
|
||||
//
|
||||
// audit.Log writes the row only — it does NOT republish an event, which would
|
||||
// loop back into this consumer. Do not switch to LogSandboxCreateSystem here.
|
||||
func (c *SandboxEventConsumer) handleFailed(ctx context.Context, sandboxID pgtype.UUID, event events.Event) {
|
||||
for _, fromStatus := range []string{"running", "starting", "pausing", "resuming", "snapshotting"} {
|
||||
if _, err := c.db.UpdateSandboxStatusIf(ctx, db.UpdateSandboxStatusIfParams{
|
||||
ID: sandboxID, Status: fromStatus, Status_2: "error",
|
||||
}); err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// The HostMonitor transient-timeout reconciler emits failure events via
|
||||
// LogSandboxCreateSystem / LogSandboxResumeSystem, which already write
|
||||
// their own audit row before publishing — auditing again here would
|
||||
// double-count. Those helpers publish with reason="transient_timeout";
|
||||
// the un-audited fast-fail (createInBackground) and host-callback paths
|
||||
// do not, so only they need a row written here.
|
||||
if event.Metadata["reason"] == "transient_timeout" {
|
||||
return
|
||||
}
|
||||
|
||||
action := "create"
|
||||
if event.Event == events.CapsuleResume {
|
||||
action = "resume"
|
||||
}
|
||||
reason := event.Metadata["reason"]
|
||||
if reason == "" {
|
||||
reason = action + "_failed"
|
||||
}
|
||||
meta := map[string]any{"reason": reason}
|
||||
if event.Error != "" {
|
||||
meta["error"] = event.Error
|
||||
}
|
||||
teamID, _ := id.ParseTeamID(event.TeamID)
|
||||
c.audit.Log(ctx, audit.Entry{
|
||||
TeamID: teamID,
|
||||
ActorType: "system",
|
||||
ResourceType: "sandbox",
|
||||
ResourceID: id.FormatSandboxID(sandboxID),
|
||||
Action: action,
|
||||
Scope: "team",
|
||||
Status: "error",
|
||||
Metadata: meta,
|
||||
})
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -13,9 +14,11 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/audit"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
|
||||
authsession "git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/channels"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/service"
|
||||
@ -26,15 +29,19 @@ var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
version string
|
||||
router chi.Router
|
||||
BuildSvc *service.BuildService
|
||||
SSERelay *SSERelay
|
||||
SessionSvc *authsession.Service
|
||||
version string
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
// Extensions are called after core routes are registered, allowing cloud
|
||||
// or third-party code to add routes and middleware.
|
||||
// New constructs the chi router and registers all routes. The jwtSecret is
|
||||
// still used to sign host-agent JWTs (long-lived, for the wrenn-agent → cp
|
||||
// trust path) and to HMAC OAuth state/link cookies; user authentication has
|
||||
// migrated to opaque session cookies backed by the session service.
|
||||
func New(
|
||||
ctx context.Context,
|
||||
queries *db.Queries,
|
||||
pool *lifecycle.HostClientPool,
|
||||
sched scheduler.HostScheduler,
|
||||
@ -45,10 +52,12 @@ func New(
|
||||
oauthRedirectURL string,
|
||||
ca *auth.CA,
|
||||
al *audit.AuditLogger,
|
||||
eventPub *channels.Publisher,
|
||||
channelSvc *channels.Service,
|
||||
mailer email.Mailer,
|
||||
extensions []cpextension.Extension,
|
||||
sctx cpextension.ServerContext,
|
||||
monitor *HostMonitor,
|
||||
version string,
|
||||
) *Server {
|
||||
r := chi.NewRouter()
|
||||
@ -61,8 +70,28 @@ func New(
|
||||
}
|
||||
}
|
||||
|
||||
// Session service backs cookie-based browser auth. The cpserver wires it
|
||||
// through ServerContext so cloud-repo extensions can share the same
|
||||
// instance; fall back to constructing one here if the host program
|
||||
// instantiates the API directly (tests, ad-hoc tooling).
|
||||
sessionSvc := sctx.Sessions
|
||||
if sessionSvc == nil {
|
||||
sessionSvc = authsession.NewService(queries, rdb)
|
||||
}
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||
sandboxSvc.PublishEvent = func(ctx context.Context, event service.SandboxStateEvent) {
|
||||
if evt, ok := serviceEventToCanonical(event); ok {
|
||||
// State-change events are ephemeral UI signals — mirror them to the
|
||||
// dashboard via Pub/Sub only, never to durable channel subscribers.
|
||||
if evt.Event == events.CapsuleStateChanged {
|
||||
eventPub.PublishTransient(ctx, evt)
|
||||
} else {
|
||||
eventPub.Publish(ctx, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||
@ -72,30 +101,40 @@ func New(
|
||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||
usageSvc := &service.UsageService{DB: queries}
|
||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||
buildBroker := service.NewBuildBroker(rdb)
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc, al)
|
||||
exec := newExecHandler(queries, pool)
|
||||
execStream := newExecStreamHandler(queries, pool, jwtSecret)
|
||||
execStream := newExecStreamHandler(queries, pool)
|
||||
files := newFilesHandler(queries, pool)
|
||||
filesStream := newFilesStreamHandler(queries, pool)
|
||||
fsH := newFSHandler(queries, pool)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, pool, al)
|
||||
authH := newAuthHandler(queries, pgPool, jwtSecret, mailer, rdb, oauthRedirectURL)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
snapshots := newSnapshotHandler(templateSvc, sandboxSvc, queries, pool, al)
|
||||
authHooks := collectAuthHooks(extensions)
|
||||
authH := newAuthHandler(queries, pgPool, sessionSvc, mailer, rdb, oauthRedirectURL, authHooks)
|
||||
oauthH := newOAuthHandler(queries, pgPool, jwtSecret, sessionSvc, oauthRegistry, oauthRedirectURL, authHooks)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer)
|
||||
usersH := newUsersHandler(queries, userSvc, al)
|
||||
hostH := newHostHandler(hostSvc, queries, al, monitor)
|
||||
teamH := newTeamHandler(teamSvc, al, mailer, sessionSvc)
|
||||
usersH := newUsersHandler(queries, userSvc, al, sessionSvc)
|
||||
auditH := newAuditHandler(auditSvc)
|
||||
statsH := newStatsHandler(statsSvc)
|
||||
usageH := newUsageHandler(usageSvc)
|
||||
metricsH := newSandboxMetricsHandler(queries, pool)
|
||||
buildH := newBuildHandler(buildSvc, queries, pool, al)
|
||||
buildStreamH := newBuildStreamHandler(queries, buildBroker)
|
||||
channelH := newChannelHandler(channelSvc, al)
|
||||
ptyH := newPtyHandler(queries, pool, jwtSecret)
|
||||
processH := newProcessHandler(queries, pool, jwtSecret)
|
||||
ptyH := newPtyHandler(queries, pool)
|
||||
processH := newProcessHandler(queries, pool)
|
||||
adminCapsules := newAdminCapsuleHandler(sandboxSvc, queries, pool, al)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, mailer, oauthRegistry, oauthRedirectURL, teamSvc)
|
||||
sandboxEvtH := newSandboxEventHandler(queries, eventPub)
|
||||
meH := newMeHandler(queries, pgPool, rdb, jwtSecret, sessionSvc, mailer, oauthRegistry, oauthRedirectURL, teamSvc, authHooks)
|
||||
sessionsH := newSessionsHandler(sessionSvc)
|
||||
|
||||
// SSE real-time event streaming.
|
||||
sseBroker := NewSSEBroker()
|
||||
sseRelay := NewSSERelay(rdb, queries, sseBroker)
|
||||
sseH := newSSEHandler(sseBroker)
|
||||
|
||||
// Health check.
|
||||
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -107,7 +146,8 @@ func New(
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
r.Get("/docs", serveDocs)
|
||||
|
||||
// Unauthenticated auth endpoints.
|
||||
// Unauthenticated auth endpoints. CSRF is not required here — there is
|
||||
// no session cookie yet to be forged.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Post("/v1/auth/activate", authH.Activate)
|
||||
@ -118,31 +158,45 @@ func New(
|
||||
r.Post("/v1/me/password/reset", meH.RequestPasswordReset)
|
||||
r.Post("/v1/me/password/reset/confirm", meH.ConfirmPasswordReset)
|
||||
|
||||
// JWT-authenticated: self-service account management.
|
||||
csrf := requireCSRF()
|
||||
|
||||
// Session-authenticated: logout endpoints (must be inside the session
|
||||
// group so the handler sees the current SID via AuthContext).
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/v1/auth/logout", authH.Logout)
|
||||
r.Post("/v1/auth/logout-all", authH.LogoutAll)
|
||||
r.Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
})
|
||||
|
||||
// Session-authenticated: self-service account management.
|
||||
r.Route("/v1/me", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", meH.GetMe)
|
||||
r.Patch("/", meH.UpdateName)
|
||||
r.Post("/password", meH.ChangePassword)
|
||||
r.Get("/providers/{provider}/connect", meH.ConnectProvider)
|
||||
r.Delete("/providers/{provider}", meH.DisconnectProvider)
|
||||
r.Delete("/", meH.DeleteAccount)
|
||||
r.Get("/sessions", sessionsH.List)
|
||||
r.Delete("/sessions/{id}", sessionsH.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: switch active team.
|
||||
r.With(requireJWT(jwtSecret, queries)).Post("/v1/auth/switch-team", authH.SwitchTeam)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
// Session-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// JWT-authenticated: team management.
|
||||
// Session-authenticated: team management.
|
||||
r.Route("/v1/teams", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", teamH.List)
|
||||
r.Post("/", teamH.Create)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -157,14 +211,15 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: user search (for add-member UI).
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
|
||||
// Session-authenticated: user search (for add-member UI).
|
||||
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/users/search", usersH.Search)
|
||||
|
||||
// Capsule lifecycle: accepts API key or JWT bearer token.
|
||||
// Capsule lifecycle: API key (SDK) or session cookie (browser).
|
||||
r.Route("/v1/capsules", func(r chi.Router) {
|
||||
// Auth-required routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
r.Get("/stats", statsH.GetStats)
|
||||
@ -174,7 +229,8 @@ func New(
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
// Auth-required non-WS routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
@ -193,11 +249,11 @@ func New(
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
})
|
||||
|
||||
// WebSocket endpoints — handlers authenticate after upgrade.
|
||||
// optionalAPIKeyOrJWT injects auth context from headers when
|
||||
// present (SDK clients) but does not reject when absent (browsers).
|
||||
// WebSocket endpoints — middleware injects auth context from
|
||||
// cookie (browser) or X-API-Key (SDK) before upgrade. CSRF is
|
||||
// not applicable to GET upgrades.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(optionalAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
@ -205,9 +261,10 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// Snapshot / template management: accepts API key or JWT bearer token.
|
||||
// Snapshot / template management: API key (SDK) or session (browser).
|
||||
r.Route("/v1/snapshots", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Use(requireSessionOrAPIKey(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", snapshots.Create)
|
||||
r.Get("/", snapshots.List)
|
||||
r.Delete("/{name}", snapshots.Delete)
|
||||
@ -221,12 +278,14 @@ func New(
|
||||
// Unauthenticated: refresh token exchange.
|
||||
r.Post("/auth/refresh", hostH.RefreshToken)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
// Host-token-authenticated: heartbeat and lifecycle callbacks.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
r.With(requireHostToken(jwtSecret)).Post("/sandbox-events", sandboxEvtH.Handle)
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
// Session-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
@ -241,9 +300,10 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: notification channels.
|
||||
// Session-authenticated: notification channels.
|
||||
r.Route("/v1/channels", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(csrf)
|
||||
r.Post("/", channelH.Create)
|
||||
r.Get("/", channelH.List)
|
||||
r.Post("/test", channelH.Test)
|
||||
@ -255,18 +315,27 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// JWT-authenticated: audit log.
|
||||
r.With(requireJWT(jwtSecret, queries)).Get("/v1/audit-logs", auditH.List)
|
||||
// SSE event stream: browser sends wrenn_sid cookie on EventSource
|
||||
// automatically; SDKs may set X-API-Key. Ticket-based auth is gone.
|
||||
r.With(requireSessionOrAPIKey(queries, sessionSvc)).Get("/v1/events/stream", sseH.Stream)
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
// Session-authenticated: audit log.
|
||||
r.With(requireSession(queries, sessionSvc), csrf).Get("/v1/audit-logs", auditH.List)
|
||||
|
||||
// Platform admin routes — session + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
// Admin SSE event stream (sees all teams).
|
||||
r.With(requireSession(queries, sessionSvc), requireAdmin(queries), injectPlatformTeam()).Get("/events/stream", sseH.AdminStream)
|
||||
|
||||
// Auth-required admin routes (non-capsule + capsule list/create).
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(csrf)
|
||||
r.Get("/teams", teamH.AdminListTeams)
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
|
||||
r.Get("/hosts", hostH.AdminList)
|
||||
r.Get("/users", usersH.AdminListUsers)
|
||||
r.Put("/users/{id}/active", usersH.SetUserActive)
|
||||
r.Put("/users/{id}/admin", usersH.SetUserAdmin)
|
||||
@ -281,11 +350,21 @@ func New(
|
||||
r.Get("/capsules", adminCapsules.List)
|
||||
})
|
||||
|
||||
// Admin build console WebSocket — cookie + admin DB check before
|
||||
// upgrade, no CSRF (WS upgrade). Builds are platform-scoped, not
|
||||
// sandbox-scoped, so this sits outside the /capsules/{id} router.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Get("/builds/{id}/stream", buildStreamH.Stream)
|
||||
})
|
||||
|
||||
r.Route("/capsules/{id}", func(r chi.Router) {
|
||||
// Auth-required non-WS admin capsule routes.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret, queries))
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(csrf)
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/", adminCapsules.Get)
|
||||
r.Delete("/", adminCapsules.Destroy)
|
||||
@ -301,11 +380,13 @@ func New(
|
||||
r.Delete("/processes/{selector}", processH.KillProcess)
|
||||
})
|
||||
|
||||
// Admin WebSocket endpoints — handlers authenticate after upgrade
|
||||
// via wsAuthenticateAdmin. markAdminWS sets the context flag so
|
||||
// handlers know to use admin auth instead of regular auth.
|
||||
// Admin WebSocket endpoints — browser auth via cookie + admin DB check
|
||||
// before upgrade. markAdminWS is retained as a context hint for any
|
||||
// admin-specific behavior downstream.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(markAdminWS)
|
||||
r.Use(requireSession(queries, sessionSvc))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Use(injectPlatformTeam())
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Get("/pty", ptyH.PtySession)
|
||||
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
|
||||
@ -318,7 +399,7 @@ func New(
|
||||
ext.RegisterRoutes(r, sctx)
|
||||
}
|
||||
|
||||
return &Server{router: r, BuildSvc: buildSvc, version: version}
|
||||
return &Server{router: r, BuildSvc: buildSvc, SSERelay: sseRelay, SessionSvc: sessionSvc, version: version}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
@ -363,3 +444,101 @@ func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
// serviceEventToCanonical maps a SandboxService background-goroutine event
|
||||
// into the canonical events.Event taxonomy for unified publishing. Returns
|
||||
// false for events that should not be broadcast.
|
||||
func serviceEventToCanonical(e service.SandboxStateEvent) (events.Event, bool) {
|
||||
var (
|
||||
eventType string
|
||||
outcome events.Outcome
|
||||
metadata map[string]string
|
||||
)
|
||||
switch e.Event {
|
||||
case "sandbox.started":
|
||||
eventType = events.CapsuleCreate
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.resumed":
|
||||
eventType = events.CapsuleResume
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.paused":
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.auto_paused":
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeSuccess
|
||||
metadata = map[string]string{"reason": "ttl_expired"}
|
||||
case "sandbox.stopped":
|
||||
eventType = events.CapsuleDestroy
|
||||
outcome = events.OutcomeSuccess
|
||||
case "sandbox.pause_failed":
|
||||
// reason must be non-empty or channels.isRedundantSystemFollowup
|
||||
// filters this system-actor event out of webhook delivery.
|
||||
eventType = events.CapsulePause
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "pause_failed"}
|
||||
case "sandbox.resume_failed":
|
||||
eventType = events.CapsuleResume
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "resume_failed"}
|
||||
case "sandbox.failed":
|
||||
// First-boot failure from the createInBackground goroutine. Without
|
||||
// this case the event falls through to default and is dropped — no
|
||||
// SSE push, no channel delivery, no DB reconciliation. reason must be
|
||||
// non-empty or channels.isRedundantSystemFollowup filters it out.
|
||||
eventType = events.CapsuleCreate
|
||||
outcome = events.OutcomeError
|
||||
metadata = map[string]string{"reason": "create_failed"}
|
||||
case "sandbox.snapshotted":
|
||||
// Completion of an async snapshot. The resource is the template name,
|
||||
// not the sandbox, so the dashboard's snapshot list refreshes.
|
||||
return events.Event{
|
||||
Event: events.SnapshotCreate,
|
||||
Outcome: events.OutcomeSuccess,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
|
||||
}, true
|
||||
case "sandbox.snapshot_failed":
|
||||
return events.Event{
|
||||
Event: events.SnapshotCreate,
|
||||
Outcome: events.OutcomeError,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.Metadata["name"], Type: "snapshot"},
|
||||
Metadata: map[string]string{"reason": "snapshot_failed"},
|
||||
Error: e.Error,
|
||||
}, true
|
||||
case "sandbox.state_changed":
|
||||
// Transient badge transition with no terminal verb of its own. Carries
|
||||
// from/to in metadata; routed via Pub/Sub only by the caller.
|
||||
return events.Event{
|
||||
Event: events.CapsuleStateChanged,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
|
||||
Metadata: e.Metadata,
|
||||
}, true
|
||||
default:
|
||||
return events.Event{}, false
|
||||
}
|
||||
if e.HostIP != "" {
|
||||
if metadata == nil {
|
||||
metadata = map[string]string{}
|
||||
}
|
||||
metadata["host_ip"] = e.HostIP
|
||||
}
|
||||
return events.Event{
|
||||
Event: eventType,
|
||||
Outcome: outcome,
|
||||
Timestamp: events.Now(),
|
||||
TeamID: e.TeamID,
|
||||
Actor: events.SystemActor(),
|
||||
Resource: events.Resource{ID: e.SandboxID, Type: "sandbox"},
|
||||
Metadata: metadata,
|
||||
Error: e.Error,
|
||||
}, true
|
||||
}
|
||||
|
||||
80
internal/api/sse_broker.go
Normal file
80
internal/api/sse_broker.go
Normal file
@ -0,0 +1,80 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const sseChannelBuffer = 32
|
||||
|
||||
type sseMessage struct {
|
||||
EventType string
|
||||
Data json.RawMessage
|
||||
}
|
||||
|
||||
type sseSubscriber struct {
|
||||
teamID string
|
||||
isAdmin bool
|
||||
ch chan sseMessage
|
||||
}
|
||||
|
||||
// SSEBroker is an in-process fan-out hub that dispatches events to connected
|
||||
// SSE clients, filtering by team ownership.
|
||||
type SSEBroker struct {
|
||||
mu sync.RWMutex
|
||||
nextID atomic.Uint64
|
||||
subscribers map[uint64]*sseSubscriber
|
||||
}
|
||||
|
||||
// NewSSEBroker constructs a broker with no subscribers.
|
||||
func NewSSEBroker() *SSEBroker {
|
||||
return &SSEBroker{
|
||||
subscribers: make(map[uint64]*sseSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new SSE client. Returns a subscriber ID (for
|
||||
// Unsubscribe) and a receive-only channel that delivers filtered events.
|
||||
func (b *SSEBroker) Subscribe(teamID string, isAdmin bool) (uint64, <-chan sseMessage) {
|
||||
id := b.nextID.Add(1)
|
||||
ch := make(chan sseMessage, sseChannelBuffer)
|
||||
sub := &sseSubscriber{teamID: teamID, isAdmin: isAdmin, ch: ch}
|
||||
|
||||
b.mu.Lock()
|
||||
b.subscribers[id] = sub
|
||||
b.mu.Unlock()
|
||||
|
||||
slog.Debug("sse: client subscribed", "sub_id", id, "team_id", teamID, "admin", isAdmin)
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a client. The handler loop exits via context cancellation;
|
||||
// the channel is not closed here to avoid send-on-closed-channel races with Dispatch.
|
||||
func (b *SSEBroker) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
delete(b.subscribers, id)
|
||||
b.mu.Unlock()
|
||||
slog.Debug("sse: client unsubscribed", "sub_id", id)
|
||||
}
|
||||
|
||||
// Dispatch fans out an event to all matching subscribers. Admin subscribers
|
||||
// receive all events; team subscribers only receive events for their team.
|
||||
// Non-blocking: events are dropped for slow consumers.
|
||||
func (b *SSEBroker) Dispatch(eventType string, teamID string, data json.RawMessage) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
msg := sseMessage{EventType: eventType, Data: data}
|
||||
for id, sub := range b.subscribers {
|
||||
if !sub.isAdmin && sub.teamID != teamID {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case sub.ch <- msg:
|
||||
default:
|
||||
slog.Warn("sse: dropped event for slow consumer", "sub_id", id, "event", eventType)
|
||||
}
|
||||
}
|
||||
}
|
||||
147
internal/api/sse_relay.go
Normal file
147
internal/api/sse_relay.go
Normal file
@ -0,0 +1,147 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const ssePubSubChannel = "wrenn:sse"
|
||||
|
||||
// sseEventPayload is the JSON envelope sent to SSE clients.
|
||||
type sseEventPayload struct {
|
||||
Event string `json:"event"`
|
||||
Outcome events.Outcome `json:"outcome,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TeamID string `json:"team_id"`
|
||||
Actor events.Actor `json:"actor"`
|
||||
Resource events.Resource `json:"resource"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Sandbox *sandboxResponse `json:"sandbox,omitempty"`
|
||||
}
|
||||
|
||||
// SSERelay subscribes to the Redis Pub/Sub channel and dispatches hydrated
|
||||
// events to the in-process broker. One instance per CP process.
|
||||
type SSERelay struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
broker *SSEBroker
|
||||
}
|
||||
|
||||
// NewSSERelay constructs the relay.
|
||||
func NewSSERelay(rdb *redis.Client, queries *db.Queries, broker *SSEBroker) *SSERelay {
|
||||
return &SSERelay{rdb: rdb, db: queries, broker: broker}
|
||||
}
|
||||
|
||||
// Start launches the Pub/Sub subscription goroutine. Returns when ctx is cancelled.
|
||||
func (r *SSERelay) Start(ctx context.Context) {
|
||||
go r.run(ctx)
|
||||
}
|
||||
|
||||
func (r *SSERelay) run(ctx context.Context) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
r.subscribe(ctx)
|
||||
// Backoff before reconnecting.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SSERelay) subscribe(ctx context.Context) {
|
||||
pubsub := r.rdb.Subscribe(ctx, ssePubSubChannel)
|
||||
defer pubsub.Close()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
r.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SSERelay) handleMessage(ctx context.Context, msg *redis.Message) {
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &event); err != nil {
|
||||
slog.Warn("sse relay: failed to unmarshal event", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
payload := sseEventPayload{
|
||||
Event: event.Event,
|
||||
Outcome: event.Outcome,
|
||||
Timestamp: event.Timestamp,
|
||||
TeamID: event.TeamID,
|
||||
Actor: event.Actor,
|
||||
Resource: event.Resource,
|
||||
Metadata: event.Metadata,
|
||||
Error: event.Error,
|
||||
}
|
||||
|
||||
// Hydrate sandbox state for capsule events.
|
||||
if isCapsuleEvent(event.Event) {
|
||||
sb, err := r.hydrateSandbox(ctx, event.Resource.ID)
|
||||
if err != nil {
|
||||
slog.Debug("sse relay: sandbox hydration failed (may be deleted)", "sandbox_id", event.Resource.ID, "error", err)
|
||||
} else {
|
||||
payload.Sandbox = sb
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
slog.Warn("sse relay: failed to marshal payload", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.broker.Dispatch(event.Event, event.TeamID, data)
|
||||
}
|
||||
|
||||
func (r *SSERelay) hydrateSandbox(ctx context.Context, sandboxIDStr string) (*sandboxResponse, error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(sandboxIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sb, err := r.db.GetSandbox(queryCtx, sandboxID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := sandboxToResponse(sb)
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func isCapsuleEvent(eventType string) bool {
|
||||
switch eventType {
|
||||
case events.CapsuleCreate, events.CapsulePause, events.CapsuleResume, events.CapsuleDestroy, events.CapsuleStateChanged:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -80,8 +80,8 @@ func (r *LoopRegistry) Release(imagePath string) {
|
||||
|
||||
e.refcount--
|
||||
if e.refcount <= 0 {
|
||||
if err := losetupDetach(e.device); err != nil {
|
||||
slog.Warn("losetup detach failed", "device", e.device, "error", err)
|
||||
if err := losetupDetachRetry(e.device); err != nil {
|
||||
slog.Error("losetup detach failed, loop device leaked", "device", e.device, "image", imagePath, "error", err)
|
||||
}
|
||||
delete(r.entries, imagePath)
|
||||
slog.Info("loop device released", "image", imagePath, "device", e.device)
|
||||
@ -94,8 +94,8 @@ func (r *LoopRegistry) ReleaseAll() {
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for path, e := range r.entries {
|
||||
if err := losetupDetach(e.device); err != nil {
|
||||
slog.Warn("losetup detach failed", "device", e.device, "error", err)
|
||||
if err := losetupDetachRetry(e.device); err != nil {
|
||||
slog.Error("losetup detach failed during shutdown", "device", e.device, "image", path, "error", err)
|
||||
}
|
||||
delete(r.entries, path)
|
||||
}
|
||||
@ -109,6 +109,31 @@ type SnapshotDevice struct {
|
||||
CowLoopDev string // loop device for the CoW file
|
||||
}
|
||||
|
||||
// attachCowAndCreate attaches a CoW file as a loop device, creates the
|
||||
// dm-snapshot target, and returns the assembled SnapshotDevice. On failure
|
||||
// it detaches the CoW loop device before returning.
|
||||
func attachCowAndCreate(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
|
||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
}
|
||||
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetachRetry(cowLoopDev); detachErr != nil {
|
||||
slog.Error("cow losetup detach failed during cleanup, loop device leaked", "device", cowLoopDev, "error", detachErr)
|
||||
}
|
||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||
}
|
||||
|
||||
return &SnapshotDevice{
|
||||
Name: name,
|
||||
DevicePath: "/dev/mapper/" + name,
|
||||
CowPath: cowPath,
|
||||
CowLoopDev: cowLoopDev,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateSnapshot sets up a new dm-snapshot device.
|
||||
//
|
||||
// It creates a sparse CoW file, attaches it as a loop device, and creates
|
||||
@ -117,45 +142,24 @@ type SnapshotDevice struct {
|
||||
//
|
||||
// The origin loop device must already exist (from LoopRegistry.Acquire).
|
||||
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes, cowSizeBytes int64) (*SnapshotDevice, error) {
|
||||
// Create sparse CoW file. The logical size limits how many blocks can be
|
||||
// modified; because the file is sparse, only written blocks use real disk.
|
||||
if err := createSparseFile(cowPath, cowSizeBytes); err != nil {
|
||||
return nil, fmt.Errorf("create cow file: %w", err)
|
||||
}
|
||||
|
||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
|
||||
if err != nil {
|
||||
os.Remove(cowPath)
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The dm-snapshot virtual device size must match the origin — the snapshot
|
||||
// target maps 1:1 onto origin sectors. The CoW file just needs enough
|
||||
// space to store all modified blocks (it's sparse, so 20GB costs nothing).
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
|
||||
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
|
||||
}
|
||||
os.Remove(cowPath)
|
||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||
}
|
||||
|
||||
devPath := "/dev/mapper/" + name
|
||||
|
||||
slog.Info("dm-snapshot created",
|
||||
"name", name,
|
||||
"device", devPath,
|
||||
"device", dev.DevicePath,
|
||||
"origin", originLoopDev,
|
||||
"cow", cowPath,
|
||||
)
|
||||
|
||||
return &SnapshotDevice{
|
||||
Name: name,
|
||||
DevicePath: devPath,
|
||||
CowPath: cowPath,
|
||||
CowLoopDev: cowLoopDev,
|
||||
}, nil
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
// RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file.
|
||||
@ -171,34 +175,19 @@ func RestoreSnapshot(ctx context.Context, name, originLoopDev, cowPath string, o
|
||||
}
|
||||
}
|
||||
|
||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||
dev, err := attachCowAndCreate(name, originLoopDev, cowPath, originSizeBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
|
||||
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
|
||||
}
|
||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||
}
|
||||
|
||||
devPath := "/dev/mapper/" + name
|
||||
|
||||
slog.Info("dm-snapshot restored",
|
||||
"name", name,
|
||||
"device", devPath,
|
||||
"device", dev.DevicePath,
|
||||
"origin", originLoopDev,
|
||||
"cow", cowPath,
|
||||
)
|
||||
|
||||
return &SnapshotDevice{
|
||||
Name: name,
|
||||
DevicePath: devPath,
|
||||
CowPath: cowPath,
|
||||
CowLoopDev: cowLoopDev,
|
||||
}, nil
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
// RemoveSnapshot tears down a dm-snapshot device and its CoW loop device.
|
||||
@ -208,8 +197,8 @@ func RemoveSnapshot(ctx context.Context, dev *SnapshotDevice) error {
|
||||
return fmt.Errorf("dmsetup remove %s: %w", dev.Name, err)
|
||||
}
|
||||
|
||||
if err := losetupDetach(dev.CowLoopDev); err != nil {
|
||||
slog.Warn("cow losetup detach failed", "device", dev.CowLoopDev, "error", err)
|
||||
if err := losetupDetachRetry(dev.CowLoopDev); err != nil {
|
||||
return fmt.Errorf("detach cow loop %s: %w", dev.CowLoopDev, err)
|
||||
}
|
||||
|
||||
slog.Info("dm-snapshot removed", "name", dev.Name)
|
||||
@ -272,6 +261,29 @@ func CleanupStaleDevices() {
|
||||
}
|
||||
}
|
||||
|
||||
// LogLoopState enumerates currently-attached loop devices that back wrenn
|
||||
// rootfs images and logs them at INFO. Diagnostic only — meant to be called
|
||||
// once at agent startup so leaked loop attachments from a prior crash are
|
||||
// visible in the journal before the LoopRegistry starts refcounting.
|
||||
func LogLoopState() {
|
||||
out, err := exec.Command("losetup", "-l", "--noheadings", "--output", "NAME,BACK-FILE").CombinedOutput()
|
||||
if err != nil {
|
||||
slog.Debug("losetup -l failed", "error", err)
|
||||
return
|
||||
}
|
||||
wrennCount := 0
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
if !strings.Contains(line, "/var/lib/wrenn/") {
|
||||
continue
|
||||
}
|
||||
wrennCount++
|
||||
slog.Info("pre-existing loop attachment", "entry", strings.TrimSpace(line))
|
||||
}
|
||||
if wrennCount == 0 {
|
||||
slog.Info("no pre-existing wrenn loop attachments")
|
||||
}
|
||||
}
|
||||
|
||||
// --- low-level helpers ---
|
||||
|
||||
// losetupCreate attaches a file as a read-only loop device.
|
||||
@ -297,6 +309,24 @@ func losetupDetach(dev string) error {
|
||||
return exec.Command("losetup", "-d", dev).Run()
|
||||
}
|
||||
|
||||
// losetupDetachRetry detaches a loop device with retries for transient
|
||||
// "device busy" errors (kernel may still hold references briefly after
|
||||
// dm-snapshot removal).
|
||||
func losetupDetachRetry(dev string) error {
|
||||
var lastErr error
|
||||
for attempt := range 5 {
|
||||
if attempt > 0 {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err := losetupDetach(dev); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("after 5 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
// dmsetupCreate creates a dm-snapshot device with persistent metadata.
|
||||
func dmsetupCreate(name, originDev, cowDev string, sectors int64) error {
|
||||
// Table format: <start> <size> snapshot <origin> <cow> P <chunk_size>
|
||||
@ -316,7 +346,7 @@ func dmDeviceExists(name string) bool {
|
||||
|
||||
// dmsetupRemove removes a device-mapper device, retrying on transient
|
||||
// "device busy" errors that occur when the kernel hasn't fully released
|
||||
// the device after a Firecracker process exits.
|
||||
// the device after a VMM process exits.
|
||||
func dmsetupRemove(ctx context.Context, name string) error {
|
||||
var lastErr error
|
||||
for attempt := range 5 {
|
||||
@ -361,5 +391,9 @@ func createSparseFile(path string, sizeBytes int64) error {
|
||||
os.Remove(path)
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(path)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -10,9 +10,12 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/wrenn/proto/envd/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/envd/gen/genconnect"
|
||||
)
|
||||
@ -78,16 +81,31 @@ type ExecResult struct {
|
||||
ExitCode int32
|
||||
}
|
||||
|
||||
// ExecOpts holds optional parameters for Exec.
|
||||
type ExecOpts struct {
|
||||
Envs map[string]string
|
||||
Cwd string
|
||||
}
|
||||
|
||||
// Exec runs a command inside the sandbox and collects all stdout/stderr output.
|
||||
// It blocks until the command completes.
|
||||
func (c *Client) Exec(ctx context.Context, cmd string, args ...string) (*ExecResult, error) {
|
||||
func (c *Client) Exec(ctx context.Context, cmd string, args []string, opts *ExecOpts) (*ExecResult, error) {
|
||||
stdin := false
|
||||
proc := &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
}
|
||||
if opts != nil {
|
||||
if len(opts.Envs) > 0 {
|
||||
proc.Envs = opts.Envs
|
||||
}
|
||||
if opts.Cwd != "" {
|
||||
proc.Cwd = &opts.Cwd
|
||||
}
|
||||
}
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
},
|
||||
Stdin: &stdin,
|
||||
Process: proc,
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
@ -294,7 +312,7 @@ func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
||||
|
||||
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which stops
|
||||
// the port scanner/forwarder and marks active connections for post-restore
|
||||
// cleanup before Firecracker freezes vCPUs.
|
||||
// cleanup before the VMM freezes vCPUs.
|
||||
//
|
||||
// Best-effort: the caller should log a warning on error but not abort the pause.
|
||||
func (c *Client) PrepareSnapshot(ctx context.Context) error {
|
||||
@ -317,27 +335,135 @@ func (c *Client) PrepareSnapshot(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
|
||||
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
|
||||
// env vars and the corresponding files under /run/wrenn/ inside the guest.
|
||||
// Must be called after snapshot restore so envd picks up the new sandbox's metadata.
|
||||
// MemoryPreloadStatus mirrors envd's /memory/preload response.
|
||||
//
|
||||
// State values: "idle", "running", "done", "failed", "cancelled".
|
||||
type MemoryPreloadStatus struct {
|
||||
State string `json:"state"`
|
||||
Regions uint64 `json:"regions"`
|
||||
Pages uint64 `json:"pages"`
|
||||
Bytes uint64 `json:"bytes"`
|
||||
ElapsedSec float64 `json:"elapsed_sec"`
|
||||
Source string `json:"source"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StartMemoryPreload posts to envd's /memory/preload to spawn a guest-side
|
||||
// loader that reads every physical RAM page. The request returns immediately
|
||||
// after the loader is queued — the actual materialisation runs in a detached
|
||||
// thread inside envd. Required after a snapshot restore with
|
||||
// memory_restore_mode=ondemand so the next ch.snapshot writes a
|
||||
// self-contained memory-ranges file.
|
||||
//
|
||||
// Use WaitMemoryPreload to block on completion or GetMemoryPreloadStatus to
|
||||
// query progress.
|
||||
func (c *Client) StartMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
|
||||
return c.memoryPreloadRequest(ctx, http.MethodPost)
|
||||
}
|
||||
|
||||
// GetMemoryPreloadStatus reads envd's /memory/preload status without
|
||||
// starting a new loader.
|
||||
func (c *Client) GetMemoryPreloadStatus(ctx context.Context) (MemoryPreloadStatus, error) {
|
||||
return c.memoryPreloadRequest(ctx, http.MethodGet)
|
||||
}
|
||||
|
||||
func (c *Client) memoryPreloadRequest(ctx context.Context, method string) (MemoryPreloadStatus, error) {
|
||||
var status MemoryPreloadStatus
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.base+"/memory/preload", nil)
|
||||
if err != nil {
|
||||
return status, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return status, fmt.Errorf("memory preload %s: %w", method, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return status, fmt.Errorf("memory preload %s: status %d: %s", method, resp.StatusCode, string(body))
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return status, fmt.Errorf("memory preload %s: decode: %w", method, err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// WaitMemoryPreload polls envd until the loader is no longer running or ctx
|
||||
// is cancelled. Returns the final status. Polling interval is fixed at 1s —
|
||||
// the loader runs for many seconds to minutes, so finer polling wastes RPCs.
|
||||
func (c *Client) WaitMemoryPreload(ctx context.Context) (MemoryPreloadStatus, error) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
status, err := c.GetMemoryPreloadStatus(ctx)
|
||||
if err != nil {
|
||||
return status, err
|
||||
}
|
||||
if status.State != "running" {
|
||||
return status, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return status, ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelMemoryPreload signals the in-guest memory preloader to stop early.
|
||||
// Used during teardown so a pause/destroy doesn't have to wait for a
|
||||
// multi-hundred-MiB read to finish.
|
||||
func (c *Client) CancelMemoryPreload(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/memory/preload/cancel", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preload cancel: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("preload cancel: status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostInit calls envd's POST /init endpoint to trigger post-boot or
|
||||
// post-restore initialization. sandbox_id and template_id are passed
|
||||
// so envd can set WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID env vars.
|
||||
func (c *Client) PostInit(ctx context.Context) error {
|
||||
return c.PostInitWithDefaults(ctx, "", nil)
|
||||
return c.PostInitWithDefaults(ctx, "", nil, "", "")
|
||||
}
|
||||
|
||||
// PostInitWithDefaults calls envd's POST /init endpoint with optional default
|
||||
// user and environment variables. These are applied to envd's defaults so all
|
||||
// subsequent process executions use them.
|
||||
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string) error {
|
||||
// user, environment variables, and sandbox metadata. These are applied to
|
||||
// envd's defaults so all subsequent process executions use them.
|
||||
//
|
||||
// timestamp and lifecycle_id are always populated: envd uses them to snap
|
||||
// the guest clock to the host's wall time and to detect post-resume calls
|
||||
// (which trigger port-forwarder restart + NFS remount).
|
||||
func (c *Client) PostInitWithDefaults(ctx context.Context, defaultUser string, envVars map[string]string, sandboxID, templateID string) error {
|
||||
payload := map[string]any{
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"lifecycle_id": uuid.NewString(),
|
||||
}
|
||||
if defaultUser != "" {
|
||||
payload["defaultUser"] = defaultUser
|
||||
}
|
||||
if len(envVars) > 0 {
|
||||
payload["envVars"] = envVars
|
||||
}
|
||||
if sandboxID != "" {
|
||||
payload["sandbox_id"] = sandboxID
|
||||
}
|
||||
if templateID != "" {
|
||||
payload["template_id"] = templateID
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if defaultUser != "" || len(envVars) > 0 {
|
||||
payload := make(map[string]any)
|
||||
if defaultUser != "" {
|
||||
payload["defaultUser"] = defaultUser
|
||||
}
|
||||
if len(envVars) > 0 {
|
||||
payload["envVars"] = envVars
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal init body: %w", err)
|
||||
|
||||
@ -59,6 +59,28 @@ func (c *Client) FetchVersion(ctx context.Context) (string, error) {
|
||||
return data.Version, nil
|
||||
}
|
||||
|
||||
// WaitUntilRPCReady polls envd's Connect RPC layer until it responds
|
||||
// successfully or the context is cancelled. This catches cases where envd's
|
||||
// HTTP health endpoint works but the Connect protocol layer is not yet
|
||||
// functional (e.g., after VM snapshot restore).
|
||||
func (c *Client) WaitUntilRPCReady(ctx context.Context) error {
|
||||
const retryInterval = 200 * time.Millisecond
|
||||
|
||||
ticker := time.NewTicker(retryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("envd RPC not ready: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if _, err := c.ListProcesses(ctx); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheck sends a single GET /health request to envd.
|
||||
func (c *Client) healthCheck(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)
|
||||
|
||||
129
internal/hostagent/callback.go
Normal file
129
internal/hostagent/callback.go
Normal file
@ -0,0 +1,129 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CallbackEvent is the payload sent to the CP's sandbox event callback endpoint.
|
||||
type CallbackEvent struct {
|
||||
Event string `json:"event"`
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
HostID string `json:"host_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// CallbackSender sends sandbox lifecycle events to the CP via HTTP POST.
|
||||
// Used for autonomous agent-side events (auto-pause, auto-destroy) that
|
||||
// the CP cannot observe through its own RPC goroutines.
|
||||
type CallbackSender struct {
|
||||
cpURL string
|
||||
hostID string
|
||||
credFile string
|
||||
client *http.Client
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
}
|
||||
|
||||
// NewCallbackSender creates a callback sender.
|
||||
func NewCallbackSender(cpURL, credFile, hostID string) *CallbackSender {
|
||||
jwt := ""
|
||||
if tf, err := LoadTokenFile(credFile); err == nil {
|
||||
jwt = tf.JWT
|
||||
}
|
||||
return &CallbackSender{
|
||||
cpURL: strings.TrimRight(cpURL, "/"),
|
||||
hostID: hostID,
|
||||
credFile: credFile,
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
jwt: jwt,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateJWT refreshes the JWT used for callback authentication.
|
||||
// Called from the heartbeat's onCredsRefreshed hook.
|
||||
func (s *CallbackSender) UpdateJWT(jwt string) {
|
||||
s.mu.Lock()
|
||||
s.jwt = jwt
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *CallbackSender) getJWT() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.jwt
|
||||
}
|
||||
|
||||
// Send sends a callback event to the CP synchronously with retries.
|
||||
func (s *CallbackSender) Send(ctx context.Context, ev CallbackEvent) error {
|
||||
ev.HostID = s.hostID
|
||||
if ev.Timestamp == 0 {
|
||||
ev.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal callback event: %w", err)
|
||||
}
|
||||
|
||||
url := s.cpURL + "/v1/hosts/sandbox-events"
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * 500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create callback request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Host-Token", s.getJWT())
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
if newCreds, refreshErr := RefreshCredentials(ctx, s.cpURL, s.credFile); refreshErr == nil {
|
||||
s.UpdateJWT(newCreds.JWT)
|
||||
}
|
||||
lastErr = fmt.Errorf("callback auth failed: %d", resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = fmt.Errorf("callback failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return fmt.Errorf("callback failed after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
// SendAsync sends a callback event in a background goroutine.
|
||||
func (s *CallbackSender) SendAsync(ev CallbackEvent) {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.Send(ctx, ev); err != nil {
|
||||
slog.Warn("callback send failed (reconciler will catch it)", "event", ev.Event, "sandbox_id", ev.SandboxID, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
31
internal/hostagent/callback_adapter.go
Normal file
31
internal/hostagent/callback_adapter.go
Normal file
@ -0,0 +1,31 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
|
||||
)
|
||||
|
||||
// callbackAdapter adapts CallbackSender to satisfy sandbox.EventSender.
|
||||
type callbackAdapter struct {
|
||||
sender *CallbackSender
|
||||
}
|
||||
|
||||
// NewEventSender wraps a CallbackSender as a sandbox.EventSender.
|
||||
func NewEventSender(sender *CallbackSender) sandbox.EventSender {
|
||||
return &callbackAdapter{sender: sender}
|
||||
}
|
||||
|
||||
func (a *callbackAdapter) SendAsync(event sandbox.LifecycleEvent) {
|
||||
a.sender.SendAsync(CallbackEvent{
|
||||
Event: event.Event,
|
||||
SandboxID: event.SandboxID,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *callbackAdapter) Send(ctx context.Context, event sandbox.LifecycleEvent) error {
|
||||
return a.sender.Send(ctx, CallbackEvent{
|
||||
Event: event.Event,
|
||||
SandboxID: event.SandboxID,
|
||||
})
|
||||
}
|
||||
@ -2,13 +2,14 @@ package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
|
||||
)
|
||||
|
||||
@ -49,38 +51,48 @@ func parseUUIDString(s string) (pgtype.UUID, error) {
|
||||
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||
}
|
||||
|
||||
// parseSandboxIDs parses the team+template UUID pair every snapshot-targeting
|
||||
// RPC handler receives, returning a CodeInvalidArgument Connect error on the
|
||||
// first failure so the caller can `return nil, err` directly.
|
||||
func parseSandboxIDs(teamIDStr, templateIDStr string) (teamID, templateID pgtype.UUID, err error) {
|
||||
teamID, err = parseUUIDString(teamIDStr)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, pgtype.UUID{}, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err = parseUUIDString(templateIDStr)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, pgtype.UUID{}, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
return teamID, templateID, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSandboxRequest],
|
||||
) (*connect.Response[pb.CreateSandboxResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
teamID, templateID, err := parseSandboxIDs(msg.TeamId, msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
|
||||
sb, diskSizeBytes, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID,
|
||||
int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb),
|
||||
msg.DefaultUser, msg.DefaultEnv)
|
||||
if err != nil {
|
||||
if errors.Is(err, sandbox.ErrDraining) {
|
||||
return nil, connect.NewError(connect.CodeUnavailable, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
|
||||
// Apply template defaults (user, env vars) if provided.
|
||||
if msg.DefaultUser != "" || len(msg.DefaultEnv) > 0 {
|
||||
if err := s.mgr.SetDefaults(ctx, sb.ID, msg.DefaultUser, msg.DefaultEnv); err != nil {
|
||||
slog.Warn("failed to set sandbox defaults", "sandbox", sb.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.CreateSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
Metadata: sb.Metadata,
|
||||
DiskSizeMb: int32(diskSizeBytes / (1024 * 1024)),
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -89,7 +101,7 @@ func (s *Server) DestroySandbox(
|
||||
req *connect.Request[pb.DestroySandboxRequest],
|
||||
) (*connect.Response[pb.DestroySandboxResponse], error) {
|
||||
if err := s.mgr.Destroy(ctx, req.Msg.SandboxId); err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
return connect.NewResponse(&pb.DestroySandboxResponse{}), nil
|
||||
}
|
||||
@ -99,7 +111,7 @@ func (s *Server) PauseSandbox(
|
||||
req *connect.Request[pb.PauseSandboxRequest],
|
||||
) (*connect.Response[pb.PauseSandboxResponse], error) {
|
||||
if err := s.mgr.Pause(ctx, req.Msg.SandboxId); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
return connect.NewResponse(&pb.PauseSandboxResponse{}), nil
|
||||
}
|
||||
@ -108,12 +120,10 @@ func (s *Server) ResumeSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ResumeSandboxRequest],
|
||||
) (*connect.Response[pb.ResumeSandboxResponse], error) {
|
||||
msg := req.Msg
|
||||
sb, err := s.mgr.Resume(ctx, msg.SandboxId, int(msg.TimeoutSec), msg.KernelVersion, msg.DefaultUser, msg.DefaultEnv)
|
||||
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec), req.Msg.DefaultUser, req.Msg.KernelVersion, req.Msg.DefaultEnv)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ResumeSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
@ -126,41 +136,30 @@ func (s *Server) CreateSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSnapshotRequest],
|
||||
) (*connect.Response[pb.CreateSnapshotResponse], error) {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
return nil, err
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
size, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, teamID, templateID, req.Msg.Name)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
return connect.NewResponse(&pb.CreateSnapshotResponse{
|
||||
SizeBytes: sizeBytes,
|
||||
Name: req.Msg.Name,
|
||||
SizeBytes: size,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteSnapshot(
|
||||
ctx context.Context,
|
||||
_ context.Context,
|
||||
req *connect.Request[pb.DeleteSnapshotRequest],
|
||||
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
return nil, err
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
|
||||
}
|
||||
@ -169,22 +168,54 @@ func (s *Server) FlattenRootfs(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.FlattenRootfsRequest],
|
||||
) (*connect.Response[pb.FlattenRootfsResponse], error) {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
return nil, err
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
size, err := s.mgr.FlattenRootfs(ctx, req.Msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
return connect.NewResponse(&pb.FlattenRootfsResponse{
|
||||
SizeBytes: sizeBytes,
|
||||
SizeBytes: size,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// mapSandboxError translates sandbox.Manager errors to Connect error codes
|
||||
// via sentinel errors (errors.Is). Adding a new precondition sentinel in the
|
||||
// sandbox package only requires extending this switch — no string sniffing.
|
||||
func mapSandboxError(err error) error {
|
||||
switch {
|
||||
case errors.Is(err, sandbox.ErrNotFound):
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
case errors.Is(err, sandbox.ErrNotRunning), errors.Is(err, sandbox.ErrNotPaused):
|
||||
return connect.NewError(connect.CodeFailedPrecondition, err)
|
||||
case errors.Is(err, sandbox.ErrDraining):
|
||||
return connect.NewError(connect.CodeUnavailable, err)
|
||||
case errors.Is(err, sandbox.ErrInvalidRange):
|
||||
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||
default:
|
||||
return connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GetTemplateSize(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.GetTemplateSizeRequest],
|
||||
) (*connect.Response[pb.GetTemplateSizeResponse], error) {
|
||||
teamID, templateID, err := parseSandboxIDs(req.Msg.TeamId, req.Msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size, err := s.mgr.TemplateRootfsSize(teamID, templateID)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("get template size: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.GetTemplateSizeResponse{
|
||||
SizeBytes: size,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@ -193,7 +224,7 @@ func (s *Server) PingSandbox(
|
||||
req *connect.Request[pb.PingSandboxRequest],
|
||||
) (*connect.Response[pb.PingSandboxResponse], error) {
|
||||
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
if errors.Is(err, sandbox.ErrNotFound) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
|
||||
@ -215,7 +246,12 @@ func (s *Server) Exec(
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
|
||||
var opts *envdclient.ExecOpts
|
||||
if len(msg.Envs) > 0 || msg.Cwd != "" {
|
||||
opts = &envdclient.ExecOpts{Envs: msg.Envs, Cwd: msg.Cwd}
|
||||
}
|
||||
|
||||
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args, opts)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("exec: %w", err))
|
||||
}
|
||||
@ -227,6 +263,17 @@ func (s *Server) Exec(
|
||||
}), nil
|
||||
}
|
||||
|
||||
// envdErr propagates an error from the envd client, preserving its Connect
|
||||
// error code (e.g. AlreadyExists, NotFound) so the control plane maps it to
|
||||
// the correct HTTP status. Non-Connect errors fall back to CodeInternal.
|
||||
func envdErr(action string, err error) error {
|
||||
code := connect.CodeOf(err)
|
||||
if code == connect.CodeUnknown {
|
||||
code = connect.CodeInternal
|
||||
}
|
||||
return connect.NewError(code, fmt.Errorf("%s: %w", action, err))
|
||||
}
|
||||
|
||||
func (s *Server) WriteFile(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.WriteFileRequest],
|
||||
@ -239,7 +286,7 @@ func (s *Server) WriteFile(
|
||||
}
|
||||
|
||||
if err := client.WriteFile(ctx, msg.Path, msg.Content); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file: %w", err))
|
||||
return nil, envdErr("write file", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.WriteFileResponse{}), nil
|
||||
@ -258,7 +305,7 @@ func (s *Server) ReadFile(
|
||||
|
||||
content, err := client.ReadFile(ctx, msg.Path)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("read file: %w", err))
|
||||
return nil, envdErr("read file", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
@ -277,7 +324,7 @@ func (s *Server) ListDir(
|
||||
|
||||
resp, err := client.ListDir(ctx, msg.Path, msg.Depth)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list dir: %w", err))
|
||||
return nil, envdErr("list dir", err)
|
||||
}
|
||||
|
||||
entries := make([]*pb.FileEntry, 0, len(resp.Entries))
|
||||
@ -301,7 +348,7 @@ func (s *Server) MakeDir(
|
||||
|
||||
resp, err := client.MakeDir(ctx, msg.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("make dir: %w", err)
|
||||
return nil, envdErr("make dir", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.MakeDirResponse{
|
||||
@ -321,7 +368,7 @@ func (s *Server) RemovePath(
|
||||
}
|
||||
|
||||
if err := client.Remove(ctx, msg.Path); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("remove: %w", err))
|
||||
return nil, envdErr("remove", err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.RemovePathResponse{}), nil
|
||||
@ -373,6 +420,8 @@ func (s *Server) ExecStream(
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
@ -548,6 +597,14 @@ func (s *Server) ListSandboxes(
|
||||
|
||||
infos := make([]*pb.SandboxInfo, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
// Paused / restored-paused sandboxes have no active network slot, so
|
||||
// HostIP is nil — net.IP(nil).String() returns the literal "<nil>"
|
||||
// which would leak into DB host_ip columns and SDK responses. Emit
|
||||
// empty string instead.
|
||||
hostIP := ""
|
||||
if sb.HostIP != nil {
|
||||
hostIP = sb.HostIP.String()
|
||||
}
|
||||
infos[i] = &pb.SandboxInfo{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
@ -555,7 +612,7 @@ func (s *Server) ListSandboxes(
|
||||
TemplateId: uuid.UUID(sb.TemplateID).String(),
|
||||
Vcpus: int32(sb.VCPUs),
|
||||
MemoryMb: int32(sb.MemoryMB),
|
||||
HostIp: sb.HostIP.String(),
|
||||
HostIp: hostIP,
|
||||
CreatedAtUnix: sb.CreatedAt.Unix(),
|
||||
LastActiveAtUnix: sb.LastActiveAt.Unix(),
|
||||
TimeoutSec: int32(sb.TimeoutSec),
|
||||
@ -588,13 +645,7 @@ func (s *Server) GetSandboxMetrics(
|
||||
|
||||
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "invalid range") {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.GetSandboxMetricsResponse{Points: metricPointsToPB(points)}), nil
|
||||
@ -606,10 +657,7 @@ func (s *Server) FlushSandboxMetrics(
|
||||
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
|
||||
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
return nil, mapSandboxError(err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.FlushSandboxMetricsResponse{
|
||||
@ -759,7 +807,7 @@ func (s *Server) StartBackground(
|
||||
|
||||
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
if errors.Is(err, sandbox.ErrNotFound) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
|
||||
@ -777,7 +825,7 @@ func (s *Server) ListProcesses(
|
||||
) (*connect.Response[pb.ListProcessesResponse], error) {
|
||||
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
if errors.Is(err, sandbox.ErrNotFound) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
|
||||
@ -828,7 +876,7 @@ func (s *Server) KillProcess(
|
||||
}
|
||||
|
||||
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
if errors.Is(err, sandbox.ErrNotFound) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
|
||||
@ -857,7 +905,7 @@ func (s *Server) ConnectProcess(
|
||||
|
||||
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
if errors.Is(err, sandbox.ErrNotFound) {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))
|
||||
@ -889,6 +937,8 @@ func (s *Server) ConnectProcess(
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
|
||||
@ -6,26 +6,28 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// IsMinimal reports whether the given team and template IDs represent the
|
||||
// built-in "minimal" template (both all-zeros).
|
||||
func IsMinimal(teamID, templateID pgtype.UUID) bool {
|
||||
return teamID.Bytes == id.PlatformTeamID.Bytes && templateID.Bytes == id.MinimalTemplateID.Bytes
|
||||
func timeNowNano() int64 { return time.Now().UnixNano() }
|
||||
|
||||
// IsSystemTemplate reports whether the given team and template IDs represent a
|
||||
// built-in system base template (minimal-ubuntu / -alpine / -arch / -fedora):
|
||||
// platform-owned with a template ID in the reserved range. System templates are
|
||||
// protected from deletion.
|
||||
func IsSystemTemplate(teamID, templateID pgtype.UUID) bool {
|
||||
return teamID.Bytes == id.PlatformTeamID.Bytes && id.IsReservedTemplateID(templateID)
|
||||
}
|
||||
|
||||
// TemplateDir returns the on-disk directory for a template.
|
||||
// TemplateDir returns the on-disk directory for a template. Every template —
|
||||
// including the built-in system base templates — lives under the teams tree:
|
||||
//
|
||||
// minimal (zeros, zeros): {wrennDir}/images/minimal
|
||||
// all others: {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
|
||||
// {wrennDir}/images/teams/{base36(teamID)}/{base36(templateID)}
|
||||
func TemplateDir(wrennDir string, teamID, templateID pgtype.UUID) string {
|
||||
if IsMinimal(teamID, templateID) {
|
||||
return filepath.Join(wrennDir, "images", "minimal")
|
||||
}
|
||||
return filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(teamID.Bytes),
|
||||
id.UUIDToBase36(templateID.Bytes))
|
||||
@ -36,17 +38,64 @@ func TemplateRootfs(wrennDir string, teamID, templateID pgtype.UUID) string {
|
||||
return filepath.Join(TemplateDir(wrennDir, teamID, templateID), "rootfs.ext4")
|
||||
}
|
||||
|
||||
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
|
||||
func PauseSnapshotDir(wrennDir, sandboxID string) string {
|
||||
return filepath.Join(wrennDir, "snapshots", sandboxID)
|
||||
// IsSnapshotTemplate reports whether dir contains a Cloud Hypervisor memory
|
||||
// snapshot (state.json + config.json + memory-ranges) alongside the flattened
|
||||
// rootfs.ext4. Used to distinguish snapshot templates (launch via CH restore)
|
||||
// from base/disk-only templates (launch via fresh boot).
|
||||
//
|
||||
// state.json is CH-authoritative — its presence indicates a complete snapshot.
|
||||
func IsSnapshotTemplate(dir string) bool {
|
||||
for _, name := range []string{"state.json", "config.json", "rootfs.ext4"} {
|
||||
if _, err := os.Stat(filepath.Join(dir, name)); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SandboxesDir returns the directory for running sandbox CoW files.
|
||||
// SandboxCowName is the filename for a sandbox's CoW rootfs diff, kept inside
|
||||
// the per-sandbox directory alongside any pause snapshot files.
|
||||
const SandboxCowName = "rootfs.cow"
|
||||
|
||||
// SandboxDir returns the per-sandbox directory under sandboxes/. It holds
|
||||
// the CoW file and, if the sandbox is paused, the snapshot files.
|
||||
//
|
||||
// Layout:
|
||||
//
|
||||
// {wrennDir}/sandboxes/{id}/rootfs.cow CoW file (persistent across pause/resume)
|
||||
// {wrennDir}/sandboxes/{id}/ paused snapshot (config.json, state.json, memory-ranges, wrenn-snapshot.json)
|
||||
// {wrennDir}/sandboxes/{id}.staging-*/ in-flight Pause writes (cleaned up by swapDir or startup GC)
|
||||
// {wrennDir}/sandboxes/{id}.trash-*/ mid-swap previous generation (cleaned up by swapDir or startup GC)
|
||||
func SandboxDir(wrennDir, sandboxID string) string {
|
||||
return filepath.Join(wrennDir, "sandboxes", sandboxID)
|
||||
}
|
||||
|
||||
// SandboxCowPath returns the path to a sandbox's CoW rootfs diff file.
|
||||
func SandboxCowPath(wrennDir, sandboxID string) string {
|
||||
return filepath.Join(SandboxDir(wrennDir, sandboxID), SandboxCowName)
|
||||
}
|
||||
|
||||
// PauseSnapshotDir returns the directory for a paused sandbox's snapshot files.
|
||||
// Same path as SandboxDir — pause snapshot files live alongside the CoW.
|
||||
func PauseSnapshotDir(wrennDir, sandboxID string) string {
|
||||
return SandboxDir(wrennDir, sandboxID)
|
||||
}
|
||||
|
||||
// PauseStagingDir returns a fresh staging directory for an in-flight Pause.
|
||||
// Each call returns a unique path (timestamped) so concurrent retries do not
|
||||
// collide.
|
||||
func PauseStagingDir(wrennDir, sandboxID string) string {
|
||||
return filepath.Join(wrennDir, "sandboxes",
|
||||
fmt.Sprintf("%s.staging-%d", sandboxID, timeNowNano()))
|
||||
}
|
||||
|
||||
// SandboxesDir returns the directory for running sandbox CoW files and paused
|
||||
// snapshot directories.
|
||||
func SandboxesDir(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "sandboxes")
|
||||
}
|
||||
|
||||
// KernelPath returns the path to the Firecracker kernel.
|
||||
// KernelPath returns the path to the VM kernel.
|
||||
func KernelPath(wrennDir string) string {
|
||||
return filepath.Join(wrennDir, "kernels", "vmlinux")
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
func TestIsMinimal(t *testing.T) {
|
||||
func TestIsSystemTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teamID pgtype.UUID
|
||||
@ -17,35 +17,41 @@ func TestIsMinimal(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "both zeros",
|
||||
name: "ubuntu (zeros, zeros)",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: id.MinimalTemplateID,
|
||||
templateID: id.UbuntuTemplateID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-zero team",
|
||||
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
|
||||
templateID: id.MinimalTemplateID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-zero template",
|
||||
name: "fedora (platform, id 3)",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
|
||||
templateID: id.FedoraTemplateID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "platform, max reserved id",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x00}, Valid: true}, // 1024
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "platform, above reserved range",
|
||||
teamID: id.PlatformTeamID,
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x01}, Valid: true}, // 1025
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "both non-zero",
|
||||
teamID: pgtype.UUID{Bytes: [16]byte{1}, Valid: true},
|
||||
templateID: pgtype.UUID{Bytes: [16]byte{2}, Valid: true},
|
||||
name: "non-platform team, reserved id",
|
||||
teamID: pgtype.UUID{Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Valid: true},
|
||||
templateID: id.UbuntuTemplateID,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsMinimal(tt.teamID, tt.templateID); got != tt.want {
|
||||
t.Errorf("IsMinimal() = %v, want %v", got, tt.want)
|
||||
if got := IsSystemTemplate(tt.teamID, tt.templateID); got != tt.want {
|
||||
t.Errorf("IsSystemTemplate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -54,9 +60,11 @@ func TestIsMinimal(t *testing.T) {
|
||||
func TestTemplateDir(t *testing.T) {
|
||||
wrennDir := "/var/lib/wrenn"
|
||||
|
||||
t.Run("minimal", func(t *testing.T) {
|
||||
got := TemplateDir(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "minimal")
|
||||
t.Run("system base template (ubuntu) lives under teams", func(t *testing.T) {
|
||||
got := TemplateDir(wrennDir, id.PlatformTeamID, id.UbuntuTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(id.PlatformTeamID.Bytes),
|
||||
id.UUIDToBase36(id.UbuntuTemplateID.Bytes))
|
||||
if got != want {
|
||||
t.Errorf("TemplateDir() = %q, want %q", got, want)
|
||||
}
|
||||
@ -88,8 +96,11 @@ func TestTemplateDir(t *testing.T) {
|
||||
|
||||
func TestTemplateRootfs(t *testing.T) {
|
||||
wrennDir := "/var/lib/wrenn"
|
||||
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "minimal", "rootfs.ext4")
|
||||
got := TemplateRootfs(wrennDir, id.PlatformTeamID, id.UbuntuTemplateID)
|
||||
want := filepath.Join(wrennDir, "images", "teams",
|
||||
id.UUIDToBase36(id.PlatformTeamID.Bytes),
|
||||
id.UUIDToBase36(id.UbuntuTemplateID.Bytes),
|
||||
"rootfs.ext4")
|
||||
if got != want {
|
||||
t.Errorf("TemplateRootfs() = %q, want %q", got, want)
|
||||
}
|
||||
@ -97,12 +108,20 @@ func TestTemplateRootfs(t *testing.T) {
|
||||
|
||||
func TestPauseSnapshotDir(t *testing.T) {
|
||||
got := PauseSnapshotDir("/var/lib/wrenn", "cl-abc123")
|
||||
want := "/var/lib/wrenn/snapshots/cl-abc123"
|
||||
want := "/var/lib/wrenn/sandboxes/cl-abc123"
|
||||
if got != want {
|
||||
t.Errorf("PauseSnapshotDir() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseStagingDir(t *testing.T) {
|
||||
got := PauseStagingDir("/var/lib/wrenn", "cl-abc123")
|
||||
prefix := "/var/lib/wrenn/sandboxes/cl-abc123.staging-"
|
||||
if len(got) <= len(prefix) || got[:len(prefix)] != prefix {
|
||||
t.Errorf("PauseStagingDir() = %q, want prefix %q", got, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxesDir(t *testing.T) {
|
||||
got := SandboxesDir("/var/lib/wrenn")
|
||||
want := "/var/lib/wrenn/sandboxes"
|
||||
|
||||
@ -9,12 +9,13 @@ import (
|
||||
type SandboxStatus string
|
||||
|
||||
const (
|
||||
StatusPending SandboxStatus = "pending"
|
||||
StatusRunning SandboxStatus = "running"
|
||||
StatusPausing SandboxStatus = "pausing"
|
||||
StatusPaused SandboxStatus = "paused"
|
||||
StatusStopped SandboxStatus = "stopped"
|
||||
StatusError SandboxStatus = "error"
|
||||
StatusPending SandboxStatus = "pending"
|
||||
StatusRunning SandboxStatus = "running"
|
||||
StatusPausing SandboxStatus = "pausing"
|
||||
StatusPaused SandboxStatus = "paused"
|
||||
StatusSnapshotting SandboxStatus = "snapshotting"
|
||||
StatusStopped SandboxStatus = "stopped"
|
||||
StatusError SandboxStatus = "error"
|
||||
)
|
||||
|
||||
// Sandbox holds all state for a running sandbox on this host.
|
||||
|
||||
@ -39,3 +39,19 @@ func (a *SlotAllocator) Release(index int) {
|
||||
defer a.mu.Unlock()
|
||||
delete(a.inUse, index)
|
||||
}
|
||||
|
||||
// Reserve marks a specific slot index as in use. Returns an error if the
|
||||
// index is out of range or already taken. Used on resume to re-acquire the
|
||||
// slot a sandbox previously held so its host-reachable IP stays stable.
|
||||
func (a *SlotAllocator) Reserve(index int) error {
|
||||
if index < 1 || index > 32767 {
|
||||
return fmt.Errorf("slot index out of range: %d", index)
|
||||
}
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.inUse[index] {
|
||||
return fmt.Errorf("slot %d already in use", index)
|
||||
}
|
||||
a.inUse[index] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -42,6 +42,43 @@ func CleanupStaleNamespaces() {
|
||||
|
||||
// Clean up any stale wrenn iptables rules referencing old veth interfaces.
|
||||
cleanupStaleIptablesRules()
|
||||
|
||||
// Flush any orphan conntrack rows for sandbox host-IPs. After a wedged
|
||||
// destroy the netfilter conntrack table can retain DNAT/SNAT entries
|
||||
// pointing at vanished interfaces, which makes new flows to recycled
|
||||
// slot IPs misroute. Best-effort; missing conntrack binary is OK.
|
||||
flushStaleConntrack()
|
||||
}
|
||||
|
||||
// flushStaleConntrack removes conntrack rows referencing the sandbox host
|
||||
// IP range (10.11.0.0/16) and the namespace veth range (10.12.0.0/16).
|
||||
// Best-effort: silently skipped if conntrack(8) is absent.
|
||||
func flushStaleConntrack() {
|
||||
if _, err := exec.LookPath("conntrack"); err != nil {
|
||||
slog.Debug("conntrack binary not found, skipping flush")
|
||||
return
|
||||
}
|
||||
flushed := 0
|
||||
for _, cidr := range []string{"10.11.0.0/16", "10.12.0.0/16"} {
|
||||
for _, dir := range []string{"--src", "--dst"} {
|
||||
out, err := exec.Command("conntrack", "-D", dir, cidr).CombinedOutput()
|
||||
if err != nil {
|
||||
// conntrack -D exits 1 when no entries match; not an
|
||||
// error from our perspective.
|
||||
slog.Debug("conntrack flush", "cidr", cidr, "dir", dir, "error", err)
|
||||
continue
|
||||
}
|
||||
// Output looks like "conntrack v1.4.x ... 3 flow entries have been deleted."
|
||||
// We only log INFO when at least one row was actually removed.
|
||||
if strings.Contains(string(out), "have been deleted") &&
|
||||
!strings.Contains(string(out), "0 flow entries") {
|
||||
flushed++
|
||||
}
|
||||
}
|
||||
}
|
||||
if flushed > 0 {
|
||||
slog.Info("flushed stale conntrack entries", "matched_filters", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleIptablesRules removes host iptables rules that reference
|
||||
@ -176,7 +213,7 @@ func NewSlot(index int) *Slot {
|
||||
// CreateNetwork sets up the full network topology for a sandbox:
|
||||
// - Named network namespace
|
||||
// - Veth pair bridging host and namespace
|
||||
// - TAP device inside namespace for Firecracker
|
||||
// - TAP device inside namespace for Cloud Hypervisor
|
||||
// - Routes and NAT rules for connectivity
|
||||
//
|
||||
// On error, all partially created resources are rolled back.
|
||||
@ -430,6 +467,9 @@ func CreateNetwork(slot *Slot) error {
|
||||
rollback()
|
||||
return fmt.Errorf("add masquerade rule: %w", err)
|
||||
}
|
||||
rollbacks = append(rollbacks, func() {
|
||||
_ = iptablesHost("-t", "nat", "-D", "POSTROUTING", "-s", fmt.Sprintf("%s/32", slot.VpeerIP.String()), "-o", defaultIface, "-j", "MASQUERADE")
|
||||
})
|
||||
|
||||
slog.Info("network created",
|
||||
"ns", slot.NamespaceID,
|
||||
@ -444,6 +484,9 @@ func CreateNetwork(slot *Slot) error {
|
||||
// All steps are attempted even if earlier ones fail. Returns a combined
|
||||
// error describing which cleanup steps failed.
|
||||
func RemoveNetwork(slot *Slot) error {
|
||||
if slot == nil {
|
||||
return nil
|
||||
}
|
||||
var errs []error
|
||||
|
||||
defaultIface, _ := getDefaultInterface()
|
||||
|
||||
@ -41,6 +41,28 @@ type ExecFunc func(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*
|
||||
// accumulated log entries. Used for per-step DB progress updates.
|
||||
type ProgressFunc func(step int, entries []BuildLogEntry)
|
||||
|
||||
// StepStartFunc is called immediately before a step begins executing.
|
||||
type StepStartFunc func(step int, phase string, st Step)
|
||||
|
||||
// OutputChunkFunc is called with each raw output chunk produced by a streaming
|
||||
// RUN step, as it arrives.
|
||||
type OutputChunkFunc func(step int, data []byte)
|
||||
|
||||
// PtyChunk is one event from a streaming PTY exec: either an output chunk
|
||||
// (Data set) or the terminal result (Done set, Exit/Err populated).
|
||||
type PtyChunk struct {
|
||||
Data []byte
|
||||
Done bool
|
||||
Exit int32
|
||||
Err error
|
||||
}
|
||||
|
||||
// StreamExecFunc runs shellCmd in a PTY inside sandboxID and returns a channel
|
||||
// of PtyChunk events. The channel is closed after a Done chunk (or an Err
|
||||
// chunk). It is the streaming counterpart of ExecFunc, used for RUN steps so
|
||||
// build output reaches the client live.
|
||||
type StreamExecFunc func(ctx context.Context, sandboxID, shellCmd string) (<-chan PtyChunk, error)
|
||||
|
||||
// Execute runs steps sequentially against sandboxID using execFn.
|
||||
//
|
||||
// - phase labels the log entries (e.g., "pre-build", "recipe", "post-build").
|
||||
@ -63,6 +85,9 @@ func Execute(
|
||||
defaultTimeout time.Duration,
|
||||
bctx *ExecContext,
|
||||
execFn ExecFunc,
|
||||
streamFn StreamExecFunc,
|
||||
onStepStart StepStartFunc,
|
||||
onChunk OutputChunkFunc,
|
||||
onProgress ProgressFunc,
|
||||
) (entries []BuildLogEntry, nextStep int, ok bool) {
|
||||
if defaultTimeout <= 0 {
|
||||
@ -73,6 +98,9 @@ func Execute(
|
||||
for _, st := range steps {
|
||||
step++
|
||||
slog.Info("executing build step", "phase", phase, "step", step, "instruction", st.Raw)
|
||||
if onStepStart != nil {
|
||||
onStepStart(step, phase, st)
|
||||
}
|
||||
|
||||
switch st.Kind {
|
||||
case KindENV:
|
||||
@ -120,7 +148,13 @@ func Execute(
|
||||
if st.Timeout > 0 {
|
||||
timeout = st.Timeout
|
||||
}
|
||||
entry, succeeded := execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
|
||||
var entry BuildLogEntry
|
||||
var succeeded bool
|
||||
if streamFn != nil {
|
||||
entry, succeeded = execRunStreaming(ctx, st, sandboxID, phase, step, timeout, bctx, streamFn, onChunk)
|
||||
} else {
|
||||
entry, succeeded = execRun(ctx, st, sandboxID, phase, step, timeout, bctx, execFn)
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
if !succeeded {
|
||||
return entries, step, false
|
||||
@ -171,6 +205,66 @@ func execRun(
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
// execRunStreaming runs a RUN step in a PTY via streamFn, forwarding each
|
||||
// output chunk to onChunk as it arrives. The merged PTY output is also
|
||||
// accumulated into the returned BuildLogEntry.Stdout for cold log viewing.
|
||||
// A PTY merges stdout and stderr onto one stream, so Stderr stays empty
|
||||
// unless the exec itself fails to start.
|
||||
func execRunStreaming(
|
||||
ctx context.Context,
|
||||
st Step,
|
||||
sandboxID, phase string,
|
||||
step int,
|
||||
timeout time.Duration,
|
||||
bctx *ExecContext,
|
||||
streamFn StreamExecFunc,
|
||||
onChunk OutputChunkFunc,
|
||||
) (BuildLogEntry, bool) {
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
entry := BuildLogEntry{Step: step, Phase: phase, Cmd: st.Raw}
|
||||
|
||||
ch, err := streamFn(execCtx, sandboxID, bctx.WrappedCommand(st.Shell))
|
||||
if err != nil {
|
||||
entry.Stderr = fmt.Sprintf("exec error: %v", err)
|
||||
entry.Elapsed = time.Since(start).Milliseconds()
|
||||
return entry, false
|
||||
}
|
||||
|
||||
var out []byte
|
||||
gotDone := false
|
||||
for chunk := range ch {
|
||||
if chunk.Err != nil {
|
||||
entry.Stdout = string(out)
|
||||
entry.Stderr = fmt.Sprintf("exec error: %v", chunk.Err)
|
||||
entry.Elapsed = time.Since(start).Milliseconds()
|
||||
return entry, false
|
||||
}
|
||||
if chunk.Done {
|
||||
entry.Exit = chunk.Exit
|
||||
gotDone = true
|
||||
continue
|
||||
}
|
||||
out = append(out, chunk.Data...)
|
||||
if onChunk != nil {
|
||||
onChunk(step, chunk.Data)
|
||||
}
|
||||
}
|
||||
|
||||
entry.Stdout = string(out)
|
||||
entry.Elapsed = time.Since(start).Milliseconds()
|
||||
// A channel that closes without a Done chunk means the stream ended
|
||||
// early (cancelled/aborted). Treat that as a failure, never a success.
|
||||
if !gotDone {
|
||||
entry.Stderr = "exec error: build step stream ended without completion"
|
||||
return entry, false
|
||||
}
|
||||
entry.Ok = entry.Exit == 0
|
||||
return entry, entry.Ok
|
||||
}
|
||||
|
||||
// execUser creates a unix user (if not exists), grants passwordless sudo,
|
||||
// and updates bctx.User for subsequent steps.
|
||||
func execUser(
|
||||
|
||||
28
internal/sandbox/chversion.go
Normal file
28
internal/sandbox/chversion.go
Normal file
@ -0,0 +1,28 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectCHVersion runs the cloud-hypervisor binary with --version and
|
||||
// parses the semver from the output (e.g. "cloud-hypervisor v43.0" → "43.0").
|
||||
func DetectCHVersion(binaryPath string) (string, error) {
|
||||
out, err := exec.Command(binaryPath, "--version").Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
|
||||
}
|
||||
|
||||
line := strings.TrimSpace(string(out))
|
||||
for field := range strings.FieldsSeq(line) {
|
||||
v := strings.TrimPrefix(field, "v")
|
||||
if v != field || strings.Contains(field, ".") {
|
||||
if strings.Count(v, ".") >= 1 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse version from cloud-hypervisor output: %q", line)
|
||||
}
|
||||
@ -10,12 +10,22 @@ import (
|
||||
// ConnTracker tracks active proxy connections for a single sandbox and
|
||||
// provides a drain mechanism for pre-pause graceful shutdown.
|
||||
// It is safe for concurrent use.
|
||||
//
|
||||
// Internally we do not use sync.WaitGroup because Wait cannot be interrupted
|
||||
// — a stuck handler would pin the waiter goroutine forever. Instead we keep
|
||||
// an explicit counter guarded by mu plus a zeroCh that is closed when the
|
||||
// counter transitions to 0, allowing Drain/ForceClose to select on it
|
||||
// alongside cancellation and timeout signals without spawning helper
|
||||
// goroutines that could leak across Reset boundaries.
|
||||
type ConnTracker struct {
|
||||
draining atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
|
||||
mu sync.Mutex
|
||||
count int
|
||||
zeroCh chan struct{} // closed when count drops to 0; recreated on next Acquire
|
||||
|
||||
// cancelMu protects cancelDrain so Reset can signal a timed-out Drain
|
||||
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
|
||||
// to exit early.
|
||||
cancelMu sync.Mutex
|
||||
cancelDrain chan struct{}
|
||||
|
||||
@ -40,13 +50,18 @@ func (t *ConnTracker) Acquire() bool {
|
||||
if t.draining.Load() {
|
||||
return false
|
||||
}
|
||||
t.wg.Add(1)
|
||||
// Re-check after Add: Drain may have set draining between our Load
|
||||
// and Add. If so, undo the Add and reject the connection.
|
||||
t.mu.Lock()
|
||||
// Re-check under mu so a concurrent Drain that flipped draining cannot
|
||||
// race past us with the counter already incremented.
|
||||
if t.draining.Load() {
|
||||
t.wg.Done()
|
||||
t.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
t.count++
|
||||
if t.count == 1 {
|
||||
t.zeroCh = make(chan struct{})
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
@ -63,11 +78,32 @@ func (t *ConnTracker) Context() context.Context {
|
||||
// Release marks one connection as complete. Must be called exactly once
|
||||
// per successful Acquire.
|
||||
func (t *ConnTracker) Release() {
|
||||
t.wg.Done()
|
||||
t.mu.Lock()
|
||||
t.count--
|
||||
if t.count == 0 && t.zeroCh != nil {
|
||||
close(t.zeroCh)
|
||||
t.zeroCh = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// waitDrain returns a channel that closes when the in-flight count is zero,
|
||||
// or a closed channel immediately if there's nothing in flight.
|
||||
func (t *ConnTracker) waitDrain() <-chan struct{} {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.count == 0 {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
return t.zeroCh
|
||||
}
|
||||
|
||||
// Drain marks the tracker as draining (all future Acquire calls return
|
||||
// false) and waits up to timeout for in-flight connections to finish.
|
||||
// Returns when the count hits 0, Reset is called, or the timeout fires —
|
||||
// whichever happens first. No goroutine is leaked on timeout.
|
||||
func (t *ConnTracker) Drain(timeout time.Duration) {
|
||||
t.draining.Store(true)
|
||||
|
||||
@ -76,16 +112,9 @@ func (t *ConnTracker) Drain(timeout time.Duration) {
|
||||
t.cancelDrain = cancel
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-t.waitDrain():
|
||||
case <-cancel:
|
||||
// Reset was called; stop waiting.
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
@ -101,22 +130,16 @@ func (t *ConnTracker) ForceClose() {
|
||||
}
|
||||
t.ctxMu.Unlock()
|
||||
|
||||
// Wait briefly for force-closed connections to call Release().
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-t.waitDrain():
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
// Reset re-enables the tracker after a failed drain. This allows the
|
||||
// sandbox to accept proxy connections again if the pause operation fails
|
||||
// and the VM is resumed. It also cancels any lingering Drain goroutine
|
||||
// and creates a fresh context for new connections.
|
||||
// and the VM is resumed. It also signals any lingering Drain to exit and
|
||||
// creates a fresh context for new connections.
|
||||
func (t *ConnTracker) Reset() {
|
||||
t.cancelMu.Lock()
|
||||
if t.cancelDrain != nil {
|
||||
@ -130,7 +153,6 @@ func (t *ConnTracker) Reset() {
|
||||
}
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
// Replace the cancelled context with a fresh one.
|
||||
t.ctxMu.Lock()
|
||||
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||
t.ctxMu.Unlock()
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFirecrackerVersion runs the firecracker binary with --version and
|
||||
// parses the semver from the output (e.g. "Firecracker v1.14.1" → "1.14.1").
|
||||
func DetectFirecrackerVersion(binaryPath string) (string, error) {
|
||||
out, err := exec.Command(binaryPath, "--version").Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run %s --version: %w", binaryPath, err)
|
||||
}
|
||||
|
||||
// Output is typically "Firecracker v1.14.1\n" or similar.
|
||||
line := strings.TrimSpace(string(out))
|
||||
for _, field := range strings.Fields(line) {
|
||||
v := strings.TrimPrefix(field, "v")
|
||||
if v != field || strings.Contains(field, ".") {
|
||||
// Either had a "v" prefix or contains a dot — likely the version.
|
||||
if strings.Count(v, ".") >= 1 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse version from firecracker output: %q", line)
|
||||
}
|
||||
@ -9,6 +9,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
@ -29,13 +31,9 @@ func EnsureImageSizes(wrennDir string, targetMB int) error {
|
||||
}
|
||||
targetBytes := int64(targetMB) * 1024 * 1024
|
||||
|
||||
// Expand the built-in minimal image.
|
||||
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
if err := expandImage(minimalRootfs, targetBytes, targetMB); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep.
|
||||
// Walk teams/{teamDir}/{templateDir}/rootfs.ext4 two levels deep. The
|
||||
// built-in system base templates live under teams/{base36(0)}/... so this
|
||||
// covers them too.
|
||||
teamsDir := layout.TeamsDir(wrennDir)
|
||||
teamEntries, err := os.ReadDir(teamsDir)
|
||||
if err != nil {
|
||||
@ -104,12 +102,19 @@ func ParseSizeToMB(s string) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// ShrinkMinimalImage shrinks the built-in minimal rootfs back to its minimum
|
||||
// size using resize2fs -M. This is the inverse of EnsureImageSizes and should
|
||||
// be called during graceful shutdown so the image is stored compactly on disk.
|
||||
func ShrinkMinimalImage(wrennDir string) {
|
||||
minimalRootfs := layout.TemplateRootfs(wrennDir, id.PlatformTeamID, id.MinimalTemplateID)
|
||||
shrinkImage(minimalRootfs)
|
||||
// ShrinkSystemImages shrinks the built-in system base rootfs images back to
|
||||
// their minimum size using resize2fs -M. This is the inverse of
|
||||
// EnsureImageSizes and should be called during graceful shutdown so the images
|
||||
// are stored compactly on disk.
|
||||
func ShrinkSystemImages(wrennDir string) {
|
||||
for _, tmplID := range []pgtype.UUID{
|
||||
id.UbuntuTemplateID,
|
||||
id.AlpineTemplateID,
|
||||
id.ArchTemplateID,
|
||||
id.FedoraTemplateID,
|
||||
} {
|
||||
shrinkImage(layout.TemplateRootfs(wrennDir, id.PlatformTeamID, tmplID))
|
||||
}
|
||||
}
|
||||
|
||||
// shrinkImage shrinks a single rootfs image to its minimum size.
|
||||
|
||||
187
internal/sandbox/launch_snapshot.go
Normal file
187
internal/sandbox/launch_snapshot.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Package sandbox: launching a fresh sandbox from a snapshot template.
|
||||
//
|
||||
// Mirrors the pause/resume restore path but produces a brand-new sandbox each
|
||||
// call: fresh ID, fresh network slot, fresh CoW on top of the template's
|
||||
// flattened rootfs. The CH process is launched with --restore + lazy memory
|
||||
// (UFFD), and the post-restore memory loader is started so any subsequent
|
||||
// CreateSnapshot taken from this descendant is self-contained (the
|
||||
// pause-resume-pause chain guarantee, applied to template lineages).
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/devicemapper"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/models"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/network"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// createFromSnapshotTemplate launches a new sandbox from a snapshot-template
|
||||
// directory (state.json + config.json + memory-ranges + rootfs.ext4).
|
||||
//
|
||||
// The caller has already verified IsSnapshotTemplate(templateDir). Resources
|
||||
// acquired here are rolled back on any failure; on success the sandbox is
|
||||
// registered in m.boxes and runs in StatusRunning.
|
||||
func (m *Manager) createFromSnapshotTemplate(
|
||||
ctx context.Context,
|
||||
sandboxID string,
|
||||
teamID, templateID pgtype.UUID,
|
||||
vcpus, memoryMB, timeoutSec, diskSizeMB int,
|
||||
defaultUser string,
|
||||
defaultEnv map[string]string,
|
||||
) (*models.Sandbox, int64, error) {
|
||||
templateDir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
|
||||
baseRootfs := layout.TemplateRootfs(m.cfg.WrennDir, teamID, templateID)
|
||||
|
||||
meta, err := readSnapshotMeta(templateDir)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read snapshot meta: %w", err)
|
||||
}
|
||||
if meta.SandboxDir == "" {
|
||||
// CH's saved config.json hardcodes a tmpfs disk path; meta.SandboxDir
|
||||
// is that exact path. A snapshot template without it cannot be launched.
|
||||
return nil, 0, fmt.Errorf("snapshot template %s missing sandbox_dir in meta", templateDir)
|
||||
}
|
||||
|
||||
// Acquire shared read-only loop on the flattened rootfs. Many sandboxes
|
||||
// can share this loop concurrently — refcounted in LoopRegistry.
|
||||
originLoop, err := m.loops.Acquire(baseRootfs)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("acquire loop: %w", err)
|
||||
}
|
||||
originSize, err := devicemapper.OriginSizeBytes(originLoop)
|
||||
if err != nil {
|
||||
m.loops.Release(baseRootfs)
|
||||
return nil, 0, fmt.Errorf("origin size: %w", err)
|
||||
}
|
||||
|
||||
// Per-sandbox CoW on top of the shared origin.
|
||||
dmName := "wrenn-" + sandboxID
|
||||
if err := os.MkdirAll(layout.SandboxDir(m.cfg.WrennDir, sandboxID), 0o755); err != nil {
|
||||
m.loops.Release(baseRootfs)
|
||||
return nil, 0, fmt.Errorf("create sandbox dir: %w", err)
|
||||
}
|
||||
cowPath := layout.SandboxCowPath(m.cfg.WrennDir, sandboxID)
|
||||
cowSize := max(int64(diskSizeMB)*1024*1024, originSize)
|
||||
dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize)
|
||||
if err != nil {
|
||||
m.loops.Release(baseRootfs)
|
||||
return nil, 0, fmt.Errorf("create dm-snapshot: %w", err)
|
||||
}
|
||||
|
||||
res := &createResources{
|
||||
sandboxID: sandboxID,
|
||||
loops: m.loops,
|
||||
loopImage: baseRootfs,
|
||||
dmDevice: dmDev,
|
||||
cowPath: cowPath,
|
||||
slots: m.slots,
|
||||
}
|
||||
|
||||
slotIdx, err := m.slots.Allocate()
|
||||
if err != nil {
|
||||
res.rollback()
|
||||
return nil, 0, fmt.Errorf("allocate network slot: %w", err)
|
||||
}
|
||||
res.slotIdx = slotIdx
|
||||
slot := network.NewSlot(slotIdx)
|
||||
|
||||
if err := network.CreateNetwork(slot); err != nil {
|
||||
res.rollback()
|
||||
return nil, 0, fmt.Errorf("create network: %w", err)
|
||||
}
|
||||
res.slot = slot
|
||||
|
||||
// CH's saved config.json hardcodes a tmpfs disk path; meta.SandboxDir is
|
||||
// that exact path (carried forward verbatim across template chains, so a
|
||||
// snapshot-of-a-snapshot resolves to the root ancestor's path). The
|
||||
// launcher mounts a fresh tmpfs there inside its private mount namespace
|
||||
// and symlinks rootfs.ext4 → our new dm device.
|
||||
vmCfg := m.buildRestoreVMConfig(restoreInputs{
|
||||
sandboxID: sandboxID,
|
||||
templateID: id.UUIDString(templateID),
|
||||
snapDir: templateDir,
|
||||
rootfsPath: dmDev.DevicePath,
|
||||
vcpus: vcpus,
|
||||
memoryMB: memoryMB,
|
||||
slot: slot,
|
||||
sandboxDir: meta.SandboxDir,
|
||||
})
|
||||
|
||||
client, err := m.launchRestoredVM(ctx, vmCfg, slot.HostIP.String())
|
||||
if err != nil {
|
||||
res.rollback()
|
||||
return nil, 0, err
|
||||
}
|
||||
res.vm = m.vm
|
||||
|
||||
envdVersion, _ := client.FetchVersion(ctx)
|
||||
|
||||
now := time.Now()
|
||||
sb := &sandboxState{
|
||||
Sandbox: models.Sandbox{
|
||||
ID: sandboxID,
|
||||
Status: models.StatusRunning,
|
||||
TemplateTeamID: teamID.Bytes,
|
||||
TemplateID: templateID.Bytes,
|
||||
VCPUs: vcpus,
|
||||
MemoryMB: memoryMB,
|
||||
TimeoutSec: timeoutSec,
|
||||
SlotIndex: slotIdx,
|
||||
HostIP: slot.HostIP,
|
||||
RootfsPath: dmDev.DevicePath,
|
||||
CreatedAt: now,
|
||||
LastActiveAt: now,
|
||||
Metadata: m.buildMetadata(envdVersion),
|
||||
},
|
||||
slot: slot,
|
||||
connTracker: &ConnTracker{},
|
||||
dmDevice: dmDev,
|
||||
baseImagePath: baseRootfs,
|
||||
sandboxDirOverride: meta.SandboxDir,
|
||||
}
|
||||
sb.client.Store(client)
|
||||
|
||||
m.mu.Lock()
|
||||
m.boxes[sandboxID] = sb
|
||||
m.mu.Unlock()
|
||||
|
||||
// /init lifecycle bump then start the memory loader. Loader is required
|
||||
// so any future CreateSnapshot taken from this descendant captures all
|
||||
// guest pages (otherwise SEEK_DATA/SEEK_HOLE would emit holes for the
|
||||
// still-lazy UFFD pages — silent corruption across template chains).
|
||||
m.initAndStartMemoryLoader(ctx, sb, defaultUser, id.UUIDString(templateID), defaultEnv)
|
||||
|
||||
m.startSampler(sb)
|
||||
m.startCrashWatcher(sb)
|
||||
|
||||
slog.Info("sandbox launched from snapshot template",
|
||||
"id", sandboxID,
|
||||
"team_id", teamID,
|
||||
"template_id", templateID,
|
||||
"sandbox_dir", meta.SandboxDir,
|
||||
"host_ip", slot.HostIP.String(),
|
||||
"dm_device", dmDev.DevicePath,
|
||||
)
|
||||
|
||||
return &sb.Sandbox, cowSize, nil
|
||||
}
|
||||
|
||||
// templateExists returns true if a snapshot template already lives at
|
||||
// TemplateDir(team, templateID). Used by CreateSnapshot to refuse silent
|
||||
// overwrites — every snapshot must land in a fresh templateID.
|
||||
func (m *Manager) templateExists(teamID, templateID pgtype.UUID) bool {
|
||||
dir := layout.TemplateDir(m.cfg.WrennDir, teamID, templateID)
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
return false
|
||||
}
|
||||
return layout.IsSnapshotTemplate(dir)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
1180
internal/sandbox/pause.go
Normal file
1180
internal/sandbox/pause.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,14 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
|
||||
)
|
||||
@ -48,42 +49,43 @@ func readCPUStat(pid int) (cpuStat, error) {
|
||||
return cpuStat{utime: utime, stime: stime}, nil
|
||||
}
|
||||
|
||||
// readEnvdMemUsed fetches mem_used from envd's /metrics endpoint. Returns
|
||||
// guest-side total - MemAvailable (actual process memory, excluding reclaimable
|
||||
// page cache). VmRSS of the Firecracker process includes guest page cache and
|
||||
// never decreases, so this is the accurate metric for dashboard display.
|
||||
func readEnvdMemUsed(client *envdclient.Client) (int64, error) {
|
||||
resp, err := client.HTTPClient().Get(client.BaseURL() + "/metrics")
|
||||
// envdMetrics holds metric values read from envd's /metrics endpoint.
|
||||
type envdMetrics struct {
|
||||
MemBytes int64
|
||||
DiskBytes int64
|
||||
}
|
||||
|
||||
// readEnvdMetrics fetches mem_used and disk_used from envd's /metrics endpoint.
|
||||
// Returns guest-side process memory (total - available) and filesystem usage
|
||||
// from statfs("/"). These are the guest-visible metrics users care about.
|
||||
func readEnvdMetrics(ctx context.Context, client *envdclient.Client) (envdMetrics, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.BaseURL()+"/metrics", nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch envd metrics: %w", err)
|
||||
return envdMetrics{}, fmt.Errorf("build metrics request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.HTTPClient().Do(req)
|
||||
if err != nil {
|
||||
return envdMetrics{}, fmt.Errorf("fetch envd metrics: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return 0, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
|
||||
return envdMetrics{}, fmt.Errorf("envd metrics: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read envd metrics body: %w", err)
|
||||
return envdMetrics{}, fmt.Errorf("read envd metrics body: %w", err)
|
||||
}
|
||||
|
||||
var m struct {
|
||||
MemUsed int64 `json:"mem_used"`
|
||||
MemUsed int64 `json:"mem_used"`
|
||||
DiskUsed int64 `json:"disk_used"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return 0, fmt.Errorf("decode envd metrics: %w", err)
|
||||
return envdMetrics{}, fmt.Errorf("decode envd metrics: %w", err)
|
||||
}
|
||||
|
||||
return m.MemUsed, nil
|
||||
}
|
||||
|
||||
// readDiskAllocated returns the actual allocated bytes (not apparent size)
|
||||
// of the file at path. This uses stat's block count × 512.
|
||||
func readDiskAllocated(path string) (int64, error) {
|
||||
var stat syscall.Stat_t
|
||||
if err := syscall.Stat(path, &stat); err != nil {
|
||||
return 0, fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
return stat.Blocks * 512, nil
|
||||
return envdMetrics{MemBytes: m.MemUsed, DiskBytes: m.DiskUsed}, nil
|
||||
}
|
||||
|
||||
186
internal/sandbox/punch.go
Normal file
186
internal/sandbox/punch.go
Normal file
@ -0,0 +1,186 @@
|
||||
// Package sandbox: post-snapshot hole punching for memory-ranges files.
|
||||
//
|
||||
// CH v52's SEEK_DATA/SEEK_HOLE snapshot writer only skips ranges already
|
||||
// hole in the source memfd. Pages the guest never reported as free are
|
||||
// written verbatim — including pages whose contents happen to be all zero
|
||||
// (fresh allocations the guest scribbled then released without telling the
|
||||
// balloon driver). Walking the resulting file and punching any 4 KiB block
|
||||
// of zeros recovers that space without any guest cooperation.
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// punchBlockSize is the granularity at which we test for zero runs and
|
||||
// issue FALLOC_FL_PUNCH_HOLE. Matches the kernel page size and the
|
||||
// minimum hole size on ext4.
|
||||
punchBlockSize = 4096
|
||||
|
||||
// punchReadSize is the IO chunk size used by the scan loop. We read
|
||||
// many blocks per syscall and split them in-memory so a 20 GiB
|
||||
// memory-ranges file costs ~20K read(2) syscalls instead of ~5M.
|
||||
// Crucial under single-disk hosts where each syscall otherwise
|
||||
// contends with sshd / journal IO.
|
||||
punchReadSize = 1 << 20 // 1 MiB = 256 blocks
|
||||
)
|
||||
|
||||
// punchZeroPagesInDir runs punchZeroPages on every memory* file in dir.
|
||||
// CH writes its memory dump as one or more files prefixed "memory" inside
|
||||
// the snapshot directory; everything else (config.json, state.json) is
|
||||
// metadata and untouched.
|
||||
func punchZeroPagesInDir(dir string) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
slog.Warn("punch: read snapshot dir", "dir", dir, "error", err)
|
||||
return
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasPrefix(e.Name(), "memory") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, e.Name())
|
||||
before, after, err := punchZeroPages(path)
|
||||
if err != nil {
|
||||
slog.Warn("punch: zero-page scan failed", "path", path, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Info("punch: zero-page scan done",
|
||||
"path", path,
|
||||
"alloc_before", before,
|
||||
"alloc_after", after,
|
||||
"reclaimed", before-after)
|
||||
}
|
||||
}
|
||||
|
||||
// punchZeroPages scans path block-by-block, batching runs of all-zero 4 KiB
|
||||
// blocks and punching them out via FALLOC_FL_PUNCH_HOLE. Existing holes are
|
||||
// skipped via SEEK_DATA so a partially-sparse input stays cheap to scan.
|
||||
//
|
||||
// Returns the file's disk allocation (st_blocks * 512) before and after.
|
||||
func punchZeroPages(path string) (int64, int64, error) {
|
||||
f, err := os.OpenFile(path, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stBefore, err := statBlocks(f)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("stat before: %w", err)
|
||||
}
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("stat: %w", err)
|
||||
}
|
||||
size := fi.Size()
|
||||
|
||||
buf := make([]byte, punchReadSize)
|
||||
off := int64(0)
|
||||
|
||||
for off < size {
|
||||
// Skip ahead to next data region; nothing to do in holes.
|
||||
next, err := f.Seek(off, 3) // SEEK_DATA = 3
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, unix.ENXIO) {
|
||||
break
|
||||
}
|
||||
return 0, 0, fmt.Errorf("seek_data @ %d: %w", off, err)
|
||||
}
|
||||
off = next &^ (punchBlockSize - 1) // align down to block
|
||||
|
||||
// Find end of this data extent.
|
||||
endData, err := f.Seek(off, 4) // SEEK_HOLE = 4
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("seek_hole @ %d: %w", off, err)
|
||||
}
|
||||
|
||||
// Scan [off, endData) chunk by chunk; batch zero runs across both
|
||||
// intra-chunk and inter-chunk boundaries so a contiguous zero
|
||||
// region is punched in a single fallocate.
|
||||
zeroStart := int64(-1)
|
||||
cur := off
|
||||
for cur < endData {
|
||||
toRead := min(int64(len(buf)), endData-cur)
|
||||
n, err := readAt(f, buf[:toRead], cur)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("read @ %d: %w", cur, err)
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
// Walk the chunk one block at a time, tracking zero runs.
|
||||
for blkOff := 0; blkOff < n; blkOff += punchBlockSize {
|
||||
blkEnd := min(blkOff+punchBlockSize, n)
|
||||
blk := buf[blkOff:blkEnd]
|
||||
blkAbs := cur + int64(blkOff)
|
||||
if isZero(blk) && len(blk) == punchBlockSize {
|
||||
if zeroStart < 0 {
|
||||
zeroStart = blkAbs
|
||||
}
|
||||
} else if zeroStart >= 0 {
|
||||
if err := punch(f, zeroStart, blkAbs-zeroStart); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
zeroStart = -1
|
||||
}
|
||||
}
|
||||
cur += int64(n)
|
||||
}
|
||||
if zeroStart >= 0 {
|
||||
if err := punch(f, zeroStart, cur-zeroStart); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
off = endData
|
||||
}
|
||||
|
||||
stAfter, err := statBlocks(f)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("stat after: %w", err)
|
||||
}
|
||||
return stBefore, stAfter, nil
|
||||
}
|
||||
|
||||
func punch(f *os.File, off, length int64) error {
|
||||
mode := uint32(unix.FALLOC_FL_PUNCH_HOLE | unix.FALLOC_FL_KEEP_SIZE)
|
||||
if err := unix.Fallocate(int(f.Fd()), mode, off, length); err != nil {
|
||||
return fmt.Errorf("fallocate punch @ %d len %d: %w", off, length, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readAt(f *os.File, buf []byte, off int64) (int, error) {
|
||||
n, err := f.ReadAt(buf, off)
|
||||
if err == io.EOF {
|
||||
return n, nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func isZero(b []byte) bool {
|
||||
for _, x := range b {
|
||||
if x != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func statBlocks(f *os.File) (int64, error) {
|
||||
var st unix.Stat_t
|
||||
if err := unix.Fstat(int(f.Fd()), &st); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(st.Blocks) * 512, nil
|
||||
}
|
||||
118
internal/sandbox/restore.go
Normal file
118
internal/sandbox/restore.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Package sandbox: shared CH-restore helpers used by both Resume (paused →
|
||||
// running) and the snapshot-template launch path (template → fresh sandbox).
|
||||
//
|
||||
// The two callers diverge in how they acquire resources (slot, dm-snapshot,
|
||||
// sandbox identity) but converge on:
|
||||
//
|
||||
// build VMConfig → CreateFromSnapshot → vm.Resume → wait envd → balloon deflate
|
||||
//
|
||||
// These steps are extracted here so the sequence — and its quirks (paused
|
||||
// post-restore state, balloon best-effort, restored disk path baked into
|
||||
// CH's config.json) — has a single source of truth.
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/envdclient"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/network"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/vm"
|
||||
)
|
||||
|
||||
// restoreInputs is the common set of fields needed to build a restore VMConfig.
|
||||
type restoreInputs struct {
|
||||
sandboxID string // VM identity for the new CH process (sock path, log file)
|
||||
templateID string // forwarded to envd via PostInit (informational)
|
||||
snapDir string // directory containing CH snapshot artefacts
|
||||
rootfsPath string // /dev/mapper/wrenn-{newID} — per-sandbox dm-snapshot
|
||||
vcpus int
|
||||
memoryMB int
|
||||
slot *network.Slot
|
||||
sandboxDir string // override for VMConfig.SandboxDir; "" = default
|
||||
}
|
||||
|
||||
// buildRestoreVMConfig assembles the VMConfig used to launch a CH process in
|
||||
// restore mode. sandboxDir, when non-empty, overrides the default
|
||||
// "/tmp/ch-vm-{SandboxID}" — required when the snapshot's saved config.json
|
||||
// points at a different sandbox's tmpfs path (i.e. snapshot-template launch).
|
||||
func (m *Manager) buildRestoreVMConfig(in restoreInputs) vm.VMConfig {
|
||||
return vm.VMConfig{
|
||||
SandboxID: in.sandboxID,
|
||||
TemplateID: in.templateID,
|
||||
KernelPath: m.cfg.KernelPath,
|
||||
RootfsPath: in.rootfsPath,
|
||||
VCPUs: in.vcpus,
|
||||
MemoryMB: in.memoryMB,
|
||||
NetworkNamespace: in.slot.NamespaceID,
|
||||
TapDevice: in.slot.TapName,
|
||||
TapMAC: in.slot.TapMAC,
|
||||
GuestIP: in.slot.GuestIP,
|
||||
GatewayIP: in.slot.TapIP,
|
||||
NetMask: in.slot.GuestNetMask,
|
||||
VMMBin: m.cfg.VMMBin,
|
||||
LogDir: filepath.Join(m.cfg.WrennDir, "logs"),
|
||||
RestoreFromDir: in.snapDir,
|
||||
RestoreLazyMemory: true,
|
||||
SandboxDir: in.sandboxDir,
|
||||
}
|
||||
}
|
||||
|
||||
// launchRestoredVM starts CH in restore mode, resumes the vCPUs, waits for
|
||||
// envd to be reachable, then best-effort deflates the balloon. On any failure
|
||||
// the partial VM is destroyed before returning — the caller is responsible
|
||||
// for tearing down dm/network/slot.
|
||||
//
|
||||
// Returns the connected envd client on success.
|
||||
func (m *Manager) launchRestoredVM(ctx context.Context, vmCfg vm.VMConfig, hostIP string) (*envdclient.Client, error) {
|
||||
if _, err := m.vm.CreateFromSnapshot(ctx, vmCfg); err != nil {
|
||||
return nil, fmt.Errorf("create from snapshot: %w", err)
|
||||
}
|
||||
|
||||
if err := m.vm.Resume(ctx, vmCfg.SandboxID); err != nil {
|
||||
_ = m.vm.Destroy(context.Background(), vmCfg.SandboxID)
|
||||
return nil, fmt.Errorf("vm resume: %w", err)
|
||||
}
|
||||
|
||||
client := envdclient.New(hostIP)
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, envdReadyTimeout(vmCfg.MemoryMB))
|
||||
defer waitCancel()
|
||||
if err := client.WaitUntilReady(waitCtx); err != nil {
|
||||
_ = m.vm.Destroy(context.Background(), vmCfg.SandboxID)
|
||||
return nil, fmt.Errorf("wait envd: %w", err)
|
||||
}
|
||||
|
||||
// Best-effort balloon deflate. Free-page reporting drains pages while the
|
||||
// sandbox runs; the resumed guest needs its full memory budget back. A
|
||||
// failure leaves the guest memory-starved but doesn't break correctness.
|
||||
if err := m.vm.UpdateBalloon(ctx, vmCfg.SandboxID, 0); err != nil {
|
||||
slog.Warn("balloon deflate after restore failed", "id", vmCfg.SandboxID, "error", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// initAndStartMemoryLoader runs envd's /init lifecycle bump and then kicks
|
||||
// off the background memory loader. Ordering matters: /init resets envd's
|
||||
// mem_preload_* atomics, so the loader's POST /memory/preload must land
|
||||
// after — otherwise the next CreateSnapshot/Pause would observe a stale
|
||||
// "idle" state and snapshot a memfile full of holes.
|
||||
//
|
||||
// Must be called with sb already registered in m.boxes with StatusRunning
|
||||
// and sb.client populated.
|
||||
func (m *Manager) initAndStartMemoryLoader(ctx context.Context, sb *sandboxState, defaultUser, templateIDStr string, envVars map[string]string) {
|
||||
initCtx, initCancel := context.WithTimeout(ctx, m.cfg.EnvdTimeout)
|
||||
defer initCancel()
|
||||
c := sb.client.Load()
|
||||
if c == nil {
|
||||
slog.Warn("post-restore PostInit skipped: envd client cleared", "id", sb.ID)
|
||||
return
|
||||
}
|
||||
if err := c.PostInitWithDefaults(initCtx, defaultUser, envVars, sb.ID, templateIDStr); err != nil {
|
||||
slog.Warn("post-restore PostInit failed", "id", sb.ID, "error", err)
|
||||
}
|
||||
|
||||
m.startMemoryLoader(sb)
|
||||
}
|
||||
208
internal/sandbox/restore_paused.go
Normal file
208
internal/sandbox/restore_paused.go
Normal file
@ -0,0 +1,208 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/layout"
|
||||
"git.omukk.dev/wrenn/wrenn/internal/models"
|
||||
)
|
||||
|
||||
// RestorePausedSandboxes scans WRENN_DIR/sandboxes/ for paused-sandbox
|
||||
// snapshots left behind by a previous agent instance and re-registers them
|
||||
// in m.boxes as StatusPaused. Without this, ListSandboxes would not report
|
||||
// these sandboxes, and the CP's HostMonitor would mark them stopped via
|
||||
// the missing-confirmed-dead reconcile path — orphaning the on-disk
|
||||
// snapshot dir and surfacing a leaked "stopped" sandbox to users.
|
||||
//
|
||||
// Restored sandboxes hold ONLY the slot reservation; VM / network / dm /
|
||||
// loop refcount stay unowned until Resume rebuilds them. baseImagePath is
|
||||
// deliberately NOT set on the in-memory entry so cleanup() does not call
|
||||
// loops.Release on a loop that was never Acquire'd — the registry tolerates
|
||||
// a Release of an unknown key, but a coincident-same-base running sandbox
|
||||
// would have its refcount decremented incorrectly.
|
||||
//
|
||||
// Must be called once at agent startup, AFTER CleanupOrphanPauseDirs (so
|
||||
// .staging-* / .trash-* dirs are gone) and BEFORE the HTTP server starts
|
||||
// serving — otherwise an early Create RPC can race the slot reservation.
|
||||
//
|
||||
// Corrupt snapshot dirs (unparseable meta, missing slot index) are renamed
|
||||
// to .trash-{ts}/ so a future CleanupOrphanPauseDirs sweeps them. Soft
|
||||
// errors are logged; this function never returns an error — startup should
|
||||
// not fail because a single sandbox is unrecoverable.
|
||||
func (m *Manager) RestorePausedSandboxes() {
|
||||
sandboxesDir := layout.SandboxesDir(m.cfg.WrennDir)
|
||||
entries, err := os.ReadDir(sandboxesDir)
|
||||
if err != nil {
|
||||
// Directory does not exist yet — fresh install, nothing to restore.
|
||||
return
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
sandboxID string
|
||||
snapDir string
|
||||
meta *snapshotMeta
|
||||
teamID [16]byte
|
||||
templID [16]byte
|
||||
}
|
||||
|
||||
// Pass 1: parse every snapshot meta. Trash anything unreadable or
|
||||
// missing the slot index — those are crash artefacts, not recoverable
|
||||
// sandboxes.
|
||||
candidates := make([]candidate, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
// Skip CleanupOrphanPauseDirs's territory. If it ran before us
|
||||
// these are already gone; if not, leave them alone.
|
||||
if strings.Contains(name, ".staging-") || strings.Contains(name, ".trash-") {
|
||||
continue
|
||||
}
|
||||
|
||||
snapDir := layout.PauseSnapshotDir(m.cfg.WrennDir, name)
|
||||
meta, err := readSnapshotMeta(snapDir)
|
||||
if err != nil {
|
||||
slog.Warn("restore: unreadable snapshot meta, trashing dir",
|
||||
"id", name, "error", err)
|
||||
trashCorruptDir(snapDir)
|
||||
continue
|
||||
}
|
||||
if meta.SlotIndex == 0 {
|
||||
slog.Warn("restore: snapshot has no slot_index, trashing dir", "id", name)
|
||||
trashCorruptDir(snapDir)
|
||||
continue
|
||||
}
|
||||
teamBytes, err := parsePlainUUID(meta.TeamID)
|
||||
if err != nil {
|
||||
slog.Warn("restore: bad team_id in snapshot meta", "id", name, "error", err)
|
||||
trashCorruptDir(snapDir)
|
||||
continue
|
||||
}
|
||||
templateBytes, err := parsePlainUUID(meta.TemplateID)
|
||||
if err != nil {
|
||||
slog.Warn("restore: bad template_id in snapshot meta", "id", name, "error", err)
|
||||
trashCorruptDir(snapDir)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate{
|
||||
sandboxID: name,
|
||||
snapDir: snapDir,
|
||||
meta: meta,
|
||||
teamID: teamBytes,
|
||||
templID: templateBytes,
|
||||
})
|
||||
}
|
||||
|
||||
// Pass 2: bucket by slot index, pick the newest CreatedAt per slot.
|
||||
// Multiple candidates per slot happen when older paused-sandbox dirs
|
||||
// were left on disk by the pre-fix leak (DB row marked stopped but the
|
||||
// snapshot was never cleaned). The newest is the most likely live one;
|
||||
// older losers are trashed so CleanupOrphanPauseDirs sweeps them on
|
||||
// the next startup.
|
||||
bySlot := make(map[int][]candidate, len(candidates))
|
||||
for _, c := range candidates {
|
||||
bySlot[c.meta.SlotIndex] = append(bySlot[c.meta.SlotIndex], c)
|
||||
}
|
||||
|
||||
restored := 0
|
||||
pruned := 0
|
||||
for slot, cands := range bySlot {
|
||||
sort.Slice(cands, func(i, j int) bool {
|
||||
return cands[i].meta.CreatedAt.After(cands[j].meta.CreatedAt)
|
||||
})
|
||||
|
||||
// Trash every loser. The host_monitor's zombie-cleanup path catches
|
||||
// the winner if its DB row says 'stopped' — but losers never enter
|
||||
// m.boxes and would otherwise sit on disk indefinitely.
|
||||
for _, stale := range cands[1:] {
|
||||
slog.Info("restore: pruning older snapshot for same slot",
|
||||
"id", stale.sandboxID, "slot", slot, "created", stale.meta.CreatedAt,
|
||||
"winner", cands[0].sandboxID, "winner_created", cands[0].meta.CreatedAt)
|
||||
trashCorruptDir(stale.snapDir)
|
||||
pruned++
|
||||
}
|
||||
|
||||
winner := cands[0]
|
||||
if err := m.slots.Reserve(winner.meta.SlotIndex); err != nil {
|
||||
// Reserve only fails if another candidate (different slot value
|
||||
// in meta but same numeric index) already grabbed it, or if the
|
||||
// allocator is corrupt. Either way the snapshot is unusable
|
||||
// without a slot, so trash it.
|
||||
slog.Warn("restore: slot reservation failed, trashing dir",
|
||||
"id", winner.sandboxID, "slot", winner.meta.SlotIndex, "error", err)
|
||||
trashCorruptDir(winner.snapDir)
|
||||
pruned++
|
||||
continue
|
||||
}
|
||||
|
||||
sb := &sandboxState{
|
||||
Sandbox: models.Sandbox{
|
||||
ID: winner.sandboxID,
|
||||
Status: models.StatusPaused,
|
||||
TemplateTeamID: winner.teamID,
|
||||
TemplateID: winner.templID,
|
||||
VCPUs: winner.meta.VCPUs,
|
||||
MemoryMB: winner.meta.MemoryMB,
|
||||
TimeoutSec: winner.meta.TimeoutSec,
|
||||
SlotIndex: winner.meta.SlotIndex,
|
||||
CreatedAt: winner.meta.CreatedAt,
|
||||
// LastActiveAt cosmetic only — TTL reaper ignores non-Running.
|
||||
LastActiveAt: winner.meta.CreatedAt,
|
||||
},
|
||||
// connTracker must be non-nil: resumeFromMeta calls Reset() on it
|
||||
// unconditionally during rehydration. A nil pointer would panic.
|
||||
connTracker: &ConnTracker{},
|
||||
// baseImagePath intentionally left empty — see function doc.
|
||||
// sandboxDirOverride intentionally left empty — resumeFromMeta
|
||||
// reads meta.SandboxDir from disk on the resume path.
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.boxes[winner.sandboxID] = sb
|
||||
m.mu.Unlock()
|
||||
restored++
|
||||
|
||||
slog.Info("restored paused sandbox", "id", winner.sandboxID,
|
||||
"slot", winner.meta.SlotIndex, "vcpus", winner.meta.VCPUs, "memory_mb", winner.meta.MemoryMB)
|
||||
}
|
||||
|
||||
if restored > 0 || pruned > 0 {
|
||||
slog.Info("paused sandbox restore complete", "restored", restored, "pruned", pruned)
|
||||
}
|
||||
}
|
||||
|
||||
// parsePlainUUID turns a standard hyphenated UUID string (as produced by
|
||||
// id.UUIDString) back into the 16-byte representation used by sandboxState.
|
||||
func parsePlainUUID(s string) ([16]byte, error) {
|
||||
if s == "" {
|
||||
return [16]byte{}, fmt.Errorf("empty uuid string")
|
||||
}
|
||||
u, err := uuid.Parse(s)
|
||||
if err != nil {
|
||||
return [16]byte{}, err
|
||||
}
|
||||
return [16]byte(u), nil
|
||||
}
|
||||
|
||||
// trashCorruptDir renames a corrupt snapshot directory aside so a future
|
||||
// CleanupOrphanPauseDirs sweeps it. Best-effort: if rename fails we log
|
||||
// and move on — leaving the directory in place is safe (restore will skip
|
||||
// it again next startup) but unwanted.
|
||||
func trashCorruptDir(dir string) {
|
||||
parent := filepath.Dir(dir)
|
||||
base := filepath.Base(dir)
|
||||
trash := filepath.Join(parent, fmt.Sprintf("%s.trash-%d", base, time.Now().UnixNano()))
|
||||
if err := os.Rename(dir, trash); err != nil {
|
||||
slog.Warn("restore: failed to trash corrupt snapshot dir",
|
||||
"src", dir, "dst", trash, "error", err)
|
||||
}
|
||||
}
|
||||
@ -1,221 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
// Package snapshot implements snapshot storage, header-based memory mapping,
|
||||
// and memory file processing for Firecracker VM snapshots.
|
||||
//
|
||||
// The header system implements a generational copy-on-write memory mapping.
|
||||
// Each snapshot generation stores only the blocks that changed since the
|
||||
// previous generation. A Header contains a sorted list of BuildMap entries
|
||||
// that together cover the entire memory address space, with each entry
|
||||
// pointing to a specific generation's diff file.
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const metadataVersion = 1
|
||||
|
||||
// Metadata is the fixed-size header prefix describing the snapshot memory layout.
|
||||
// Binary layout (little-endian, 64 bytes total):
|
||||
//
|
||||
// Version uint64 (8 bytes)
|
||||
// BlockSize uint64 (8 bytes)
|
||||
// Size uint64 (8 bytes) — total memory size in bytes
|
||||
// Generation uint64 (8 bytes)
|
||||
// BuildID [16]byte (UUID)
|
||||
// BaseBuildID [16]byte (UUID)
|
||||
type Metadata struct {
|
||||
Version uint64
|
||||
BlockSize uint64
|
||||
Size uint64
|
||||
Generation uint64
|
||||
BuildID uuid.UUID
|
||||
BaseBuildID uuid.UUID
|
||||
}
|
||||
|
||||
// NewMetadata creates metadata for a first-generation snapshot.
|
||||
func NewMetadata(buildID uuid.UUID, blockSize, size uint64) *Metadata {
|
||||
return &Metadata{
|
||||
Version: metadataVersion,
|
||||
Generation: 0,
|
||||
BlockSize: blockSize,
|
||||
Size: size,
|
||||
BuildID: buildID,
|
||||
BaseBuildID: buildID,
|
||||
}
|
||||
}
|
||||
|
||||
// NextGeneration creates metadata for the next generation in the chain.
|
||||
func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata {
|
||||
return &Metadata{
|
||||
Version: m.Version,
|
||||
Generation: m.Generation + 1,
|
||||
BlockSize: m.BlockSize,
|
||||
Size: m.Size,
|
||||
BuildID: buildID,
|
||||
BaseBuildID: m.BaseBuildID,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildMap maps a contiguous range of the memory address space to a specific
|
||||
// generation's diff file. Binary layout (little-endian, 40 bytes):
|
||||
//
|
||||
// Offset uint64 — byte offset in the virtual address space
|
||||
// Length uint64 — byte count (multiple of BlockSize)
|
||||
// BuildID [16]byte — which generation's diff file, uuid.Nil = zero-fill
|
||||
// BuildStorageOffset uint64 — byte offset within that generation's diff file
|
||||
type BuildMap struct {
|
||||
Offset uint64
|
||||
Length uint64
|
||||
BuildID uuid.UUID
|
||||
BuildStorageOffset uint64
|
||||
}
|
||||
|
||||
// Header is the in-memory representation of a snapshot's memory mapping.
|
||||
// It provides O(log N) lookup from any memory offset to the correct
|
||||
// generation's diff file and offset within it.
|
||||
type Header struct {
|
||||
Metadata *Metadata
|
||||
Mapping []*BuildMap
|
||||
|
||||
// blockStarts tracks which block indices start a new BuildMap entry.
|
||||
// startMap provides direct access from block index to the BuildMap.
|
||||
blockStarts []bool
|
||||
startMap map[int64]*BuildMap
|
||||
}
|
||||
|
||||
// NewHeader creates a Header from metadata and mapping entries.
|
||||
// If mapping is nil/empty, a single entry covering the full size is created.
|
||||
func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) {
|
||||
if metadata.BlockSize == 0 {
|
||||
return nil, fmt.Errorf("block size cannot be zero")
|
||||
}
|
||||
|
||||
if len(mapping) == 0 {
|
||||
mapping = []*BuildMap{{
|
||||
Offset: 0,
|
||||
Length: metadata.Size,
|
||||
BuildID: metadata.BuildID,
|
||||
BuildStorageOffset: 0,
|
||||
}}
|
||||
}
|
||||
|
||||
blocks := TotalBlocks(int64(metadata.Size), int64(metadata.BlockSize))
|
||||
starts := make([]bool, blocks)
|
||||
startMap := make(map[int64]*BuildMap, len(mapping))
|
||||
|
||||
for _, m := range mapping {
|
||||
idx := BlockIdx(int64(m.Offset), int64(metadata.BlockSize))
|
||||
if idx >= 0 && idx < blocks {
|
||||
starts[idx] = true
|
||||
startMap[idx] = m
|
||||
}
|
||||
}
|
||||
|
||||
return &Header{
|
||||
Metadata: metadata,
|
||||
Mapping: mapping,
|
||||
blockStarts: starts,
|
||||
startMap: startMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetShiftedMapping resolves a memory offset to the corresponding diff file
|
||||
// offset, remaining length, and build ID. This is the hot path called for
|
||||
// every UFFD page fault.
|
||||
func (h *Header) GetShiftedMapping(_ context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) {
|
||||
if offset < 0 || offset >= int64(h.Metadata.Size) {
|
||||
return 0, 0, nil, fmt.Errorf("offset %d out of bounds (size: %d)", offset, h.Metadata.Size)
|
||||
}
|
||||
|
||||
blockSize := int64(h.Metadata.BlockSize)
|
||||
block := BlockIdx(offset, blockSize)
|
||||
|
||||
// Walk backwards to find the BuildMap that contains this block.
|
||||
start := block
|
||||
for start >= 0 {
|
||||
if h.blockStarts[start] {
|
||||
break
|
||||
}
|
||||
start--
|
||||
}
|
||||
if start < 0 {
|
||||
return 0, 0, nil, fmt.Errorf("no mapping found for offset %d", offset)
|
||||
}
|
||||
|
||||
m, ok := h.startMap[start]
|
||||
if !ok {
|
||||
return 0, 0, nil, fmt.Errorf("no mapping at block %d", start)
|
||||
}
|
||||
|
||||
shift := (block - start) * blockSize
|
||||
if shift >= int64(m.Length) {
|
||||
return 0, 0, nil, fmt.Errorf("offset %d beyond mapping end (mapping offset=%d, length=%d)", offset, m.Offset, m.Length)
|
||||
}
|
||||
|
||||
return int64(m.BuildStorageOffset) + shift, int64(m.Length) - shift, &m.BuildID, nil
|
||||
}
|
||||
|
||||
// Serialize writes metadata + mapping entries to binary (little-endian).
|
||||
func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.LittleEndian, metadata); err != nil {
|
||||
return nil, fmt.Errorf("write metadata: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range mappings {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, m); err != nil {
|
||||
return nil, fmt.Errorf("write mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Deserialize reads a header from binary data.
|
||||
func Deserialize(data []byte) (*Header, error) {
|
||||
reader := bytes.NewReader(data)
|
||||
|
||||
var metadata Metadata
|
||||
if err := binary.Read(reader, binary.LittleEndian, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("read metadata: %w", err)
|
||||
}
|
||||
|
||||
var mappings []*BuildMap
|
||||
for {
|
||||
var m BuildMap
|
||||
if err := binary.Read(reader, binary.LittleEndian, &m); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("read mapping: %w", err)
|
||||
}
|
||||
mappings = append(mappings, &m)
|
||||
}
|
||||
|
||||
return NewHeader(&metadata, mappings)
|
||||
}
|
||||
|
||||
// Block index helpers.
|
||||
|
||||
func TotalBlocks(size, blockSize int64) int64 {
|
||||
return (size + blockSize - 1) / blockSize
|
||||
}
|
||||
|
||||
func BlockIdx(offset, blockSize int64) int64 {
|
||||
return offset / blockSize
|
||||
}
|
||||
|
||||
func BlockOffset(idx, blockSize int64) int64 {
|
||||
return idx * blockSize
|
||||
}
|
||||
@ -7,14 +7,15 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
SnapFileName = "snapfile"
|
||||
MemDiffName = "memfile"
|
||||
MemHeaderName = "memfile.header"
|
||||
// Cloud Hypervisor snapshot files.
|
||||
CHConfigFile = "config.json"
|
||||
CHMemRangesFile = "memory-ranges"
|
||||
CHStateFile = "state.json"
|
||||
|
||||
// Rootfs files.
|
||||
RootfsFileName = "rootfs.ext4"
|
||||
RootfsCowName = "rootfs.cow"
|
||||
RootfsMetaName = "rootfs.meta"
|
||||
@ -25,27 +26,6 @@ func DirPath(baseDir, name string) string {
|
||||
return filepath.Join(baseDir, name)
|
||||
}
|
||||
|
||||
// SnapPath returns the path to the VM state snapshot file.
|
||||
func SnapPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), SnapFileName)
|
||||
}
|
||||
|
||||
// MemDiffPath returns the path to the compact memory diff file (legacy single-generation).
|
||||
func MemDiffPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), MemDiffName)
|
||||
}
|
||||
|
||||
// MemDiffPathForBuild returns the path to a specific generation's diff file.
|
||||
// Format: memfile.{buildID}
|
||||
func MemDiffPathForBuild(baseDir, name string, buildID uuid.UUID) string {
|
||||
return filepath.Join(DirPath(baseDir, name), fmt.Sprintf("memfile.%s", buildID.String()))
|
||||
}
|
||||
|
||||
// MemHeaderPath returns the path to the memory mapping header file.
|
||||
func MemHeaderPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), MemHeaderName)
|
||||
}
|
||||
|
||||
// RootfsPath returns the path to the rootfs image.
|
||||
func RootfsPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), RootfsFileName)
|
||||
@ -61,10 +41,13 @@ func MetaPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), RootfsMetaName)
|
||||
}
|
||||
|
||||
// RootfsMeta records which base template a CoW file was created against.
|
||||
// RootfsMeta records which base template a CoW file was created against
|
||||
// and the VM resource config needed to restart the sampler on resume.
|
||||
type RootfsMeta struct {
|
||||
BaseTemplate string `json:"base_template"`
|
||||
TemplateID string `json:"template_id,omitempty"`
|
||||
VCPUs int `json:"vcpus,omitempty"`
|
||||
MemoryMB int `json:"memory_mb,omitempty"`
|
||||
}
|
||||
|
||||
// WriteMeta writes rootfs metadata to the snapshot directory.
|
||||
@ -92,102 +75,6 @@ func ReadMeta(baseDir, name string) (*RootfsMeta, error) {
|
||||
return &meta, nil
|
||||
}
|
||||
|
||||
// Exists reports whether a complete snapshot exists (all required files present).
|
||||
// Supports both legacy (rootfs.ext4) and CoW-based (rootfs.cow + rootfs.meta) snapshots.
|
||||
// Memory diff files can be either legacy "memfile" or generation-specific "memfile.{uuid}".
|
||||
func Exists(baseDir, name string) bool {
|
||||
dir := DirPath(baseDir, name)
|
||||
|
||||
// snapfile and header are always required.
|
||||
for _, f := range []string{SnapFileName, MemHeaderName} {
|
||||
if _, err := os.Stat(filepath.Join(dir, f)); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check that at least one memfile exists (legacy or generation-specific).
|
||||
// We verify by reading the header and checking that referenced diff files exist.
|
||||
// Fall back to checking for the legacy memfile name if header can't be read.
|
||||
if _, err := os.Stat(filepath.Join(dir, MemDiffName)); err != nil {
|
||||
// No legacy memfile — check if any memfile.{uuid} exists by
|
||||
// looking for files matching the pattern.
|
||||
matches, _ := filepath.Glob(filepath.Join(dir, "memfile.*"))
|
||||
hasGenDiff := false
|
||||
for _, m := range matches {
|
||||
base := filepath.Base(m)
|
||||
if base != MemHeaderName {
|
||||
hasGenDiff = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasGenDiff {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Accept either rootfs.ext4 (legacy/template) or rootfs.cow + rootfs.meta (dm-snapshot).
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsFileName)); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsCowName)); err == nil {
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsMetaName)); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTemplate reports whether a template image directory exists (has rootfs.ext4).
|
||||
func IsTemplate(baseDir, name string) bool {
|
||||
_, err := os.Stat(filepath.Join(DirPath(baseDir, name), RootfsFileName))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSnapshot reports whether a directory is a snapshot (has all snapshot files).
|
||||
func IsSnapshot(baseDir, name string) bool {
|
||||
return Exists(baseDir, name)
|
||||
}
|
||||
|
||||
// HasCow reports whether a snapshot uses CoW format (rootfs.cow + rootfs.meta)
|
||||
// as opposed to legacy full rootfs (rootfs.ext4).
|
||||
func HasCow(baseDir, name string) bool {
|
||||
dir := DirPath(baseDir, name)
|
||||
_, cowErr := os.Stat(filepath.Join(dir, RootfsCowName))
|
||||
_, metaErr := os.Stat(filepath.Join(dir, RootfsMetaName))
|
||||
return cowErr == nil && metaErr == nil
|
||||
}
|
||||
|
||||
// ListDiffFiles returns a map of build ID → file path for all memory diff files
|
||||
// referenced by the given header. Handles both the legacy "memfile" name
|
||||
// (single-generation) and generation-specific "memfile.{uuid}" names.
|
||||
func ListDiffFiles(baseDir, name string, header *Header) (map[string]string, error) {
|
||||
dir := DirPath(baseDir, name)
|
||||
result := make(map[string]string)
|
||||
|
||||
for _, m := range header.Mapping {
|
||||
if m.BuildID == uuid.Nil {
|
||||
continue // zero-fill, no file needed
|
||||
}
|
||||
idStr := m.BuildID.String()
|
||||
if _, exists := result[idStr]; exists {
|
||||
continue
|
||||
}
|
||||
// Try generation-specific path first, fall back to legacy.
|
||||
genPath := filepath.Join(dir, fmt.Sprintf("memfile.%s", idStr))
|
||||
if _, err := os.Stat(genPath); err == nil {
|
||||
result[idStr] = genPath
|
||||
continue
|
||||
}
|
||||
legacyPath := filepath.Join(dir, MemDiffName)
|
||||
if _, err := os.Stat(legacyPath); err == nil {
|
||||
result[idStr] = legacyPath
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("diff file not found for build %s", idStr)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EnsureDir creates the snapshot directory if it doesn't exist.
|
||||
func EnsureDir(baseDir, name string) error {
|
||||
dir := DirPath(baseDir, name)
|
||||
|
||||
@ -1,214 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package snapshot
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// CreateMapping converts a dirty-block bitset (represented as a []bool) into
|
||||
// a sorted list of BuildMap entries. Consecutive dirty blocks are merged into
|
||||
// a single entry. BuildStorageOffset tracks the sequential position in the
|
||||
// compact diff file.
|
||||
func CreateMapping(buildID uuid.UUID, dirty []bool, blockSize int64) []*BuildMap {
|
||||
var mappings []*BuildMap
|
||||
var runStart int64 = -1
|
||||
var runLength int64
|
||||
var storageOffset uint64
|
||||
|
||||
for i, set := range dirty {
|
||||
if !set {
|
||||
if runLength > 0 {
|
||||
mappings = append(mappings, &BuildMap{
|
||||
Offset: uint64(runStart) * uint64(blockSize),
|
||||
Length: uint64(runLength) * uint64(blockSize),
|
||||
BuildID: buildID,
|
||||
BuildStorageOffset: storageOffset,
|
||||
})
|
||||
storageOffset += uint64(runLength) * uint64(blockSize)
|
||||
runLength = 0
|
||||
}
|
||||
runStart = -1
|
||||
continue
|
||||
}
|
||||
|
||||
if runStart < 0 {
|
||||
runStart = int64(i)
|
||||
runLength = 1
|
||||
} else {
|
||||
runLength++
|
||||
}
|
||||
}
|
||||
|
||||
if runLength > 0 {
|
||||
mappings = append(mappings, &BuildMap{
|
||||
Offset: uint64(runStart) * uint64(blockSize),
|
||||
Length: uint64(runLength) * uint64(blockSize),
|
||||
BuildID: buildID,
|
||||
BuildStorageOffset: storageOffset,
|
||||
})
|
||||
}
|
||||
|
||||
return mappings
|
||||
}
|
||||
|
||||
// MergeMappings overlays diffMapping on top of baseMapping. Where they overlap,
|
||||
// diff takes priority. The result covers the entire address space.
|
||||
//
|
||||
// Both inputs must be sorted by Offset. The base mapping should cover the full size.
|
||||
//
|
||||
// Inspired by e2b's snapshot system (Apache 2.0, modified by Omukk).
|
||||
func MergeMappings(baseMapping, diffMapping []*BuildMap) []*BuildMap {
|
||||
if len(diffMapping) == 0 {
|
||||
return baseMapping
|
||||
}
|
||||
|
||||
// Work on a copy of baseMapping to avoid mutating the original.
|
||||
baseCopy := make([]*BuildMap, len(baseMapping))
|
||||
for i, m := range baseMapping {
|
||||
cp := *m
|
||||
baseCopy[i] = &cp
|
||||
}
|
||||
|
||||
var result []*BuildMap
|
||||
var bi, di int
|
||||
|
||||
for bi < len(baseCopy) && di < len(diffMapping) {
|
||||
base := baseCopy[bi]
|
||||
diff := diffMapping[di]
|
||||
|
||||
if base.Length == 0 {
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
if diff.Length == 0 {
|
||||
di++
|
||||
continue
|
||||
}
|
||||
|
||||
// No overlap: base entirely before diff.
|
||||
if base.Offset+base.Length <= diff.Offset {
|
||||
result = append(result, base)
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
|
||||
// No overlap: diff entirely before base.
|
||||
if diff.Offset+diff.Length <= base.Offset {
|
||||
result = append(result, diff)
|
||||
di++
|
||||
continue
|
||||
}
|
||||
|
||||
// Base fully inside diff — skip base.
|
||||
if base.Offset >= diff.Offset && base.Offset+base.Length <= diff.Offset+diff.Length {
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
|
||||
// Diff fully inside base — split base around diff.
|
||||
if diff.Offset >= base.Offset && diff.Offset+diff.Length <= base.Offset+base.Length {
|
||||
leftLen := int64(diff.Offset) - int64(base.Offset)
|
||||
if leftLen > 0 {
|
||||
result = append(result, &BuildMap{
|
||||
Offset: base.Offset,
|
||||
Length: uint64(leftLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset,
|
||||
})
|
||||
}
|
||||
|
||||
result = append(result, diff)
|
||||
di++
|
||||
|
||||
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
|
||||
rightLen := int64(base.Length) - rightShift
|
||||
|
||||
if rightLen > 0 {
|
||||
baseCopy[bi] = &BuildMap{
|
||||
Offset: base.Offset + uint64(rightShift),
|
||||
Length: uint64(rightLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
|
||||
}
|
||||
} else {
|
||||
bi++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Base starts after diff with overlap — emit diff, trim base.
|
||||
if base.Offset > diff.Offset {
|
||||
result = append(result, diff)
|
||||
di++
|
||||
|
||||
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
|
||||
rightLen := int64(base.Length) - rightShift
|
||||
|
||||
if rightLen > 0 {
|
||||
baseCopy[bi] = &BuildMap{
|
||||
Offset: base.Offset + uint64(rightShift),
|
||||
Length: uint64(rightLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
|
||||
}
|
||||
} else {
|
||||
bi++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Diff starts after base with overlap — emit left part of base.
|
||||
if diff.Offset > base.Offset {
|
||||
leftLen := int64(diff.Offset) - int64(base.Offset)
|
||||
if leftLen > 0 {
|
||||
result = append(result, &BuildMap{
|
||||
Offset: base.Offset,
|
||||
Length: uint64(leftLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset,
|
||||
})
|
||||
}
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Append remaining entries.
|
||||
result = append(result, baseCopy[bi:]...)
|
||||
result = append(result, diffMapping[di:]...)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NormalizeMappings merges adjacent entries with the same BuildID.
|
||||
func NormalizeMappings(mappings []*BuildMap) []*BuildMap {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]*BuildMap, 0, len(mappings))
|
||||
current := &BuildMap{
|
||||
Offset: mappings[0].Offset,
|
||||
Length: mappings[0].Length,
|
||||
BuildID: mappings[0].BuildID,
|
||||
BuildStorageOffset: mappings[0].BuildStorageOffset,
|
||||
}
|
||||
|
||||
for i := 1; i < len(mappings); i++ {
|
||||
m := mappings[i]
|
||||
if m.BuildID == current.BuildID {
|
||||
current.Length += m.Length
|
||||
} else {
|
||||
result = append(result, current)
|
||||
current = &BuildMap{
|
||||
Offset: m.Offset,
|
||||
Length: m.Length,
|
||||
BuildID: m.BuildID,
|
||||
BuildStorageOffset: m.BuildStorageOffset,
|
||||
}
|
||||
}
|
||||
}
|
||||
result = append(result, current)
|
||||
|
||||
return result
|
||||
}
|
||||
@ -1,285 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBlockSize is 4KB — standard page size for Firecracker.
|
||||
DefaultBlockSize int64 = 4096
|
||||
)
|
||||
|
||||
// ProcessMemfile reads a full memory file produced by Firecracker's
|
||||
// PUT /snapshot/create, identifies non-zero blocks, and writes only those
|
||||
// blocks to a compact diff file. Returns the Header describing the mapping.
|
||||
//
|
||||
// The output diff file contains non-zero blocks written sequentially.
|
||||
// The header maps each block in the full address space to either:
|
||||
// - A position in the diff file (for non-zero blocks)
|
||||
// - uuid.Nil (for zero/empty blocks, served as zeros without I/O)
|
||||
//
|
||||
// buildID identifies this snapshot generation in the header chain.
|
||||
func ProcessMemfile(memfilePath, diffPath, headerPath string, buildID uuid.UUID) (*Header, error) {
|
||||
src, err := os.Open(memfilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memfile: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
info, err := src.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat memfile: %w", err)
|
||||
}
|
||||
memSize := info.Size()
|
||||
|
||||
dst, err := os.Create(diffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
empty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, DefaultBlockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
n, err := io.ReadFull(src, buf)
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return nil, fmt.Errorf("read block %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Zero-pad the last block if it's short.
|
||||
if int64(n) < DefaultBlockSize {
|
||||
for j := n; j < int(DefaultBlockSize); j++ {
|
||||
buf[j] = 0
|
||||
}
|
||||
}
|
||||
|
||||
if isZeroBlock(buf) {
|
||||
empty[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write diff block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build header.
|
||||
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
|
||||
emptyMappings := CreateMapping(uuid.Nil, empty, DefaultBlockSize)
|
||||
merged := MergeMappings(dirtyMappings, emptyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := NewMetadata(buildID, uint64(DefaultBlockSize), uint64(memSize))
|
||||
header, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create header: %w", err)
|
||||
}
|
||||
|
||||
// Write header to disk.
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write header: %w", err)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// ProcessMemfileWithParent processes a memory file as a new generation on top
|
||||
// of an existing parent header. The new diff file contains only blocks that
|
||||
// differ from what the parent header maps. This is used for re-pause of a
|
||||
// sandbox that was restored from a snapshot.
|
||||
func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHeader *Header, buildID uuid.UUID) (*Header, error) {
|
||||
src, err := os.Open(memfilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memfile: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
info, err := src.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat memfile: %w", err)
|
||||
}
|
||||
memSize := info.Size()
|
||||
|
||||
dst, err := os.Create(diffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, DefaultBlockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
n, err := io.ReadFull(src, buf)
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return nil, fmt.Errorf("read block %d: %w", i, err)
|
||||
}
|
||||
|
||||
if int64(n) < DefaultBlockSize {
|
||||
for j := n; j < int(DefaultBlockSize); j++ {
|
||||
buf[j] = 0
|
||||
}
|
||||
}
|
||||
|
||||
if isZeroBlock(buf) {
|
||||
// For a diff memfile, zero blocks mean "not dirtied since resume" —
|
||||
// they should inherit the parent's mapping, not be zero-filled.
|
||||
continue
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write diff block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Only dirty blocks go into the diff overlay; MergeMappings preserves the
|
||||
// parent's mapping for everything else.
|
||||
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
|
||||
merged := MergeMappings(parentHeader.Mapping, dirtyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := parentHeader.Metadata.NextGeneration(buildID)
|
||||
header, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create header: %w", err)
|
||||
}
|
||||
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write header: %w", err)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// MergeDiffs consolidates multiple generation diff files into a single diff
|
||||
// file and resets the generation counter to 0. This is a pure file-level
|
||||
// operation — no Firecracker involvement.
|
||||
//
|
||||
// It reads each non-nil block from the appropriate diff file (as mapped by
|
||||
// the header), writes them all sequentially into a single new diff file,
|
||||
// and produces a fresh header pointing only at that file.
|
||||
//
|
||||
// diffFiles maps build ID (string) → open file path for each generation's diff.
|
||||
func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
|
||||
blockSize := int64(header.Metadata.BlockSize)
|
||||
mergedBuildID := uuid.New()
|
||||
|
||||
// Open all source diff files.
|
||||
sources := make(map[string]*os.File, len(diffFiles))
|
||||
for id, path := range diffFiles {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
// Close already opened files.
|
||||
for _, sf := range sources {
|
||||
sf.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
|
||||
}
|
||||
sources[id] = f
|
||||
}
|
||||
defer func() {
|
||||
for _, f := range sources {
|
||||
f.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
dst, err := os.Create(mergedDiffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create merged diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
empty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, blockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
offset := i * blockSize
|
||||
mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup block %d: %w", i, err)
|
||||
}
|
||||
|
||||
if *buildID == uuid.Nil {
|
||||
empty[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
src, ok := sources[buildID.String()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
|
||||
}
|
||||
|
||||
if _, err := src.ReadAt(buf, mappedOffset); err != nil {
|
||||
return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write merged block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build fresh header with generation 0.
|
||||
dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
|
||||
emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
|
||||
merged := MergeMappings(dirtyMappings, emptyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
|
||||
newHeader, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create merged header: %w", err)
|
||||
}
|
||||
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize merged header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write merged header: %w", err)
|
||||
}
|
||||
|
||||
return newHeader, nil
|
||||
}
|
||||
|
||||
// isZeroBlock checks if a block is entirely zero bytes.
|
||||
func isZeroBlock(block []byte) bool {
|
||||
// Fast path: compare 8 bytes at a time.
|
||||
for i := 0; i+8 <= len(block); i += 8 {
|
||||
if block[i] != 0 || block[i+1] != 0 || block[i+2] != 0 || block[i+3] != 0 ||
|
||||
block[i+4] != 0 || block[i+5] != 0 || block[i+6] != 0 || block[i+7] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// Tail bytes.
|
||||
for i := len(block) &^ 7; i < len(block); i++ {
|
||||
if block[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@ -1,92 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
// Package uffd implements a userfaultfd-based memory server for Firecracker
|
||||
// snapshot restore. When a VM is restored from a snapshot, instead of loading
|
||||
// the entire memory file upfront, the UFFD handler intercepts page faults
|
||||
// and serves memory pages on demand from the snapshot's compact diff file.
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <sys/syscall.h>
|
||||
#include <fcntl.h>
|
||||
#include <linux/userfaultfd.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
struct uffd_pagefault {
|
||||
__u64 flags;
|
||||
__u64 address;
|
||||
__u32 ptid;
|
||||
};
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
|
||||
UFFD_EVENT_FORK = C.UFFD_EVENT_FORK
|
||||
UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP
|
||||
UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE
|
||||
UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP
|
||||
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
|
||||
UFFDIO_COPY = C.UFFDIO_COPY
|
||||
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP
|
||||
)
|
||||
|
||||
type (
|
||||
uffdMsg = C.struct_uffd_msg
|
||||
uffdPagefault = C.struct_uffd_pagefault
|
||||
uffdioCopy = C.struct_uffdio_copy
|
||||
)
|
||||
|
||||
// fd wraps a userfaultfd file descriptor received from Firecracker.
|
||||
type fd uintptr
|
||||
|
||||
// copy installs a page into guest memory at the given address using UFFDIO_COPY.
|
||||
// mode controls write-protection: use UFFDIO_COPY_MODE_WP to preserve WP bit.
|
||||
func (f fd) copy(addr, pagesize uintptr, data []byte, mode C.ulonglong) error {
|
||||
alignedAddr := addr &^ (pagesize - 1)
|
||||
cpy := uffdioCopy{
|
||||
src: C.ulonglong(uintptr(unsafe.Pointer(&data[0]))),
|
||||
dst: C.ulonglong(alignedAddr),
|
||||
len: C.ulonglong(pagesize),
|
||||
mode: mode,
|
||||
copy: 0,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy)))
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
|
||||
if cpy.copy != C.longlong(pagesize) {
|
||||
return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the userfaultfd file descriptor.
|
||||
func (f fd) close() error {
|
||||
return syscall.Close(int(f))
|
||||
}
|
||||
|
||||
// getMsgEvent extracts the event type from a uffd_msg.
|
||||
func getMsgEvent(msg *uffdMsg) C.uchar {
|
||||
return msg.event
|
||||
}
|
||||
|
||||
// getMsgArg extracts the arg union from a uffd_msg.
|
||||
func getMsgArg(msg *uffdMsg) [24]byte {
|
||||
return msg.arg
|
||||
}
|
||||
|
||||
// getPagefaultAddress extracts the faulting address from a uffd_pagefault.
|
||||
func getPagefaultAddress(pf *uffdPagefault) uintptr {
|
||||
return uintptr(pf.address)
|
||||
}
|
||||
@ -1,41 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
//
|
||||
// Modifications by Omukk (Wrenn Sandbox): merged Region and Mapping into
|
||||
// single file, inlined shiftedOffset helper.
|
||||
|
||||
package uffd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Region is a mapping of guest memory to host virtual address space.
|
||||
// Firecracker sends these as JSON when connecting to the UFFD socket.
|
||||
// The JSON field names match Firecracker's UFFD protocol.
|
||||
type Region struct {
|
||||
BaseHostVirtAddr uintptr `json:"base_host_virt_addr"`
|
||||
Size uintptr `json:"size"`
|
||||
Offset uintptr `json:"offset"`
|
||||
PageSize uintptr `json:"page_size_kib"` // Actually in bytes despite the name.
|
||||
}
|
||||
|
||||
// Mapping translates between host virtual addresses and logical memory offsets.
|
||||
type Mapping struct {
|
||||
Regions []Region
|
||||
}
|
||||
|
||||
// NewMapping creates a Mapping from a list of regions.
|
||||
func NewMapping(regions []Region) *Mapping {
|
||||
return &Mapping{Regions: regions}
|
||||
}
|
||||
|
||||
// GetOffset converts a host virtual address to a logical memory file offset
|
||||
// and returns the page size. This is called on every UFFD page fault.
|
||||
func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) {
|
||||
for _, r := range m.Regions {
|
||||
if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.BaseHostVirtAddr+r.Size {
|
||||
offset := int64(hostVirtAddr-r.BaseHostVirtAddr) + int64(r.Offset)
|
||||
return offset, r.PageSize, nil
|
||||
}
|
||||
}
|
||||
return 0, 0, fmt.Errorf("address %#x not found in any memory region", hostVirtAddr)
|
||||
}
|
||||
@ -1,451 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
//
|
||||
// Modifications by Omukk (Wrenn Sandbox): replaced errgroup with WaitGroup
|
||||
// + semaphore, replaced fdexit abstraction with pipe, integrated with
|
||||
// snapshot.Header-based DiffFileSource instead of block.ReadonlyDevice,
|
||||
// fixed EAGAIN handling in poll loop.
|
||||
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <linux/userfaultfd.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/internal/snapshot"
|
||||
)
|
||||
|
||||
const (
|
||||
fdSize = 4
|
||||
regionMappingsSize = 1024
|
||||
maxConcurrentFaults = 4096
|
||||
)
|
||||
|
||||
// MemorySource provides page data for the UFFD handler.
|
||||
// Given a logical memory offset and a size, it returns the page data.
|
||||
type MemorySource interface {
|
||||
ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error)
|
||||
}
|
||||
|
||||
// Server manages the UFFD Unix socket lifecycle and page fault handling
|
||||
// for a single Firecracker snapshot restore.
|
||||
type Server struct {
|
||||
socketPath string
|
||||
source MemorySource
|
||||
lis *net.UnixListener
|
||||
|
||||
readyCh chan struct{}
|
||||
readyOnce sync.Once
|
||||
doneCh chan struct{}
|
||||
doneErr error
|
||||
|
||||
// exitPipe signals the poll loop to stop.
|
||||
exitR *os.File
|
||||
exitW *os.File
|
||||
|
||||
// Set by handle() after Firecracker connects; read by Prefetch()
|
||||
// after waiting on readyCh (which establishes happens-before).
|
||||
uffdFd fd
|
||||
mapping *Mapping
|
||||
|
||||
// Prefetch lifecycle: cancel stops the goroutine, prefetchDone is
|
||||
// closed when it exits. Stop() drains prefetchDone before returning
|
||||
// so the caller can safely close diff file handles.
|
||||
prefetchCancel context.CancelFunc
|
||||
prefetchDone chan struct{}
|
||||
}
|
||||
|
||||
// NewServer creates a UFFD server that will listen on the given socket path
|
||||
// and serve memory pages from the given source.
|
||||
func NewServer(socketPath string, source MemorySource) *Server {
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
source: source,
|
||||
readyCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix socket. Firecracker will connect to this
|
||||
// socket after loadSnapshot is called with the UFFD backend.
|
||||
// Start returns immediately; the server runs in a background goroutine.
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
lis, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.socketPath, Net: "unix"})
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on uffd socket: %w", err)
|
||||
}
|
||||
s.lis = lis
|
||||
|
||||
if err := os.Chmod(s.socketPath, 0o777); err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("chmod uffd socket: %w", err)
|
||||
}
|
||||
|
||||
// Create exit signal pipe.
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("create exit pipe: %w", err)
|
||||
}
|
||||
s.exitR = r
|
||||
s.exitW = w
|
||||
|
||||
go func() {
|
||||
defer close(s.doneCh)
|
||||
s.doneErr = s.handle(ctx)
|
||||
s.lis.Close()
|
||||
s.exitR.Close()
|
||||
s.exitW.Close()
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the UFFD handler is ready
|
||||
// (after Firecracker has connected and sent the uffd fd).
|
||||
func (s *Server) Ready() <-chan struct{} {
|
||||
return s.readyCh
|
||||
}
|
||||
|
||||
// Stop signals the UFFD poll loop to exit and waits for it to finish.
|
||||
// Also cancels and waits for any running prefetch goroutine.
|
||||
func (s *Server) Stop() error {
|
||||
if s.prefetchCancel != nil {
|
||||
s.prefetchCancel()
|
||||
}
|
||||
// Write a byte to the exit pipe to wake the poll loop.
|
||||
_, _ = s.exitW.Write([]byte{0})
|
||||
<-s.doneCh
|
||||
if s.prefetchDone != nil {
|
||||
<-s.prefetchDone
|
||||
}
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// Wait blocks until the server exits.
|
||||
func (s *Server) Wait() error {
|
||||
<-s.doneCh
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// handle accepts the Firecracker connection, receives the UFFD fd via
|
||||
// SCM_RIGHTS, and runs the page fault poll loop.
|
||||
func (s *Server) handle(ctx context.Context) error {
|
||||
conn, err := s.lis.Accept()
|
||||
if err != nil {
|
||||
return fmt.Errorf("accept uffd connection: %w", err)
|
||||
}
|
||||
|
||||
unixConn := conn.(*net.UnixConn)
|
||||
defer unixConn.Close()
|
||||
|
||||
// Read the memory region mappings (JSON) and the UFFD fd (SCM_RIGHTS).
|
||||
regionBuf := make([]byte, regionMappingsSize)
|
||||
uffdBuf := make([]byte, syscall.CmsgSpace(fdSize))
|
||||
|
||||
nRegion, nFd, _, _, err := unixConn.ReadMsgUnix(regionBuf, uffdBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd message: %w", err)
|
||||
}
|
||||
|
||||
var regions []Region
|
||||
if err := json.Unmarshal(regionBuf[:nRegion], ®ions); err != nil {
|
||||
return fmt.Errorf("parse memory regions: %w", err)
|
||||
}
|
||||
|
||||
controlMsgs, err := syscall.ParseSocketControlMessage(uffdBuf[:nFd])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse control messages: %w", err)
|
||||
}
|
||||
if len(controlMsgs) != 1 {
|
||||
return fmt.Errorf("expected 1 control message, got %d", len(controlMsgs))
|
||||
}
|
||||
|
||||
fds, err := syscall.ParseUnixRights(&controlMsgs[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse unix rights: %w", err)
|
||||
}
|
||||
if len(fds) != 1 {
|
||||
return fmt.Errorf("expected 1 fd, got %d", len(fds))
|
||||
}
|
||||
|
||||
uffdFd := fd(fds[0])
|
||||
defer uffdFd.close()
|
||||
|
||||
mapping := NewMapping(regions)
|
||||
|
||||
// Store for use by Prefetch().
|
||||
s.uffdFd = uffdFd
|
||||
s.mapping = mapping
|
||||
|
||||
slog.Info("uffd handler connected",
|
||||
"regions", len(regions),
|
||||
"fd", int(uffdFd),
|
||||
)
|
||||
|
||||
// Signal readiness.
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
|
||||
// Run the poll loop.
|
||||
return s.serve(ctx, uffdFd, mapping)
|
||||
}
|
||||
|
||||
// serve is the main poll loop. It polls the UFFD fd for page fault events
|
||||
// and the exit pipe for shutdown signals.
|
||||
func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
|
||||
pollFds := []unix.PollFd{
|
||||
{Fd: int32(uffdFd), Events: unix.POLLIN},
|
||||
{Fd: int32(s.exitR.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxConcurrentFaults)
|
||||
|
||||
// Always wait for in-flight goroutines before returning, so the caller
|
||||
// can safely close the uffd fd after serve returns.
|
||||
defer wg.Wait()
|
||||
|
||||
for {
|
||||
if _, err := unix.Poll(pollFds, -1); err != nil {
|
||||
if err == unix.EINTR || err == unix.EAGAIN {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("poll: %w", err)
|
||||
}
|
||||
|
||||
// Check exit signal.
|
||||
if pollFds[1].Revents&unix.POLLIN != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pollFds[0].Revents&unix.POLLIN == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the uffd_msg. The fd is O_NONBLOCK (set by Firecracker),
|
||||
// so EAGAIN is expected — just go back to poll.
|
||||
buf := make([]byte, unsafe.Sizeof(uffdMsg{}))
|
||||
n, err := readUffdMsg(uffdFd, buf)
|
||||
if err == syscall.EAGAIN {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd msg: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
|
||||
event := getMsgEvent(&msg)
|
||||
|
||||
switch event {
|
||||
case UFFD_EVENT_PAGEFAULT:
|
||||
// Handled below.
|
||||
case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK:
|
||||
// Non-fatal lifecycle events from the guest kernel (e.g. balloon
|
||||
// deflation, mmap/munmap). No action needed — continue polling.
|
||||
continue
|
||||
default:
|
||||
return fmt.Errorf("unexpected uffd event type: %d", event)
|
||||
}
|
||||
|
||||
arg := getMsgArg(&msg)
|
||||
pf := *(*uffdPagefault)(unsafe.Pointer(&arg[0]))
|
||||
addr := getPagefaultAddress(&pf)
|
||||
|
||||
offset, pagesize, err := mapping.GetOffset(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve address %#x: %w", addr, err)
|
||||
}
|
||||
|
||||
sem <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.faultPage(ctx, uffdFd, addr, offset, pagesize); err != nil {
|
||||
slog.Error("uffd fault page error",
|
||||
"addr", fmt.Sprintf("%#x", addr),
|
||||
"offset", offset,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// readUffdMsg reads a single uffd_msg, retrying on EINTR.
|
||||
// Returns (n, EAGAIN) if the non-blocking read has nothing available.
|
||||
func readUffdMsg(uffdFd fd, buf []byte) (int, error) {
|
||||
for {
|
||||
n, err := syscall.Read(int(uffdFd), buf)
|
||||
if err == syscall.EINTR {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// faultPage fetches a page from the memory source and copies it into
|
||||
// guest memory via UFFDIO_COPY.
|
||||
func (s *Server) faultPage(ctx context.Context, uffdFd fd, addr uintptr, offset int64, pagesize uintptr) error {
|
||||
data, err := s.source.ReadPage(ctx, offset, int64(pagesize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read page at offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// Mode 0: no write-protect. Standard Firecracker does not register
|
||||
// UFFD ranges with WP support, so UFFDIO_COPY_MODE_WP would fail.
|
||||
if err := uffdFd.copy(addr, pagesize, data, 0); err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Page already mapped (race with prefetch or concurrent fault).
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("uffdio_copy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prefetch proactively loads all guest memory pages in the background.
|
||||
// It iterates over every page in every UFFD region and copies it from the
|
||||
// diff file into guest memory via UFFDIO_COPY. Pages already loaded by
|
||||
// on-demand faults return nil from faultPage (EEXIST handled internally).
|
||||
// This eliminates the per-request latency caused by lazy page faulting
|
||||
// after snapshot restore.
|
||||
//
|
||||
// The goroutine blocks on readyCh before reading the uffd fd and mapping
|
||||
// fields (establishes happens-before with handle()). It uses an internal
|
||||
// context independent of the caller's RPC context so it survives after the
|
||||
// create/resume RPC returns. Stop() cancels and joins the goroutine.
|
||||
func (s *Server) Prefetch() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.prefetchCancel = cancel
|
||||
s.prefetchDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(s.prefetchDone)
|
||||
|
||||
// Wait for Firecracker to connect and send the uffd fd.
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
uffdFd := s.uffdFd
|
||||
mapping := s.mapping
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var total, errored int
|
||||
for _, region := range mapping.Regions {
|
||||
pageSize := region.PageSize
|
||||
if pageSize == 0 {
|
||||
continue
|
||||
}
|
||||
for off := uintptr(0); off < region.Size; off += pageSize {
|
||||
if ctx.Err() != nil {
|
||||
slog.Debug("uffd prefetch cancelled",
|
||||
"pages", total, "errors", errored)
|
||||
return
|
||||
}
|
||||
|
||||
addr := region.BaseHostVirtAddr + off
|
||||
memOffset := int64(off) + int64(region.Offset)
|
||||
|
||||
if err := s.faultPage(ctx, uffdFd, addr, memOffset, pageSize); err != nil {
|
||||
errored++
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info("uffd prefetch complete",
|
||||
"pages", total, "errors", errored)
|
||||
}()
|
||||
}
|
||||
|
||||
// DiffFileSource serves pages from a snapshot's compact diff file using
|
||||
// the header's block mapping to resolve offsets.
|
||||
type DiffFileSource struct {
|
||||
header *snapshot.Header
|
||||
// diffs maps build ID → open file handle for each generation's diff file.
|
||||
diffs map[string]*os.File
|
||||
}
|
||||
|
||||
// NewDiffFileSource creates a memory source backed by snapshot diff files.
|
||||
// diffs maps build ID string to the file path of each generation's diff file.
|
||||
func NewDiffFileSource(header *snapshot.Header, diffPaths map[string]string) (*DiffFileSource, error) {
|
||||
diffs := make(map[string]*os.File, len(diffPaths))
|
||||
for id, path := range diffPaths {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
// Close already opened files.
|
||||
for _, opened := range diffs {
|
||||
opened.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open diff file %s: %w", path, err)
|
||||
}
|
||||
diffs[id] = f
|
||||
}
|
||||
return &DiffFileSource{header: header, diffs: diffs}, nil
|
||||
}
|
||||
|
||||
// ReadPage resolves a memory offset through the header mapping and reads
|
||||
// the corresponding page from the correct generation's diff file.
|
||||
func (s *DiffFileSource) ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error) {
|
||||
mappedOffset, _, buildID, err := s.header.GetShiftedMapping(ctx, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// uuid.Nil means zero-fill (empty page).
|
||||
var nilUUID [16]byte
|
||||
if *buildID == nilUUID {
|
||||
return make([]byte, size), nil
|
||||
}
|
||||
|
||||
f, ok := s.diffs[buildID.String()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no diff file for build %s", buildID)
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
n, err := f.ReadAt(buf, mappedOffset)
|
||||
if err != nil && int64(n) < size {
|
||||
return nil, fmt.Errorf("read diff at offset %d: %w", mappedOffset, err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Close closes all open diff file handles.
|
||||
func (s *DiffFileSource) Close() error {
|
||||
var errs []error
|
||||
for _, f := range s.diffs {
|
||||
if err := f.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
232
internal/vm/ch.go
Normal file
232
internal/vm/ch.go
Normal file
@ -0,0 +1,232 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// chClient talks to the Cloud Hypervisor HTTP API over a Unix socket.
|
||||
type chClient struct {
|
||||
http *http.Client
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func newCHClient(socketPath string) *chClient {
|
||||
return &chClient{
|
||||
socketPath: socketPath,
|
||||
http: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", socketPath)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *chClient) do(ctx context.Context, method, path string, body any) error {
|
||||
return c.doJSON(ctx, method, path, body, nil)
|
||||
}
|
||||
|
||||
// doJSON sends a request and optionally decodes a JSON response into out.
|
||||
// out may be nil if the response body should be discarded.
|
||||
func (c *chClient) doJSON(ctx context.Context, method, path string, body, out any) error {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s: %w", method, path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s %s: status %d: %s", method, path, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
||||
return fmt.Errorf("%s %s: decode response: %w", method, path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
|
||||
// --- CH API payload types ---
|
||||
|
||||
type chPayload struct {
|
||||
Firmware string `json:"firmware,omitempty"`
|
||||
Kernel string `json:"kernel"`
|
||||
Cmdline string `json:"cmdline"`
|
||||
}
|
||||
|
||||
type chCPUs struct {
|
||||
BootVCPUs int `json:"boot_vcpus"`
|
||||
MaxVCPUs int `json:"max_vcpus"`
|
||||
}
|
||||
|
||||
type chMemory struct {
|
||||
Size uint64 `json:"size"`
|
||||
Shared bool `json:"shared,omitempty"`
|
||||
// Thp uses a pointer with NO omitempty so explicit false is always
|
||||
// serialized (CH defaults to true). Must be false so the backing memfile
|
||||
// remains 4 KiB-granular: balloon-reported free pages get punched as
|
||||
// holes and CH's SEEK_DATA/SEEK_HOLE snapshot writer (v52+) skips them.
|
||||
// A nil Thp would silently re-enable THP and break sparse snapshots —
|
||||
// rejecting "thp": null at the wire is preferable to a silent fallback.
|
||||
Thp *bool `json:"thp"`
|
||||
Prefault bool `json:"prefault,omitempty"`
|
||||
HotplugSize uint64 `json:"hotplug_size,omitempty"`
|
||||
HotplugMethod string `json:"hotplug_method,omitempty"`
|
||||
}
|
||||
|
||||
type chDisk struct {
|
||||
Path string `json:"path"`
|
||||
Readonly bool `json:"readonly,omitempty"`
|
||||
ImageType string `json:"image_type,omitempty"`
|
||||
}
|
||||
|
||||
type chNet struct {
|
||||
Tap string `json:"tap"`
|
||||
MAC string `json:"mac"`
|
||||
NumQs int `json:"num_queues,omitempty"`
|
||||
QueueS int `json:"queue_size,omitempty"`
|
||||
}
|
||||
|
||||
type chBalloon struct {
|
||||
Size int64 `json:"size"`
|
||||
DeflateOnOOM bool `json:"deflate_on_oom"`
|
||||
FreePageRep bool `json:"free_page_reporting,omitempty"`
|
||||
}
|
||||
|
||||
type chConsole struct {
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type chCreatePayload struct {
|
||||
Payload chPayload `json:"payload"`
|
||||
CPUs chCPUs `json:"cpus"`
|
||||
Memory chMemory `json:"memory"`
|
||||
Disks []chDisk `json:"disks"`
|
||||
Net []chNet `json:"net"`
|
||||
Balloon *chBalloon `json:"balloon,omitempty"`
|
||||
Serial chConsole `json:"serial"`
|
||||
Console chConsole `json:"console"`
|
||||
}
|
||||
|
||||
// createVM sends the full VM configuration as a single payload.
|
||||
func (c *chClient) createVM(ctx context.Context, cfg *VMConfig) error {
|
||||
memBytes := uint64(cfg.MemoryMB) * 1024 * 1024
|
||||
|
||||
payload := chCreatePayload{
|
||||
Payload: chPayload{
|
||||
Kernel: cfg.KernelPath,
|
||||
Cmdline: cfg.kernelArgs(),
|
||||
},
|
||||
CPUs: chCPUs{
|
||||
BootVCPUs: cfg.VCPUs,
|
||||
MaxVCPUs: cfg.VCPUs,
|
||||
},
|
||||
Memory: chMemory{
|
||||
Size: memBytes,
|
||||
Shared: true,
|
||||
Thp: boolPtr(false),
|
||||
},
|
||||
Disks: []chDisk{
|
||||
{
|
||||
Path: cfg.SandboxDir + "/rootfs.ext4",
|
||||
ImageType: "Raw",
|
||||
},
|
||||
},
|
||||
Net: []chNet{
|
||||
{
|
||||
Tap: cfg.TapDevice,
|
||||
MAC: cfg.TapMAC,
|
||||
},
|
||||
},
|
||||
Balloon: &chBalloon{
|
||||
Size: 0,
|
||||
DeflateOnOOM: true,
|
||||
FreePageRep: true,
|
||||
},
|
||||
Serial: chConsole{
|
||||
Mode: "Tty",
|
||||
},
|
||||
Console: chConsole{
|
||||
Mode: "Off",
|
||||
},
|
||||
}
|
||||
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.create", payload)
|
||||
}
|
||||
|
||||
// bootVM starts the VM after creation.
|
||||
func (c *chClient) bootVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.boot", nil)
|
||||
}
|
||||
|
||||
// shutdownVMM cleanly shuts down the Cloud Hypervisor VMM process.
|
||||
func (c *chClient) shutdownVMM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vmm.shutdown", nil)
|
||||
}
|
||||
|
||||
// resizeBalloon adjusts the balloon target at runtime.
|
||||
// sizeBytes is memory to take FROM the guest (0 = give all back).
|
||||
func (c *chClient) resizeBalloon(ctx context.Context, sizeBytes int64) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.resize", map[string]int64{
|
||||
"desired_balloon": sizeBytes,
|
||||
})
|
||||
}
|
||||
|
||||
// pauseVM freezes guest vCPUs and devices via the CH API.
|
||||
func (c *chClient) pauseVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.pause", nil)
|
||||
}
|
||||
|
||||
// resumeVM unfreezes a paused VM via the CH API.
|
||||
func (c *chClient) resumeVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.resume", nil)
|
||||
}
|
||||
|
||||
// snapshotVM dumps VM config + state + memory to a directory URL of the form
|
||||
// `file:///abs/path/`. VM must be paused before calling.
|
||||
func (c *chClient) snapshotVM(ctx context.Context, destURL string) error {
|
||||
return c.do(ctx, http.MethodPut, "/api/v1/vm.snapshot", map[string]string{
|
||||
"destination_url": destURL,
|
||||
})
|
||||
}
|
||||
|
||||
// vmInfo reports the runtime state of the VM. Used after a restore to confirm
|
||||
// CH successfully hydrated the snapshot before registering the VM.
|
||||
func (c *chClient) vmInfo(ctx context.Context) (state string, err error) {
|
||||
var resp struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
if err := c.doJSON(ctx, http.MethodGet, "/api/v1/vm.info", nil, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.State, nil
|
||||
}
|
||||
104
internal/vm/cleanup.go
Normal file
104
internal/vm/cleanup.go
Normal file
@ -0,0 +1,104 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CleanupStaleProcesses kills any cloud-hypervisor processes left behind by a
|
||||
// previous agent that crashed without graceful shutdown. Must run at agent
|
||||
// startup before devicemapper.CleanupStaleDevices — a still-running CH process
|
||||
// holds the dm-snapshot open and would cause "Device or resource busy" on
|
||||
// dmsetup remove.
|
||||
//
|
||||
// Matches processes by argv containing the wrenn CH API socket path
|
||||
// (/tmp/ch-<sandboxID>.sock) so we don't kill unrelated cloud-hypervisor VMs
|
||||
// the operator may be running.
|
||||
//
|
||||
// Also removes stale /tmp/ch-*.sock files once the owning process is gone.
|
||||
func CleanupStaleProcesses() {
|
||||
socketPattern := regexp.MustCompile(`/tmp/ch-[A-Za-z0-9-]+\.sock`)
|
||||
|
||||
pids, err := scanProcs()
|
||||
if err != nil {
|
||||
slog.Debug("scan procs failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
killed := 0
|
||||
for _, pid := range pids {
|
||||
cmdline, err := readCmdline(pid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(cmdline, "cloud-hypervisor") {
|
||||
continue
|
||||
}
|
||||
if !socketPattern.MatchString(cmdline) {
|
||||
continue
|
||||
}
|
||||
slog.Warn("killing stale cloud-hypervisor process", "pid", pid, "cmdline", cmdline)
|
||||
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
|
||||
slog.Warn("SIGTERM stale CH failed", "pid", pid, "error", err)
|
||||
}
|
||||
killed++
|
||||
}
|
||||
|
||||
// Give SIGTERM'd processes a brief window to exit so subsequent dm/loop
|
||||
// teardown sees no open fd, then SIGKILL anything still alive.
|
||||
if killed > 0 {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
for _, pid := range pids {
|
||||
cmdline, err := readCmdline(pid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(cmdline, "cloud-hypervisor") || !socketPattern.MatchString(cmdline) {
|
||||
continue
|
||||
}
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
matches, _ := filepath.Glob("/tmp/ch-*.sock")
|
||||
for _, sock := range matches {
|
||||
if err := os.Remove(sock); err == nil {
|
||||
slog.Info("removed stale CH socket", "path", sock)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func scanProcs() ([]int, error) {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pids []int
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
pid, err := strconv.Atoi(e.Name())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pids = append(pids, pid)
|
||||
}
|
||||
return pids, nil
|
||||
}
|
||||
|
||||
func readCmdline(pid int) (string, error) {
|
||||
b, err := os.ReadFile("/proc/" + strconv.Itoa(pid) + "/cmdline")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// /proc/<pid>/cmdline is NUL-separated; convert to spaces for substring match.
|
||||
return strings.ReplaceAll(string(b), "\x00", " "), nil
|
||||
}
|
||||
@ -2,13 +2,25 @@ package vm
|
||||
|
||||
import "fmt"
|
||||
|
||||
// VMConfig holds the configuration for creating a Firecracker microVM.
|
||||
// SandboxTmpDir returns the per-sandbox tmpfs mount point used inside the
|
||||
// VMM's private mount namespace. Recorded as the disk path in CH's saved
|
||||
// config.json, so restore paths must reconstruct it exactly to make the
|
||||
// symlink prelude resolve.
|
||||
func SandboxTmpDir(sandboxID string) string {
|
||||
return fmt.Sprintf("/tmp/ch-vm-%s", sandboxID)
|
||||
}
|
||||
|
||||
// SandboxSocketPath returns the Cloud Hypervisor API socket path for a sandbox.
|
||||
func SandboxSocketPath(sandboxID string) string {
|
||||
return fmt.Sprintf("/tmp/ch-%s.sock", sandboxID)
|
||||
}
|
||||
|
||||
// VMConfig holds the configuration for creating a Cloud Hypervisor microVM.
|
||||
type VMConfig struct {
|
||||
// SandboxID is the unique identifier for this sandbox (e.g., "cl-a1b2c3d4").
|
||||
SandboxID string
|
||||
|
||||
// TemplateID is the template UUID string used to populate MMDS metadata
|
||||
// so that envd can read WRENN_TEMPLATE_ID from inside the guest.
|
||||
// TemplateID is the template UUID string, passed to envd via PostInit.
|
||||
TemplateID string
|
||||
|
||||
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
|
||||
@ -25,12 +37,12 @@ type VMConfig struct {
|
||||
MemoryMB int
|
||||
|
||||
// NetworkNamespace is the name of the network namespace to launch
|
||||
// Firecracker inside (e.g., "ns-1"). The namespace must already exist
|
||||
// Cloud Hypervisor inside (e.g., "ns-1"). The namespace must already exist
|
||||
// with a TAP device configured.
|
||||
NetworkNamespace string
|
||||
|
||||
// TapDevice is the name of the TAP device inside the network namespace
|
||||
// that Firecracker will attach to (e.g., "tap0").
|
||||
// that Cloud Hypervisor will attach to (e.g., "tap0").
|
||||
TapDevice string
|
||||
|
||||
// TapMAC is the MAC address for the TAP device.
|
||||
@ -45,19 +57,34 @@ type VMConfig struct {
|
||||
// NetMask is the subnet mask for the guest network (e.g., "255.255.255.252").
|
||||
NetMask string
|
||||
|
||||
// FirecrackerBin is the path to the firecracker binary.
|
||||
FirecrackerBin string
|
||||
// VMMBin is the path to the cloud-hypervisor binary.
|
||||
VMMBin string
|
||||
|
||||
// SocketPath is the path for the Firecracker API Unix socket.
|
||||
// SocketPath is the path for the Cloud Hypervisor API Unix socket.
|
||||
SocketPath string
|
||||
|
||||
// SandboxDir is the tmpfs mount point for per-sandbox files inside the
|
||||
// mount namespace (e.g., "/fc-vm").
|
||||
// mount namespace (e.g., "/ch-vm").
|
||||
SandboxDir string
|
||||
|
||||
// InitPath is the path to the init process inside the guest.
|
||||
// Defaults to "/sbin/init" if empty.
|
||||
InitPath string
|
||||
|
||||
// RestoreFromDir, if non-empty, switches the process launcher into restore
|
||||
// mode. CH is invoked with `--restore source_url=file://{dir}/` instead of
|
||||
// the fresh-boot path. The directory must contain CH's snapshot artefacts
|
||||
// (config.json, state.json, memory-ranges, memory file).
|
||||
RestoreFromDir string
|
||||
|
||||
// RestoreLazyMemory enables `memory_restore_mode=ondemand` so guest pages
|
||||
// fault in lazily via userfaultfd. Only honored when RestoreFromDir is set.
|
||||
RestoreLazyMemory bool
|
||||
|
||||
// LogDir is the directory for Cloud Hypervisor log files. If set, CH
|
||||
// stdout/stderr are written to {LogDir}/ch-{SandboxID}.log instead of
|
||||
// the parent process's stdout/stderr.
|
||||
LogDir string
|
||||
}
|
||||
|
||||
func (c *VMConfig) applyDefaults() {
|
||||
@ -67,14 +94,14 @@ func (c *VMConfig) applyDefaults() {
|
||||
if c.MemoryMB == 0 {
|
||||
c.MemoryMB = 512
|
||||
}
|
||||
if c.FirecrackerBin == "" {
|
||||
c.FirecrackerBin = "/usr/local/bin/firecracker"
|
||||
if c.VMMBin == "" {
|
||||
c.VMMBin = "/usr/local/bin/cloud-hypervisor"
|
||||
}
|
||||
if c.SocketPath == "" {
|
||||
c.SocketPath = fmt.Sprintf("/tmp/fc-%s.sock", c.SandboxID)
|
||||
c.SocketPath = SandboxSocketPath(c.SandboxID)
|
||||
}
|
||||
if c.SandboxDir == "" {
|
||||
c.SandboxDir = "/tmp/fc-vm"
|
||||
c.SandboxDir = SandboxTmpDir(c.SandboxID)
|
||||
}
|
||||
if c.TapDevice == "" {
|
||||
c.TapDevice = "tap0"
|
||||
@ -95,7 +122,7 @@ func (c *VMConfig) kernelArgs() string {
|
||||
)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 clocksource=kvm-clock init=%s %s",
|
||||
"console=ttyS0 root=/dev/vda rw rootflags=nodiscard reboot=k panic=1 quiet loglevel=1 init_on_free=1 clocksource=kvm-clock init=%s %s",
|
||||
c.InitPath, ipArg,
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,202 +0,0 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// fcClient talks to the Firecracker HTTP API over a Unix socket.
|
||||
type fcClient struct {
|
||||
http *http.Client
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func newFCClient(socketPath string) *fcClient {
|
||||
return &fcClient{
|
||||
socketPath: socketPath,
|
||||
http: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", socketPath)
|
||||
},
|
||||
},
|
||||
// No global timeout — callers pass context.Context with appropriate
|
||||
// deadlines. A fixed 10s timeout was too short for snapshot/resume
|
||||
// operations on large-memory VMs (20GB+ memfiles).
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fcClient) do(ctx context.Context, method, path string, body any) error {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
// The host in the URL is ignored for Unix sockets; we use "localhost" by convention.
|
||||
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s: %w", method, path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s %s: status %d: %s", method, path, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setBootSource configures the kernel and boot args.
|
||||
func (c *fcClient) setBootSource(ctx context.Context, kernelPath, bootArgs string) error {
|
||||
return c.do(ctx, http.MethodPut, "/boot-source", map[string]string{
|
||||
"kernel_image_path": kernelPath,
|
||||
"boot_args": bootArgs,
|
||||
})
|
||||
}
|
||||
|
||||
// setRootfsDrive configures the root filesystem drive.
|
||||
func (c *fcClient) setRootfsDrive(ctx context.Context, driveID, path string, readOnly bool) error {
|
||||
return c.do(ctx, http.MethodPut, "/drives/"+driveID, map[string]any{
|
||||
"drive_id": driveID,
|
||||
"path_on_host": path,
|
||||
"is_root_device": true,
|
||||
"is_read_only": readOnly,
|
||||
})
|
||||
}
|
||||
|
||||
// setNetworkInterface configures a network interface attached to a TAP device.
|
||||
// A tx_rate_limiter caps sustained guest→host throughput to prevent user
|
||||
// application traffic from completely saturating the TAP device and starving
|
||||
// envd control traffic (PTY, exec, file ops).
|
||||
func (c *fcClient) setNetworkInterface(ctx context.Context, ifaceID, tapName, macAddr string) error {
|
||||
return c.do(ctx, http.MethodPut, "/network-interfaces/"+ifaceID, map[string]any{
|
||||
"iface_id": ifaceID,
|
||||
"host_dev_name": tapName,
|
||||
"guest_mac": macAddr,
|
||||
"tx_rate_limiter": map[string]any{
|
||||
"bandwidth": map[string]any{
|
||||
"size": 209715200, // 200 MB/s sustained
|
||||
"refill_time": 1000, // refill period: 1 second
|
||||
"one_time_burst": 104857600, // 100 MB initial burst
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// setMachineConfig configures vCPUs, memory, and other machine settings.
|
||||
func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error {
|
||||
return c.do(ctx, http.MethodPut, "/machine-config", map[string]any{
|
||||
"vcpu_count": vcpus,
|
||||
"mem_size_mib": memMB,
|
||||
"smt": false,
|
||||
})
|
||||
}
|
||||
|
||||
// setMMDSConfig enables MMDS V2 token-based access on the given network interface.
|
||||
// Must be called before startVM.
|
||||
func (c *fcClient) setMMDSConfig(ctx context.Context, ifaceID string) error {
|
||||
return c.do(ctx, http.MethodPut, "/mmds/config", map[string]any{
|
||||
"version": "V2",
|
||||
"network_interfaces": []string{ifaceID},
|
||||
})
|
||||
}
|
||||
|
||||
// mmdsMetadata is the metadata payload written to the Firecracker MMDS store.
|
||||
// envd reads this via PollForMMDSOpts to populate WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID.
|
||||
type mmdsMetadata struct {
|
||||
SandboxID string `json:"instanceID"`
|
||||
TemplateID string `json:"envID"`
|
||||
}
|
||||
|
||||
// setMMDS writes sandbox metadata to the Firecracker MMDS store.
|
||||
// Can be called after the VM has started.
|
||||
func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) error {
|
||||
return c.do(ctx, http.MethodPut, "/mmds", mmdsMetadata{
|
||||
SandboxID: sandboxID,
|
||||
TemplateID: templateID,
|
||||
})
|
||||
}
|
||||
|
||||
// setBalloon configures the Firecracker balloon device for dynamic memory
|
||||
// management. deflateOnOom lets the guest reclaim balloon pages under memory
|
||||
// pressure. statsInterval enables periodic stats via GET /balloon/statistics.
|
||||
// Must be called before startVM.
|
||||
func (c *fcClient) setBalloon(ctx context.Context, amountMiB int, deflateOnOom bool, statsIntervalS int) error {
|
||||
return c.do(ctx, http.MethodPut, "/balloon", map[string]any{
|
||||
"amount_mib": amountMiB,
|
||||
"deflate_on_oom": deflateOnOom,
|
||||
"stats_polling_interval_s": statsIntervalS,
|
||||
})
|
||||
}
|
||||
|
||||
// updateBalloon adjusts the balloon target at runtime.
|
||||
func (c *fcClient) updateBalloon(ctx context.Context, amountMiB int) error {
|
||||
return c.do(ctx, http.MethodPatch, "/balloon", map[string]any{
|
||||
"amount_mib": amountMiB,
|
||||
})
|
||||
}
|
||||
|
||||
// startVM issues the InstanceStart action.
|
||||
func (c *fcClient) startVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/actions", map[string]string{
|
||||
"action_type": "InstanceStart",
|
||||
})
|
||||
}
|
||||
|
||||
// pauseVM pauses the microVM.
|
||||
func (c *fcClient) pauseVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
|
||||
"state": "Paused",
|
||||
})
|
||||
}
|
||||
|
||||
// resumeVM resumes a paused microVM.
|
||||
func (c *fcClient) resumeVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
|
||||
"state": "Resumed",
|
||||
})
|
||||
}
|
||||
|
||||
// createSnapshot creates a VM snapshot.
|
||||
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
|
||||
func (c *fcClient) createSnapshot(ctx context.Context, snapPath, memPath, snapshotType string) error {
|
||||
return c.do(ctx, http.MethodPut, "/snapshot/create", map[string]any{
|
||||
"snapshot_type": snapshotType,
|
||||
"snapshot_path": snapPath,
|
||||
"mem_file_path": memPath,
|
||||
})
|
||||
}
|
||||
|
||||
// loadSnapshotWithUffd loads a VM snapshot using a UFFD socket for
|
||||
// lazy memory loading. Firecracker will connect to the socket and
|
||||
// send the uffd fd + memory region mappings.
|
||||
func (c *fcClient) loadSnapshotWithUffd(ctx context.Context, snapPath, uffdSocketPath string) error {
|
||||
return c.do(ctx, http.MethodPut, "/snapshot/load", map[string]any{
|
||||
"snapshot_path": snapPath,
|
||||
"resume_vm": false,
|
||||
"mem_backend": map[string]any{
|
||||
"backend_type": "Uffd",
|
||||
"backend_path": uffdSocketPath,
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -1,128 +0,0 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// process represents a running Firecracker process with mount and network
|
||||
// namespace isolation.
|
||||
type process struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
|
||||
exitCh chan struct{}
|
||||
exitErr error
|
||||
}
|
||||
|
||||
// startProcess launches the Firecracker binary inside an isolated mount namespace
|
||||
// and the specified network namespace. The launch sequence:
|
||||
//
|
||||
// 1. unshare -m: creates a private mount namespace
|
||||
// 2. mount --make-rprivate /: prevents mount propagation to host
|
||||
// 3. mount tmpfs at SandboxDir: ephemeral workspace for this VM
|
||||
// 4. symlink kernel and rootfs into SandboxDir
|
||||
// 5. ip netns exec <ns>: enters the network namespace where TAP is configured
|
||||
// 6. exec firecracker with the API socket path
|
||||
func startProcess(ctx context.Context, cfg *VMConfig) (*process, error) {
|
||||
// Use a background context for the long-lived Firecracker process.
|
||||
// The request context (ctx) is only used for the startup phase — we must
|
||||
// not tie the VM's lifetime to the HTTP request that created it.
|
||||
execCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
script := buildStartScript(cfg)
|
||||
|
||||
cmd := exec.CommandContext(execCtx, "unshare", "-m", "--", "bash", "-c", script)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true, // new session so signals don't propagate from parent
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start firecracker process: %w", err)
|
||||
}
|
||||
|
||||
p := &process{
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
exitCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.exitErr = cmd.Wait()
|
||||
close(p.exitCh)
|
||||
}()
|
||||
|
||||
slog.Info("firecracker process started",
|
||||
"pid", cmd.Process.Pid,
|
||||
"sandbox", cfg.SandboxID,
|
||||
)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// buildStartScript generates the bash script that sets up the mount namespace,
|
||||
// symlinks kernel/rootfs, and execs Firecracker inside the network namespace.
|
||||
func buildStartScript(cfg *VMConfig) string {
|
||||
return fmt.Sprintf(`
|
||||
set -euo pipefail
|
||||
|
||||
# Prevent mount propagation to the host
|
||||
mount --make-rprivate /
|
||||
|
||||
# Create ephemeral tmpfs workspace
|
||||
mkdir -p %[1]s
|
||||
mount -t tmpfs tmpfs %[1]s
|
||||
|
||||
# Symlink kernel and rootfs into the workspace
|
||||
ln -s %[2]s %[1]s/vmlinux
|
||||
ln -s %[3]s %[1]s/rootfs.ext4
|
||||
|
||||
# Launch Firecracker inside the network namespace
|
||||
exec ip netns exec %[4]s %[5]s --api-sock %[6]s
|
||||
`,
|
||||
cfg.SandboxDir, // 1
|
||||
cfg.KernelPath, // 2
|
||||
cfg.RootfsPath, // 3
|
||||
cfg.NetworkNamespace, // 4
|
||||
cfg.FirecrackerBin, // 5
|
||||
cfg.SocketPath, // 6
|
||||
)
|
||||
}
|
||||
|
||||
// stop sends SIGTERM and waits for the process to exit. If it doesn't exit
|
||||
// within 10 seconds, SIGKILL is sent.
|
||||
func (p *process) stop() error {
|
||||
if p.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send SIGTERM to the process group (negative PID).
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
slog.Debug("sigterm failed, process may have exited", "error", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.exitCh:
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
slog.Warn("firecracker did not exit after SIGTERM, sending SIGKILL")
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
slog.Debug("sigkill failed", "error", err)
|
||||
}
|
||||
<-p.exitCh
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// exited returns a channel that is closed when the process exits.
|
||||
func (p *process) exited() <-chan struct{} {
|
||||
return p.exitCh
|
||||
}
|
||||
@ -5,18 +5,19 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VM represents a running Firecracker microVM.
|
||||
// VM represents a running Cloud Hypervisor microVM.
|
||||
type VM struct {
|
||||
Config VMConfig
|
||||
process *process
|
||||
client *fcClient
|
||||
client *chClient
|
||||
}
|
||||
|
||||
// Manager handles the lifecycle of Firecracker microVMs.
|
||||
// Manager handles the lifecycle of Cloud Hypervisor microVMs.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
// vms tracks running VMs by sandbox ID.
|
||||
@ -30,7 +31,7 @@ func NewManager() *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// Create boots a new Firecracker microVM with the given configuration.
|
||||
// Create boots a new Cloud Hypervisor microVM with the given configuration.
|
||||
// The network namespace and TAP device must already be set up.
|
||||
func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
cfg.applyDefaults()
|
||||
@ -38,7 +39,6 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
// Clean up any leftover socket from a previous run.
|
||||
os.Remove(cfg.SocketPath)
|
||||
|
||||
slog.Info("creating VM",
|
||||
@ -47,8 +47,8 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
"memory_mb", cfg.MemoryMB,
|
||||
)
|
||||
|
||||
// Step 1: Launch the Firecracker process.
|
||||
proc, err := startProcess(ctx, &cfg)
|
||||
// Step 1: Launch the Cloud Hypervisor process.
|
||||
proc, err := startProcess(&cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
@ -59,25 +59,18 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
return nil, fmt.Errorf("wait for socket: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Configure the VM via the Firecracker API.
|
||||
client := newFCClient(cfg.SocketPath)
|
||||
// Step 3: Configure and boot the VM via a single API call.
|
||||
client := newCHClient(cfg.SocketPath)
|
||||
|
||||
if err := configureVM(ctx, client, &cfg); err != nil {
|
||||
if err := client.createVM(ctx, &cfg); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("configure VM: %w", err)
|
||||
return nil, fmt.Errorf("create VM config: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Start the VM.
|
||||
if err := client.startVM(ctx); err != nil {
|
||||
// Step 4: Boot the VM.
|
||||
if err := client.bootVM(ctx); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("start VM: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Push sandbox metadata into MMDS so envd can read
|
||||
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
|
||||
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("set MMDS metadata: %w", err)
|
||||
return nil, fmt.Errorf("boot VM: %w", err)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
@ -95,78 +88,34 @@ func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
// configureVM sends the configuration to Firecracker via its HTTP API.
|
||||
func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
|
||||
// Boot source (kernel + args)
|
||||
if err := client.setBootSource(ctx, cfg.KernelPath, cfg.kernelArgs()); err != nil {
|
||||
return fmt.Errorf("set boot source: %w", err)
|
||||
}
|
||||
|
||||
// Root drive — use the symlink path inside the mount namespace so that
|
||||
// snapshots record a stable path that works on restore.
|
||||
rootfsSymlink := cfg.SandboxDir + "/rootfs.ext4"
|
||||
if err := client.setRootfsDrive(ctx, "rootfs", rootfsSymlink, false); err != nil {
|
||||
return fmt.Errorf("set rootfs drive: %w", err)
|
||||
}
|
||||
|
||||
// Network interface
|
||||
if err := client.setNetworkInterface(ctx, "eth0", cfg.TapDevice, cfg.TapMAC); err != nil {
|
||||
return fmt.Errorf("set network interface: %w", err)
|
||||
}
|
||||
|
||||
// Machine config (vCPUs + memory)
|
||||
if err := client.setMachineConfig(ctx, cfg.VCPUs, cfg.MemoryMB); err != nil {
|
||||
return fmt.Errorf("set machine config: %w", err)
|
||||
}
|
||||
|
||||
// Balloon device — allows the host to reclaim unused guest memory.
|
||||
// Start with 0 (no inflation). deflate_on_oom lets the guest reclaim
|
||||
// balloon pages under memory pressure. Stats interval enables monitoring.
|
||||
if err := client.setBalloon(ctx, 0, true, 5); err != nil {
|
||||
slog.Warn("set balloon failed (non-fatal, VM will run without memory reclaim)", "error", err)
|
||||
}
|
||||
|
||||
// MMDS config — enable V2 token access on eth0 so that envd can read
|
||||
// WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest.
|
||||
if err := client.setMMDSConfig(ctx, "eth0"); err != nil {
|
||||
return fmt.Errorf("set MMDS config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pause pauses a running VM.
|
||||
// Pause freezes a running VM's vCPUs via the CH API.
|
||||
func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
vm, ok := m.Get(sandboxID)
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
if err := vm.client.pauseVM(ctx); err != nil {
|
||||
return fmt.Errorf("pause VM: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("VM paused", "sandbox", sandboxID)
|
||||
return nil
|
||||
return vm.client.pauseVM(ctx)
|
||||
}
|
||||
|
||||
// Resume resumes a paused VM.
|
||||
// Resume unfreezes a paused VM via the CH API.
|
||||
func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
vm, ok := m.Get(sandboxID)
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
return vm.client.resumeVM(ctx)
|
||||
}
|
||||
|
||||
if err := vm.client.resumeVM(ctx); err != nil {
|
||||
return fmt.Errorf("resume VM: %w", err)
|
||||
// Info returns the CH VM state (e.g. "Running", "Paused", "Shutdown") via
|
||||
// the CH unix-socket API. Returns an error if the socket is dead or the VM
|
||||
// is not registered. Use to probe liveness before issuing destructive ops
|
||||
// like pause or snapshot.
|
||||
func (m *Manager) Info(ctx context.Context, sandboxID string) (string, error) {
|
||||
vm, ok := m.Get(sandboxID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
slog.Info("VM resumed", "sandbox", sandboxID)
|
||||
return nil
|
||||
return vm.client.vmInfo(ctx)
|
||||
}
|
||||
|
||||
// UpdateBalloon adjusts the balloon target for a running VM.
|
||||
@ -179,7 +128,8 @@ func (m *Manager) UpdateBalloon(ctx context.Context, sandboxID string, amountMiB
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
return vm.client.updateBalloon(ctx, amountMiB)
|
||||
sizeBytes := int64(amountMiB) * 1024 * 1024
|
||||
return vm.client.resizeBalloon(ctx, sizeBytes)
|
||||
}
|
||||
|
||||
// Destroy stops and cleans up a VM.
|
||||
@ -190,103 +140,98 @@ func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
delete(m.vms, sandboxID)
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("destroying VM", "sandbox", sandboxID)
|
||||
|
||||
// Stop the Firecracker process.
|
||||
// Try clean shutdown first, fall back to process kill.
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
if err := vm.client.shutdownVMM(shutdownCtx); err != nil {
|
||||
slog.Debug("clean VMM shutdown failed, killing process", "sandbox", sandboxID, "error", err)
|
||||
}
|
||||
shutdownCancel()
|
||||
|
||||
if err := vm.process.stop(); err != nil {
|
||||
slog.Warn("error stopping process", "sandbox", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Clean up the API socket.
|
||||
os.Remove(vm.Config.SocketPath)
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.vms, sandboxID)
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("VM destroyed", "sandbox", sandboxID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Snapshot creates a VM snapshot. The VM must already be paused.
|
||||
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
|
||||
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapPath, memPath, snapshotType string) error {
|
||||
m.mu.RLock()
|
||||
vm, ok := m.vms[sandboxID]
|
||||
m.mu.RUnlock()
|
||||
// Snapshot writes the VM's config/state/memory to snapshotDir via CH's
|
||||
// vm.snapshot API. The VM must already be paused. snapshotDir must be an
|
||||
// absolute path; it is passed to CH as `file://{dir}/`.
|
||||
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapshotDir string) error {
|
||||
vm, ok := m.Get(sandboxID)
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
if err := vm.client.createSnapshot(ctx, snapPath, memPath, snapshotType); err != nil {
|
||||
return fmt.Errorf("create snapshot: %w", err)
|
||||
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir snapshot dir: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("VM snapshot created", "sandbox", sandboxID, "snap_path", snapPath, "type", snapshotType)
|
||||
url := "file://" + strings.TrimRight(snapshotDir, "/") + "/"
|
||||
if err := vm.client.snapshotVM(ctx, url); err != nil {
|
||||
return fmt.Errorf("vm.snapshot: %w", err)
|
||||
}
|
||||
slog.Info("VM snapshot written", "sandbox", sandboxID, "dir", snapshotDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFromSnapshot boots a new Firecracker VM by loading a snapshot
|
||||
// using UFFD for lazy memory loading. The network namespace and TAP
|
||||
// device must already be set up.
|
||||
// CreateFromSnapshot launches a Cloud Hypervisor process in restore mode,
|
||||
// connecting it to an existing snapshot directory. The VM is left in the
|
||||
// paused state — the caller is expected to call Resume after any post-restore
|
||||
// setup (e.g. re-acquiring envd connectivity is implicit via TCP).
|
||||
//
|
||||
// No boot resources (kernel, drives, machine config) are configured —
|
||||
// the snapshot carries all that state. The rootfs path recorded in the
|
||||
// snapshot is resolved via a stable symlink at SandboxDir/rootfs.ext4
|
||||
// inside the mount namespace (created by the start script in jailer.go).
|
||||
//
|
||||
// The sequence is:
|
||||
// 1. Start FC process in mount+network namespace (creates tmpfs + rootfs symlink)
|
||||
// 2. Wait for API socket
|
||||
// 3. Load snapshot with UFFD backend
|
||||
// 4. Resume VM execution
|
||||
func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath, uffdSocketPath string) (*VM, error) {
|
||||
// cfg.RestoreFromDir must point to an absolute path containing the CH
|
||||
// snapshot artefacts. The disk path inside config.json must already resolve
|
||||
// (CH receives the same SandboxDir/rootfs.ext4 symlink as for fresh boot).
|
||||
func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
cfg.applyDefaults()
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
if cfg.RestoreFromDir == "" {
|
||||
return nil, fmt.Errorf("RestoreFromDir is required for restore")
|
||||
}
|
||||
|
||||
os.Remove(cfg.SocketPath)
|
||||
|
||||
slog.Info("restoring VM from snapshot",
|
||||
"sandbox", cfg.SandboxID,
|
||||
"snap_path", snapPath,
|
||||
"restore_dir", cfg.RestoreFromDir,
|
||||
"lazy_memory", cfg.RestoreLazyMemory,
|
||||
)
|
||||
|
||||
// Step 1: Launch the Firecracker process.
|
||||
// The start script creates a tmpfs at SandboxDir and symlinks
|
||||
// rootfs.ext4 → cfg.RootfsPath, so the snapshot's recorded rootfs
|
||||
// path (/fc-vm/rootfs.ext4) resolves to the new clone.
|
||||
proc, err := startProcess(ctx, &cfg)
|
||||
proc, err := startRestoreProcess(&cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
return nil, fmt.Errorf("start restore process: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Wait for the API socket.
|
||||
if err := waitForSocket(ctx, cfg.SocketPath, proc); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("wait for socket: %w", err)
|
||||
}
|
||||
|
||||
client := newFCClient(cfg.SocketPath)
|
||||
client := newCHClient(cfg.SocketPath)
|
||||
|
||||
// Step 3: Load the snapshot with UFFD backend.
|
||||
// No boot resources are configured — the snapshot carries kernel,
|
||||
// drive, network, and machine config state.
|
||||
if err := client.loadSnapshotWithUffd(ctx, snapPath, uffdSocketPath); err != nil {
|
||||
// Confirm CH actually hydrated the snapshot before registering. Without
|
||||
// this check, a broken snapshot would leave a zombie *VM in the map that
|
||||
// blocks future restores for the same sandbox ID.
|
||||
state, err := client.vmInfo(ctx)
|
||||
if err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("load snapshot: %w", err)
|
||||
return nil, fmt.Errorf("vm.info after restore: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Resume the VM.
|
||||
if err := client.resumeVM(ctx); err != nil {
|
||||
if state != "Paused" {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("resume VM: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Push sandbox metadata into MMDS.
|
||||
if err := client.setMMDS(ctx, cfg.SandboxID, cfg.TemplateID); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("set MMDS metadata: %w", err)
|
||||
return nil, fmt.Errorf("unexpected post-restore VM state %q (want Paused)", state)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
@ -299,16 +244,20 @@ func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath
|
||||
m.vms[cfg.SandboxID] = vm
|
||||
m.mu.Unlock()
|
||||
|
||||
slog.Info("VM restored from snapshot", "sandbox", cfg.SandboxID)
|
||||
slog.Info("VM restored from snapshot (paused)", "sandbox", cfg.SandboxID)
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
// PID returns the process ID of the unshare wrapper process.
|
||||
// The actual Firecracker process is a direct child of this PID.
|
||||
func (v *VM) PID() int {
|
||||
return v.process.cmd.Process.Pid
|
||||
}
|
||||
|
||||
// Exited returns a channel that is closed when the VM process exits.
|
||||
func (v *VM) Exited() <-chan struct{} {
|
||||
return v.process.exited()
|
||||
}
|
||||
|
||||
// Get returns a running VM by sandbox ID.
|
||||
func (m *Manager) Get(sandboxID string) (*VM, bool) {
|
||||
m.mu.RLock()
|
||||
@ -317,7 +266,7 @@ func (m *Manager) Get(sandboxID string) (*VM, bool) {
|
||||
return vm, ok
|
||||
}
|
||||
|
||||
// waitForSocket polls for the Firecracker API socket to appear on disk.
|
||||
// waitForSocket polls for the Cloud Hypervisor API socket to appear on disk.
|
||||
func waitForSocket(ctx context.Context, socketPath string, proc *process) error {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
@ -329,7 +278,7 @@ func waitForSocket(ctx context.Context, socketPath string, proc *process) error
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-proc.exited():
|
||||
return fmt.Errorf("firecracker process exited before socket was ready")
|
||||
return fmt.Errorf("cloud-hypervisor process exited before socket was ready")
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timed out waiting for API socket at %s", socketPath)
|
||||
case <-ticker.C:
|
||||
|
||||
174
internal/vm/process.go
Normal file
174
internal/vm/process.go
Normal file
@ -0,0 +1,174 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// process represents a running Cloud Hypervisor process with mount and network
|
||||
// namespace isolation.
|
||||
type process struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
|
||||
exitCh chan struct{}
|
||||
exitErr error
|
||||
logFile *os.File
|
||||
}
|
||||
|
||||
// startProcess launches the Cloud Hypervisor binary inside an isolated mount
|
||||
// namespace and the specified network namespace. Used for fresh boot (no
|
||||
// snapshot). The launch sequence:
|
||||
//
|
||||
// 1. unshare -m: creates a private mount namespace
|
||||
// 2. mount --make-rprivate /: prevents mount propagation to host
|
||||
// 3. mount tmpfs at SandboxDir: ephemeral workspace for this VM
|
||||
// 4. symlink kernel and rootfs into SandboxDir
|
||||
// 5. ip netns exec <ns>: enters the network namespace where TAP is configured
|
||||
// 6. exec cloud-hypervisor with the API socket path
|
||||
func startProcess(cfg *VMConfig) (*process, error) {
|
||||
script := buildStartScript(cfg)
|
||||
return launchScript(script, cfg)
|
||||
}
|
||||
|
||||
// startRestoreProcess launches CH in restore mode. It mirrors startProcess
|
||||
// for namespace/tmpfs/symlink setup so the disk paths recorded in the
|
||||
// snapshot's config.json remain valid, then execs CH with `--restore`.
|
||||
func startRestoreProcess(cfg *VMConfig) (*process, error) {
|
||||
script := buildRestoreScript(cfg)
|
||||
return launchScript(script, cfg)
|
||||
}
|
||||
|
||||
func launchScript(script string, cfg *VMConfig) (*process, error) {
|
||||
execCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cmd := exec.CommandContext(execCtx, "unshare", "-m", "--", "bash", "-c", script)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true,
|
||||
}
|
||||
|
||||
var logFile *os.File
|
||||
if cfg.LogDir != "" {
|
||||
logPath := fmt.Sprintf("%s/ch-%s.log", cfg.LogDir, cfg.SandboxID)
|
||||
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("open CH log file %s: %w", logPath, err)
|
||||
}
|
||||
cmd.Stdout = f
|
||||
cmd.Stderr = f
|
||||
logFile = f
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if logFile != nil {
|
||||
logFile.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("start cloud-hypervisor process: %w", err)
|
||||
}
|
||||
|
||||
p := &process{
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
exitCh: make(chan struct{}),
|
||||
logFile: logFile,
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.exitErr = cmd.Wait()
|
||||
if p.logFile != nil {
|
||||
p.logFile.Close()
|
||||
}
|
||||
close(p.exitCh)
|
||||
}()
|
||||
|
||||
slog.Info("cloud-hypervisor process started",
|
||||
"pid", cmd.Process.Pid,
|
||||
"sandbox", cfg.SandboxID,
|
||||
)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// buildStartScript generates the bash script for fresh boot: sets up mount
|
||||
// namespace, symlinks kernel/rootfs, and execs Cloud Hypervisor.
|
||||
func buildStartScript(cfg *VMConfig) string {
|
||||
return buildLaunchScript(cfg, "")
|
||||
}
|
||||
|
||||
// buildRestoreScript generates the bash script for restoring a VM from a
|
||||
// snapshot directory. The mount/symlink prelude is identical to fresh boot
|
||||
// so disk paths in the snapshot config.json resolve correctly.
|
||||
func buildRestoreScript(cfg *VMConfig) string {
|
||||
dir := strings.TrimRight(cfg.RestoreFromDir, "/")
|
||||
restoreArg := fmt.Sprintf("--restore source_url=file://%s/", dir)
|
||||
if cfg.RestoreLazyMemory {
|
||||
restoreArg += ",memory_restore_mode=ondemand"
|
||||
}
|
||||
return buildLaunchScript(cfg, restoreArg)
|
||||
}
|
||||
|
||||
// buildLaunchScript composes the namespace/tmpfs/symlink prelude and the
|
||||
// final cloud-hypervisor exec line. extraArgs is appended verbatim — used
|
||||
// to inject `--restore source_url=...` for restore launches.
|
||||
func buildLaunchScript(cfg *VMConfig, extraArgs string) string {
|
||||
chCmd := fmt.Sprintf("ip netns exec %s %s --api-socket path=%s",
|
||||
cfg.NetworkNamespace, cfg.VMMBin, cfg.SocketPath)
|
||||
if extraArgs != "" {
|
||||
chCmd += " " + extraArgs
|
||||
}
|
||||
return fmt.Sprintf(`
|
||||
set -euo pipefail
|
||||
|
||||
mount --make-rprivate /
|
||||
|
||||
mkdir -p %[1]s
|
||||
mount -t tmpfs tmpfs %[1]s
|
||||
|
||||
ln -s %[2]s %[1]s/vmlinux
|
||||
ln -s %[3]s %[1]s/rootfs.ext4
|
||||
|
||||
exec %[4]s
|
||||
`,
|
||||
cfg.SandboxDir, // 1
|
||||
cfg.KernelPath, // 2
|
||||
cfg.RootfsPath, // 3
|
||||
chCmd, // 4
|
||||
)
|
||||
}
|
||||
|
||||
// stop sends SIGTERM and waits for the process to exit. If it doesn't exit
|
||||
// within 10 seconds, SIGKILL is sent.
|
||||
func (p *process) stop() error {
|
||||
if p.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
slog.Debug("sigterm failed, process may have exited", "error", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.exitCh:
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
slog.Warn("cloud-hypervisor did not exit after SIGTERM, sending SIGKILL")
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
slog.Debug("sigkill failed", "error", err)
|
||||
}
|
||||
<-p.exitCh
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// exited returns a channel that is closed when the process exits.
|
||||
func (p *process) exited() <-chan struct{} {
|
||||
return p.exitCh
|
||||
}
|
||||
Reference in New Issue
Block a user