diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 84e04aa..a8293e6 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "time" @@ -314,10 +315,11 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) - // Determine snapshot type: Diff if resumed from snapshot (avoids UFFD - // fault-in storm), Full otherwise or if generation cap is reached. + // Always use Diff when we have a parent snapshot — Diff only captures + // changed pages and is much faster than Full (which dumps all memory). + // For first-time pauses (no parent) we must use Full. snapshotType := "Full" - if sb.parent != nil && sb.parent.header.Metadata.Generation < maxDiffGenerations { + if sb.parent != nil { snapshotType = "Diff" } @@ -353,7 +355,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { headerPath := filepath.Join(pauseDir, snapshot.MemHeaderName) processStart := time.Now() - if sb.parent != nil && snapshotType == "Diff" { + if sb.parent != nil { // Diff: process against parent header, producing only changed blocks. diffPath := snapshot.MemDiffPathForBuild(pauseDir, "", buildID) if _, err := snapshot.ProcessMemfileWithParent(rawMemPath, diffPath, headerPath, sb.parent.header, buildID); err != nil { @@ -373,8 +375,50 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } } } + + // If the generation cap is reached, merge all diff files into a + // single file to collapse the chain. This is a file-level operation + // (no Firecracker involvement) so it's fast and reliable. + generation := sb.parent.header.Metadata.Generation + 1 + if generation >= maxDiffGenerations { + slog.Debug("pause: merging diff generations", "id", sandboxID, "generation", generation) + + // Load the header we just wrote (it references all generations). + headerData, err := os.ReadFile(headerPath) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("read header for merge: %w", err) + } + currentHeader, err := snapshot.Deserialize(headerData) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("deserialize header for merge: %w", err) + } + + // Locate all diff files referenced by the header. + diffFiles, err := snapshot.ListDiffFiles(pauseDir, "", currentHeader) + if err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("list diff files for merge: %w", err) + } + + // Merge into a single new diff file. + mergedPath := snapshot.MemDiffPath(pauseDir, "") + if _, err := snapshot.MergeDiffs(currentHeader, diffFiles, mergedPath, headerPath); err != nil { + warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) + resumeOnError() + return fmt.Errorf("merge diff files: %w", err) + } + + // Remove the old per-generation diff files. + removeStaleMemDiffs(pauseDir) + slog.Debug("pause: diff merge complete", "id", sandboxID) + } } else { - // Full: first generation or generation cap reached — single diff file. + // Full: first pause — no parent to diff against. diffPath := snapshot.MemDiffPath(pauseDir, "") if _, err := snapshot.ProcessMemfile(rawMemPath, diffPath, headerPath, buildID); err != nil { warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) @@ -1283,6 +1327,23 @@ func (m *Manager) PauseAll(ctx context.Context) { } } +// removeStaleMemDiffs removes memfile.{uuid} diff files from a snapshot +// directory. Called before writing a Full snapshot to prevent orphaned diffs +// from accumulating across generation resets. +func removeStaleMemDiffs(dir string) { + entries, err := os.ReadDir(dir) + if err != nil { + return + } + for _, e := range entries { + name := e.Name() + // Match "memfile.{uuid}" but not "memfile", "memfile.header", or "memfile.raw". + if strings.HasPrefix(name, "memfile.") && name != snapshot.MemHeaderName && name != "memfile.raw" { + os.Remove(filepath.Join(dir, name)) + } + } +} + // warnErr logs a warning if err is non-nil. Used for best-effort cleanup // in error paths where the primary error has already been captured. func warnErr(msg string, id string, err error) { diff --git a/internal/snapshot/memfile.go b/internal/snapshot/memfile.go index aabe885..f7b14f9 100644 --- a/internal/snapshot/memfile.go +++ b/internal/snapshot/memfile.go @@ -4,6 +4,7 @@ package snapshot import ( + "context" "fmt" "io" "os" @@ -172,6 +173,99 @@ func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHe return header, nil } +// MergeDiffs consolidates multiple generation diff files into a single diff +// file and resets the generation counter to 0. This is a pure file-level +// operation — no Firecracker involvement. +// +// It reads each non-nil block from the appropriate diff file (as mapped by +// the header), writes them all sequentially into a single new diff file, +// and produces a fresh header pointing only at that file. +// +// diffFiles maps build ID (string) → open file path for each generation's diff. +func MergeDiffs(header *Header, diffFiles map[string]string, mergedDiffPath, headerPath string) (*Header, error) { + blockSize := int64(header.Metadata.BlockSize) + mergedBuildID := uuid.New() + + // Open all source diff files. + sources := make(map[string]*os.File, len(diffFiles)) + for id, path := range diffFiles { + f, err := os.Open(path) + if err != nil { + // Close already opened files. + for _, sf := range sources { + sf.Close() + } + return nil, fmt.Errorf("open diff file for build %s: %w", id, err) + } + sources[id] = f + } + defer func() { + for _, f := range sources { + f.Close() + } + }() + + dst, err := os.Create(mergedDiffPath) + if err != nil { + return nil, fmt.Errorf("create merged diff file: %w", err) + } + defer dst.Close() + + totalBlocks := TotalBlocks(int64(header.Metadata.Size), blockSize) + dirty := make([]bool, totalBlocks) + empty := make([]bool, totalBlocks) + buf := make([]byte, blockSize) + + for i := int64(0); i < totalBlocks; i++ { + offset := i * blockSize + mappedOffset, _, buildID, err := header.GetShiftedMapping(context.Background(), offset) + if err != nil { + return nil, fmt.Errorf("lookup block %d: %w", i, err) + } + + if *buildID == uuid.Nil { + empty[i] = true + continue + } + + src, ok := sources[buildID.String()] + if !ok { + return nil, fmt.Errorf("no diff file for build %s (block %d)", buildID, i) + } + + if _, err := src.ReadAt(buf, mappedOffset); err != nil { + return nil, fmt.Errorf("read block %d from build %s: %w", i, buildID, err) + } + + dirty[i] = true + if _, err := dst.Write(buf); err != nil { + return nil, fmt.Errorf("write merged block %d: %w", i, err) + } + } + + // Build fresh header with generation 0. + dirtyMappings := CreateMapping(mergedBuildID, dirty, blockSize) + emptyMappings := CreateMapping(uuid.Nil, empty, blockSize) + merged := MergeMappings(dirtyMappings, emptyMappings) + normalized := NormalizeMappings(merged) + + metadata := NewMetadata(mergedBuildID, uint64(blockSize), header.Metadata.Size) + newHeader, err := NewHeader(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("create merged header: %w", err) + } + + headerData, err := Serialize(metadata, normalized) + if err != nil { + return nil, fmt.Errorf("serialize merged header: %w", err) + } + if err := os.WriteFile(headerPath, headerData, 0644); err != nil { + return nil, fmt.Errorf("write merged header: %w", err) + } + + return newHeader, nil +} + // isZeroBlock checks if a block is entirely zero bytes. func isZeroBlock(block []byte) bool { // Fast path: compare 8 bytes at a time.