1
0
forked from wrenn/wrenn
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev>

Reviewed-on: wrenn/wrenn#50
This commit is contained in:
2026-05-24 21:10:37 +00:00
parent 4707f16c76
commit 05ddf62399
203 changed files with 15815 additions and 9344 deletions

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
@ -25,49 +24,53 @@ import (
)
type snapshotHandler struct {
svc *service.TemplateService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
svc *service.TemplateService
sandboxSvc *service.SandboxService
db *db.Queries
pool *lifecycle.HostClientPool
audit *audit.AuditLogger
}
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
}
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
// and ignore NotFound errors.
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
// deleteSnapshotEverywhere removes a template's files from every active host.
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
//
// It is strict by design: deletion is reported successful only when every
// active host has either removed the files or reported NotFound (it never
// held them). If any host is offline or returns an error, it returns an error
// and the caller MUST NOT delete the DB record — doing so would orphan the
// files on disk with no record left to retry against.
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
hosts, err := queries.ListActiveHosts(ctx)
if err != nil {
return fmt.Errorf("list hosts: %w", err)
}
for _, host := range hosts {
if host.Status != "online" {
continue
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
id.FormatHostID(host.ID), host.Status)
}
agent, err := pool.GetForHost(host)
if err != nil {
continue
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: formatUUIDForRPC(teamID),
TemplateId: formatUUIDForRPC(templateID),
})); err != nil {
if connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
// NotFound just means this host never held the template.
if connect.CodeOf(err) == connect.CodeNotFound {
continue
}
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
}
}
return nil
}
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
type snapshotResponse struct {
Name string `json:"name"`
Type string `json:"type"`
@ -76,6 +79,7 @@ type snapshotResponse struct {
SizeBytes int64 `json:"size_bytes"`
CreatedAt string `json:"created_at"`
Platform bool `json:"platform"`
Protected bool `json:"protected"`
Metadata map[string]string `json:"metadata,omitempty"`
}
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
Type: t.Type,
SizeBytes: t.SizeBytes,
Platform: t.TeamID == id.PlatformTeamID,
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
}
if t.Vcpus != 0 {
resp.VCPUs = &t.Vcpus
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
return resp
}
// Create handles POST /v1/snapshots.
type createSnapshotRequest struct {
SandboxID string `json:"sandbox_id"`
Name string `json:"name"`
}
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
// registers the result as a new template.
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createSnapshotRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.SandboxID == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
return
}
sandboxID, err := id.ParseSandboxID(req.SandboxID)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
return
}
ac := auth.MustFromContext(r.Context())
if req.Name == "" {
req.Name = id.NewSnapshotName()
}
if err := validate.SafeName(req.Name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
return
}
ctx := r.Context()
ac := auth.MustFromContext(ctx)
// Check for global name collision.
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
return
}
// Check if name already exists for this team.
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
writeError(w, http.StatusConflict, "template_name_taken",
"snapshot name already exists; delete the existing snapshot first to reuse this name")
return
}
// Verify sandbox exists, belongs to team, and is running or paused.
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
// template is registered by a background goroutine; clients learn of the
// result via the SSE template.snapshot.create event (or by polling).
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
if err != nil {
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
return
}
if sb.Status != "running" && sb.Status != "paused" {
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
return
}
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
return
}
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
// The host agent's CreateSnapshot removes the sandbox from its in-memory
// map immediately; if the reconciler fires during the flatten window and
// the DB still says "running", it will mark the sandbox "stopped".
if sb.Status == "running" {
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
return
}
}
// Use a detached context with a generous timeout so the snapshot completes
// even if the client disconnects (the flatten step can take 10-20s).
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer snapCancel()
// Generate the new template ID upfront so the host agent knows where to store files.
newTemplateID := id.NewTemplateID()
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: req.SandboxID,
Name: req.Name,
TeamId: formatUUIDForRPC(ac.TeamID),
TemplateId: formatUUIDForRPC(newTemplateID),
}))
if err != nil {
// Snapshot failed — revert status back to what it was.
if sb.Status == "running" {
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
}
}
status, code, msg := agentErrToHTTP(err)
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
ID: newTemplateID,
Name: req.Name,
Type: "snapshot",
Vcpus: sb.Vcpus,
MemoryMb: sb.MemoryMb,
SizeBytes: resp.Msg.SizeBytes,
TeamID: ac.TeamID,
DefaultUser: "root",
DefaultEnv: []byte("{}"),
Metadata: sb.Metadata,
})
if err != nil {
slog.Error("failed to insert template record", "name", req.Name, "error", err)
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
return
}
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
if ctx.Err() != nil {
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
return
}
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
}
// List handles GET /v1/snapshots.
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
return
}
// Resolve actual on-disk sizes for templates with unknown size (e.g.
// system base templates seeded with size_bytes = 0). This queries a host
// agent and persists the result to the DB for subsequent requests.
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
resp := make([]snapshotResponse, len(templates))
for i, t := range templates {
resp[i] = templateToResponse(t)
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
return
}
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
return
}
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusConflict, "delete_failed",
"could not remove snapshot files from all hosts: "+err.Error())
return
}
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
return
}
h.audit.LogSnapshotDelete(r.Context(), ac, name)
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
w.WriteHeader(http.StatusNoContent)
}