forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -25,49 +24,53 @@ import (
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
svc *service.TemplateService
|
||||
sandboxSvc *service.SandboxService
|
||||
db *db.Queries
|
||||
pool *lifecycle.HostClientPool
|
||||
audit *audit.AuditLogger
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, pool: pool, audit: al}
|
||||
func newSnapshotHandler(svc *service.TemplateService, sandboxSvc *service.SandboxService, db *db.Queries, pool *lifecycle.HostClientPool, al *audit.AuditLogger) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, sandboxSvc: sandboxSvc, db: db, pool: pool, audit: al}
|
||||
}
|
||||
|
||||
// deleteSnapshotBroadcast attempts to delete snapshot files on all online hosts.
|
||||
// Snapshots aren't currently host-tracked in the DB, so we broadcast to all hosts
|
||||
// and ignore NotFound errors.
|
||||
func deleteSnapshotBroadcast(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
// deleteSnapshotEverywhere removes a template's files from every active host.
|
||||
// Templates aren't host-tracked in the DB, so it broadcasts to all hosts.
|
||||
//
|
||||
// It is strict by design: deletion is reported successful only when every
|
||||
// active host has either removed the files or reported NotFound (it never
|
||||
// held them). If any host is offline or returns an error, it returns an error
|
||||
// and the caller MUST NOT delete the DB record — doing so would orphan the
|
||||
// files on disk with no record left to retry against.
|
||||
func deleteSnapshotEverywhere(ctx context.Context, queries *db.Queries, pool *lifecycle.HostClientPool, teamID, templateID pgtype.UUID) error {
|
||||
hosts, err := queries.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.Status != "online" {
|
||||
continue
|
||||
return fmt.Errorf("host %s is %s — cannot guarantee snapshot file removal",
|
||||
id.FormatHostID(host.ID), host.Status)
|
||||
}
|
||||
agent, err := pool.GetForHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
return fmt.Errorf("connect to host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
TeamId: formatUUIDForRPC(teamID),
|
||||
TemplateId: formatUUIDForRPC(templateID),
|
||||
})); err != nil {
|
||||
if connect.CodeOf(err) != connect.CodeNotFound {
|
||||
slog.Warn("snapshot: failed to delete on host", "host_id", id.FormatHostID(host.ID), "error", err)
|
||||
// NotFound just means this host never held the template.
|
||||
if connect.CodeOf(err) == connect.CodeNotFound {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("delete snapshot on host %s: %w", id.FormatHostID(host.ID), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@ -76,6 +79,7 @@ type snapshotResponse struct {
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Platform bool `json:"platform"`
|
||||
Protected bool `json:"protected"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@ -85,6 +89,7 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
Platform: t.TeamID == id.PlatformTeamID,
|
||||
Protected: layout.IsSystemTemplate(t.TeamID, t.ID),
|
||||
}
|
||||
if t.Vcpus != 0 {
|
||||
resp.VCPUs = &t.Vcpus
|
||||
@ -104,132 +109,42 @@ func templateToResponse(t db.Template) snapshotResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots.
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots. Snapshots a running or paused sandbox and
|
||||
// registers the result as a new template.
|
||||
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SandboxID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID, err := id.ParseSandboxID(req.SandboxID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox_id")
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid sandbox ID")
|
||||
return
|
||||
}
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
// Check for global name collision.
|
||||
if _, err := h.db.GetPlatformTemplateByName(ctx, req.Name); err == nil {
|
||||
writeError(w, http.StatusConflict, "name_reserved", "template name is reserved by a global template")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
writeError(w, http.StatusConflict, "template_name_taken",
|
||||
"snapshot name already exists; delete the existing snapshot first to reuse this name")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
// Async: the VM briefly pauses to a "snapshotting" state, then resumes. The
|
||||
// template is registered by a background goroutine; clients learn of the
|
||||
// result via the SSE template.snapshot.create event (or by polling).
|
||||
sb, name, err := h.sandboxSvc.CreateSnapshot(r.Context(), sandboxID, ac.TeamID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := agentForHost(ctx, h.db, h.pool, sb.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "host_unavailable", "sandbox host is not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-mark sandbox as "paused" in DB BEFORE issuing the snapshot RPC.
|
||||
// The host agent's CreateSnapshot removes the sandbox from its in-memory
|
||||
// map immediately; if the reconciler fires during the flatten window and
|
||||
// the DB still says "running", it will mark the sandbox "stopped".
|
||||
if sb.Status == "running" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update sandbox status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a detached context with a generous timeout so the snapshot completes
|
||||
// even if the client disconnects (the flatten step can take 10-20s).
|
||||
snapCtx, snapCancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer snapCancel()
|
||||
|
||||
// Generate the new template ID upfront so the host agent knows where to store files.
|
||||
newTemplateID := id.NewTemplateID()
|
||||
|
||||
resp, err := agent.CreateSnapshot(snapCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
TeamId: formatUUIDForRPC(ac.TeamID),
|
||||
TemplateId: formatUUIDForRPC(newTemplateID),
|
||||
}))
|
||||
if err != nil {
|
||||
// Snapshot failed — revert status back to what it was.
|
||||
if sb.Status == "running" {
|
||||
if _, dbErr := h.db.UpdateSandboxStatus(snapCtx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "running",
|
||||
}); dbErr != nil {
|
||||
slog.Error("failed to revert sandbox status after snapshot error", "sandbox_id", req.SandboxID, "error", dbErr)
|
||||
}
|
||||
}
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
h.audit.LogSnapshotCreateRequested(r.Context(), ac, name)
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(snapCtx, db.InsertTemplateParams{
|
||||
ID: newTemplateID,
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: sb.Vcpus,
|
||||
MemoryMb: sb.MemoryMb,
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
DefaultUser: "root",
|
||||
DefaultEnv: []byte("{}"),
|
||||
Metadata: sb.Metadata,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotCreate(snapCtx, ac, req.Name)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
slog.Info("snapshot created but client disconnected before response", "name", req.Name)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
writeJSON(w, http.StatusAccepted, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/snapshots.
|
||||
@ -243,6 +158,11 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve actual on-disk sizes for templates with unknown size (e.g.
|
||||
// system base templates seeded with size_bytes = 0). This queries a host
|
||||
// agent and persists the result to the DB for subsequent requests.
|
||||
templates = resolveTemplateSizes(r.Context(), h.db, h.pool, templates)
|
||||
|
||||
resp := make([]snapshotResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateToResponse(t)
|
||||
@ -271,21 +191,24 @@ func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "platform templates cannot be deleted here")
|
||||
return
|
||||
}
|
||||
if layout.IsMinimal(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "the minimal template cannot be deleted")
|
||||
if layout.IsSystemTemplate(tmpl.TeamID, tmpl.ID) {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "system base templates cannot be deleted")
|
||||
return
|
||||
}
|
||||
|
||||
if err := deleteSnapshotBroadcast(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "agent_error", "failed to delete snapshot files")
|
||||
if err := deleteSnapshotEverywhere(ctx, h.db, h.pool, tmpl.TeamID, tmpl.ID); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusConflict, "delete_failed",
|
||||
"could not remove snapshot files from all hosts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name)
|
||||
h.audit.LogSnapshotDelete(r.Context(), ac, name, nil)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user