1
0
forked from wrenn/wrenn

Replace Full snapshot fallback with file-level diff merge

Always use Firecracker Diff snapshots (fast, only changed pages) and
merge diff files at the file level when the generation cap is reached.
The previous approach used Firecracker's Full snapshot type which dumps
all memory to disk and can timeout, losing all snapshot data on failure.

Add snapshot.MergeDiffs() which reads each block from the appropriate
generation's diff file via the header mapping and writes them into a
single consolidated file with a fresh generation-0 header.
This commit is contained in:
2026-03-29 02:33:33 +06:00
parent 1ca10230a9
commit 8f06fc554a
2 changed files with 160 additions and 5 deletions

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
@ -314,10 +315,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
} }
slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart))
// Determine snapshot type: Diff if resumed from snapshot (avoids UFFD // Always use Diff when we have a parent snapshot — Diff only captures
// fault-in storm), Full otherwise or if generation cap is reached. // changed pages and is much faster than Full (which dumps all memory).
// For first-time pauses (no parent) we must use Full.
snapshotType := "Full" snapshotType := "Full"
if sb.parent != nil && sb.parent.header.Metadata.Generation < maxDiffGenerations { if sb.parent != nil {
snapshotType = "Diff" snapshotType = "Diff"
} }
@ -353,7 +355,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName) headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName)
processStart := time.Now() processStart := time.Now()
if sb.parent != nil && snapshotType == "Diff" { if sb.parent != nil {
// Diff: process against parent header, producing only changed blocks. // Diff: process against parent header, producing only changed blocks.
diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID) diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID)
if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil { if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil {
@ -373,8 +375,50 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
} }
} }
} }
// If the generation cap is reached, merge all diff files into a
// single file to collapse the chain. This is a file-level operation
// (no Firecracker involvement) so it's fast and reliable.
generation := sb.parent.header.Metadata.Generation + 1
if generation >= maxDiffGenerations {
slog.Debug("pause: merging diff generations", "id", sandboxID, "generation", generation)
// Load the header we just wrote (it references all generations).
headerData, err := os.ReadFile(headerPath)
if err != nil {
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("read header for merge: %w", err)
}
currentHeader, err := snapshot.Deserialize(headerData)
if err != nil {
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("deserialize header for merge: %w", err)
}
// Locate all diff files referenced by the header.
diffFiles, err := snapshot.ListDiffFiles(pauseDir, "", currentHeader)
if err != nil {
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("list diff files for merge: %w", err)
}
// Merge into a single new diff file.
mergedPath := snapshot.MemDiffPath(pauseDir, "")
if _, err := snapshot.MergeDiffs(currentHeader, diffFiles, mergedPath, headerPath); err != nil {
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
resumeOnError()
return fmt.Errorf("merge diff files: %w", err)
}
// Remove the old per-generation diff files.
removeStaleMemDiffs(pauseDir)
slog.Debug("pause: diff merge complete", "id", sandboxID)
}
} else { } else {
// Full: first generation or generation cap reached — single diff file. // Full: first pause — no parent to diff against.
diffPath := snapshot.MemDiffPath(pauseDir, "") diffPath := snapshot.MemDiffPath(pauseDir, "")
if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil { if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil {
warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir))
@ -1283,6 +1327,23 @@ func (m *Manager) PauseAll(ctx context.Context) {
} }
} }
// removeStaleMemDiffs removes memfile.{uuid} diff files from a snapshot
// directory. Called before writing a Full snapshot to prevent orphaned diffs
// from accumulating across generation resets.
func removeStaleMemDiffs(dir string) {
entries, err := os.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
name := e.Name()
// Match "memfile.{uuid}" but not "memfile", "memfile.header", or "memfile.raw".
if strings.HasPrefix(name, "memfile.") && name != snapshot.MemHeaderName && name != "memfile.raw" {
os.Remove(filepath.Join(dir, name))
}
}
}
// warnErr logs a warning if err is non-nil. Used for best-effort cleanup // warnErr logs a warning if err is non-nil. Used for best-effort cleanup
// in error paths where the primary error has already been captured. // in error paths where the primary error has already been captured.
func warnErr(msg string, id string, err error) { func warnErr(msg string, id string, err error) {

View File

@ -4,6 +4,7 @@
package snapshot package snapshot
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe
return header, nil return header, nil
} }
// MergeDiffs consolidates multiple generation diff files into a single diff
// file and resets the generation counter to 0. This is a pure file-level
// operation — no Firecracker involvement.
//
// It reads each non-nil block from the appropriate diff file (as mapped by
// the header), writes them all sequentially into a single new diff file,
// and produces a fresh header pointing only at that file.
//
// diffFiles maps build ID (string) → open file path for each generation's diff.
func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) {
blockSize := int64(header.Metadata.BlockSize)
mergedBuildID := uuid.New()
// Open all source diff files.
sources := make(map[string]*os.File, len(diffFiles))
for id, path := range diffFiles {
f, err := os.Open(path)
if err != nil {
// Close already opened files.
for _, sf := range sources {
sf.Close()
}
return nil, fmt.Errorf("open diff file for build %s: %w", id, err)
}
sources[id] = f
}
defer func() {
for _, f := range sources {
f.Close()
}
}()
dst, err := os.Create(mergedDiffPath)
if err != nil {
return nil, fmt.Errorf("create merged diff file: %w", err)
}
defer dst.Close()
totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize)
dirty := make([]bool, totalBlocks)
empty := make([]bool, totalBlocks)
buf := make([]byte, blockSize)
for i := int64(0); i < totalBlocks; i++ {
offset := i * blockSize
mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset)
if err != nil {
return nil, fmt.Errorf("lookup block %d: %w", i, err)
}
if *buildID == uuid.Nil {
empty[i] = true
continue
}
src, ok := sources[buildID.String()]
if !ok {
return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i)
}
if _, err := src.ReadAt(buf, mappedOffset); err != nil {
return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err)
}
dirty[i] = true
if _, err := dst.Write(buf); err != nil {
return nil, fmt.Errorf("write merged block %d: %w", i, err)
}
}
// Build fresh header with generation 0.
dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize)
emptyMappings := CreateMapping(uuid.Nil, empty, blockSize)
merged := MergeMappings(dirtyMappings, emptyMappings)
normalized := NormalizeMappings(merged)
metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size)
newHeader, err := NewHeader(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("create merged header: %w", err)
}
headerData, err := Serialize(metadata, normalized)
if err != nil {
return nil, fmt.Errorf("serialize merged header: %w", err)
}
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
return nil, fmt.Errorf("write merged header: %w", err)
}
return newHeader, nil
}
// isZeroBlock checks if a block is entirely zero bytes. // isZeroBlock checks if a block is entirely zero bytes.
func isZeroBlock(block []byte) bool { func isZeroBlock(block []byte) bool {
// Fast path: compare 8 bytes at a time. // Fast path: compare 8 bytes at a time.