Add sandbox snapshot and restore with UFFD lazy memory loading
Implement full snapshot lifecycle: pause (snapshot + free resources), resume (UFFD-based lazy restore), and named snapshot templates that can spawn new sandboxes from frozen VM state. Key changes: - Snapshot header system with generational diff mapping (inspired by e2b) - UFFD server for lazy page fault handling during snapshot restore - Stable rootfs symlink path (/tmp/fc-vm/) for snapshot compatibility - Templates DB table and CRUD API endpoints (POST/GET/DELETE /v1/snapshots) - CreateSnapshot/DeleteSnapshot RPCs in hostagent proto - Reconciler excludes paused sandboxes (expected absent from host agent) - Snapshot templates lock vcpus/memory to baked-in values - Proper cleanup of uffd sockets and pause snapshot files on destroy
This commit is contained in:
87
internal/uffd/fd.go
Normal file
87
internal/uffd/fd.go
Normal file
@ -0,0 +1,87 @@
|
||||
// Package uffd implements a userfaultfd-based memory server for Firecracker
|
||||
// snapshot restore. When a VM is restored from a snapshot, instead of loading
|
||||
// the entire memory file upfront, the UFFD handler intercepts page faults
|
||||
// and serves memory pages on demand from the snapshot's compact diff file.
|
||||
//
|
||||
// Inspired by e2b's UFFD implementation (Apache 2.0, modified by Omukk).
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <sys/syscall.h>
|
||||
#include <fcntl.h>
|
||||
#include <linux/userfaultfd.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
struct uffd_pagefault {
|
||||
__u64 flags;
|
||||
__u64 address;
|
||||
__u32 ptid;
|
||||
};
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
|
||||
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
|
||||
UFFDIO_COPY = C.UFFDIO_COPY
|
||||
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP
|
||||
)
|
||||
|
||||
type (
|
||||
uffdMsg = C.struct_uffd_msg
|
||||
uffdPagefault = C.struct_uffd_pagefault
|
||||
uffdioCopy = C.struct_uffdio_copy
|
||||
)
|
||||
|
||||
// fd wraps a userfaultfd file descriptor received from Firecracker.
|
||||
type fd uintptr
|
||||
|
||||
// copy installs a page into guest memory at the given address using UFFDIO_COPY.
|
||||
// mode controls write-protection: use UFFDIO_COPY_MODE_WP to preserve WP bit.
|
||||
func (f fd) copy(addr, pagesize uintptr, data []byte, mode C.ulonglong) error {
|
||||
alignedAddr := addr &^ (pagesize - 1)
|
||||
cpy := uffdioCopy{
|
||||
src: C.ulonglong(uintptr(unsafe.Pointer(&data[0]))),
|
||||
dst: C.ulonglong(alignedAddr),
|
||||
len: C.ulonglong(pagesize),
|
||||
mode: mode,
|
||||
copy: 0,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy)))
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
|
||||
if cpy.copy != C.longlong(pagesize) {
|
||||
return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the userfaultfd file descriptor.
|
||||
func (f fd) close() error {
|
||||
return syscall.Close(int(f))
|
||||
}
|
||||
|
||||
// getMsgEvent extracts the event type from a uffd_msg.
|
||||
func getMsgEvent(msg *uffdMsg) C.uchar {
|
||||
return msg.event
|
||||
}
|
||||
|
||||
// getMsgArg extracts the arg union from a uffd_msg.
|
||||
func getMsgArg(msg *uffdMsg) [24]byte {
|
||||
return msg.arg
|
||||
}
|
||||
|
||||
// getPagefaultAddress extracts the faulting address from a uffd_pagefault.
|
||||
func getPagefaultAddress(pf *uffdPagefault) uintptr {
|
||||
return uintptr(pf.address)
|
||||
}
|
||||
35
internal/uffd/region.go
Normal file
35
internal/uffd/region.go
Normal file
@ -0,0 +1,35 @@
|
||||
package uffd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Region is a mapping of guest memory to host virtual address space.
|
||||
// Firecracker sends these as JSON when connecting to the UFFD socket.
|
||||
// The JSON field names match Firecracker's UFFD protocol.
|
||||
type Region struct {
|
||||
BaseHostVirtAddr uintptr `json:"base_host_virt_addr"`
|
||||
Size uintptr `json:"size"`
|
||||
Offset uintptr `json:"offset"`
|
||||
PageSize uintptr `json:"page_size_kib"` // Actually in bytes despite the name.
|
||||
}
|
||||
|
||||
// Mapping translates between host virtual addresses and logical memory offsets.
|
||||
type Mapping struct {
|
||||
Regions []Region
|
||||
}
|
||||
|
||||
// NewMapping creates a Mapping from a list of regions.
|
||||
func NewMapping(regions []Region) *Mapping {
|
||||
return &Mapping{Regions: regions}
|
||||
}
|
||||
|
||||
// GetOffset converts a host virtual address to a logical memory file offset
|
||||
// and returns the page size. This is called on every UFFD page fault.
|
||||
func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) {
|
||||
for _, r := range m.Regions {
|
||||
if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.BaseHostVirtAddr+r.Size {
|
||||
offset := int64(hostVirtAddr-r.BaseHostVirtAddr) + int64(r.Offset)
|
||||
return offset, r.PageSize, nil
|
||||
}
|
||||
}
|
||||
return 0, 0, fmt.Errorf("address %#x not found in any memory region", hostVirtAddr)
|
||||
}
|
||||
352
internal/uffd/server.go
Normal file
352
internal/uffd/server.go
Normal file
@ -0,0 +1,352 @@
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <linux/userfaultfd.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/snapshot"
|
||||
)
|
||||
|
||||
const (
|
||||
fdSize = 4
|
||||
regionMappingsSize = 1024
|
||||
maxConcurrentFaults = 4096
|
||||
)
|
||||
|
||||
// MemorySource provides page data for the UFFD handler.
|
||||
// Given a logical memory offset and a size, it returns the page data.
|
||||
type MemorySource interface {
|
||||
ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error)
|
||||
}
|
||||
|
||||
// Server manages the UFFD Unix socket lifecycle and page fault handling
|
||||
// for a single Firecracker snapshot restore.
|
||||
type Server struct {
|
||||
socketPath string
|
||||
source MemorySource
|
||||
lis *net.UnixListener
|
||||
|
||||
readyCh chan struct{}
|
||||
readyOnce sync.Once
|
||||
doneCh chan struct{}
|
||||
doneErr error
|
||||
|
||||
// exitPipe signals the poll loop to stop.
|
||||
exitR *os.File
|
||||
exitW *os.File
|
||||
}
|
||||
|
||||
// NewServer creates a UFFD server that will listen on the given socket path
|
||||
// and serve memory pages from the given source.
|
||||
func NewServer(socketPath string, source MemorySource) *Server {
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
source: source,
|
||||
readyCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix socket. Firecracker will connect to this
|
||||
// socket after loadSnapshot is called with the UFFD backend.
|
||||
// Start returns immediately; the server runs in a background goroutine.
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
lis, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.socketPath, Net: "unix"})
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on uffd socket: %w", err)
|
||||
}
|
||||
s.lis = lis
|
||||
|
||||
if err := os.Chmod(s.socketPath, 0o777); err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("chmod uffd socket: %w", err)
|
||||
}
|
||||
|
||||
// Create exit signal pipe.
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("create exit pipe: %w", err)
|
||||
}
|
||||
s.exitR = r
|
||||
s.exitW = w
|
||||
|
||||
go func() {
|
||||
defer close(s.doneCh)
|
||||
s.doneErr = s.handle(ctx)
|
||||
s.lis.Close()
|
||||
s.exitR.Close()
|
||||
s.exitW.Close()
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the UFFD handler is ready
|
||||
// (after Firecracker has connected and sent the uffd fd).
|
||||
func (s *Server) Ready() <-chan struct{} {
|
||||
return s.readyCh
|
||||
}
|
||||
|
||||
// Stop signals the UFFD poll loop to exit and waits for it to finish.
|
||||
func (s *Server) Stop() error {
|
||||
// Write a byte to the exit pipe to wake the poll loop.
|
||||
s.exitW.Write([]byte{0})
|
||||
<-s.doneCh
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// Wait blocks until the server exits.
|
||||
func (s *Server) Wait() error {
|
||||
<-s.doneCh
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// handle accepts the Firecracker connection, receives the UFFD fd via
|
||||
// SCM_RIGHTS, and runs the page fault poll loop.
|
||||
func (s *Server) handle(ctx context.Context) error {
|
||||
conn, err := s.lis.Accept()
|
||||
if err != nil {
|
||||
return fmt.Errorf("accept uffd connection: %w", err)
|
||||
}
|
||||
|
||||
unixConn := conn.(*net.UnixConn)
|
||||
defer unixConn.Close()
|
||||
|
||||
// Read the memory region mappings (JSON) and the UFFD fd (SCM_RIGHTS).
|
||||
regionBuf := make([]byte, regionMappingsSize)
|
||||
uffdBuf := make([]byte, syscall.CmsgSpace(fdSize))
|
||||
|
||||
nRegion, nFd, _, _, err := unixConn.ReadMsgUnix(regionBuf, uffdBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd message: %w", err)
|
||||
}
|
||||
|
||||
var regions []Region
|
||||
if err := json.Unmarshal(regionBuf[:nRegion], ®ions); err != nil {
|
||||
return fmt.Errorf("parse memory regions: %w", err)
|
||||
}
|
||||
|
||||
controlMsgs, err := syscall.ParseSocketControlMessage(uffdBuf[:nFd])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse control messages: %w", err)
|
||||
}
|
||||
if len(controlMsgs) != 1 {
|
||||
return fmt.Errorf("expected 1 control message, got %d", len(controlMsgs))
|
||||
}
|
||||
|
||||
fds, err := syscall.ParseUnixRights(&controlMsgs[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse unix rights: %w", err)
|
||||
}
|
||||
if len(fds) != 1 {
|
||||
return fmt.Errorf("expected 1 fd, got %d", len(fds))
|
||||
}
|
||||
|
||||
uffdFd := fd(fds[0])
|
||||
defer uffdFd.close()
|
||||
|
||||
mapping := NewMapping(regions)
|
||||
|
||||
slog.Info("uffd handler connected",
|
||||
"regions", len(regions),
|
||||
"fd", int(uffdFd),
|
||||
)
|
||||
|
||||
// Signal readiness.
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
|
||||
// Run the poll loop.
|
||||
return s.serve(ctx, uffdFd, mapping)
|
||||
}
|
||||
|
||||
// serve is the main poll loop. It polls the UFFD fd for page fault events
|
||||
// and the exit pipe for shutdown signals.
|
||||
func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
|
||||
pollFds := []unix.PollFd{
|
||||
{Fd: int32(uffdFd), Events: unix.POLLIN},
|
||||
{Fd: int32(s.exitR.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxConcurrentFaults)
|
||||
|
||||
// Always wait for in-flight goroutines before returning, so the caller
|
||||
// can safely close the uffd fd after serve returns.
|
||||
defer wg.Wait()
|
||||
|
||||
for {
|
||||
if _, err := unix.Poll(pollFds, -1); err != nil {
|
||||
if err == unix.EINTR || err == unix.EAGAIN {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("poll: %w", err)
|
||||
}
|
||||
|
||||
// Check exit signal.
|
||||
if pollFds[1].Revents&unix.POLLIN != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pollFds[0].Revents&unix.POLLIN == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the uffd_msg. The fd is O_NONBLOCK (set by Firecracker),
|
||||
// so EAGAIN is expected — just go back to poll.
|
||||
buf := make([]byte, unsafe.Sizeof(uffdMsg{}))
|
||||
n, err := readUffdMsg(uffdFd, buf)
|
||||
if err == syscall.EAGAIN {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd msg: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
|
||||
if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT {
|
||||
return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg))
|
||||
}
|
||||
|
||||
arg := getMsgArg(&msg)
|
||||
pf := *(*uffdPagefault)(unsafe.Pointer(&arg[0]))
|
||||
addr := getPagefaultAddress(&pf)
|
||||
|
||||
offset, pagesize, err := mapping.GetOffset(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve address %#x: %w", addr, err)
|
||||
}
|
||||
|
||||
sem <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.faultPage(ctx, uffdFd, addr, offset, pagesize); err != nil {
|
||||
slog.Error("uffd fault page error",
|
||||
"addr", fmt.Sprintf("%#x", addr),
|
||||
"offset", offset,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// readUffdMsg reads a single uffd_msg, retrying on EINTR.
|
||||
// Returns (n, EAGAIN) if the non-blocking read has nothing available.
|
||||
func readUffdMsg(uffdFd fd, buf []byte) (int, error) {
|
||||
for {
|
||||
n, err := syscall.Read(int(uffdFd), buf)
|
||||
if err == syscall.EINTR {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// faultPage fetches a page from the memory source and copies it into
|
||||
// guest memory via UFFDIO_COPY.
|
||||
func (s *Server) faultPage(ctx context.Context, uffdFd fd, addr uintptr, offset int64, pagesize uintptr) error {
|
||||
data, err := s.source.ReadPage(ctx, offset, int64(pagesize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read page at offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// Mode 0: no write-protect. Standard Firecracker does not register
|
||||
// UFFD ranges with WP support, so UFFDIO_COPY_MODE_WP would fail.
|
||||
if err := uffdFd.copy(addr, pagesize, data, 0); err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Page already mapped (race with prefetch or concurrent fault).
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("uffdio_copy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DiffFileSource serves pages from a snapshot's compact diff file using
|
||||
// the header's block mapping to resolve offsets.
|
||||
type DiffFileSource struct {
|
||||
header *snapshot.Header
|
||||
// diffs maps build ID → open file handle for each generation's diff file.
|
||||
diffs map[string]*os.File
|
||||
}
|
||||
|
||||
// NewDiffFileSource creates a memory source backed by snapshot diff files.
|
||||
// diffs maps build ID string to the file path of each generation's diff file.
|
||||
func NewDiffFileSource(header *snapshot.Header, diffPaths map[string]string) (*DiffFileSource, error) {
|
||||
diffs := make(map[string]*os.File, len(diffPaths))
|
||||
for id, path := range diffPaths {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
// Close already opened files.
|
||||
for _, opened := range diffs {
|
||||
opened.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open diff file %s: %w", path, err)
|
||||
}
|
||||
diffs[id] = f
|
||||
}
|
||||
return &DiffFileSource{header: header, diffs: diffs}, nil
|
||||
}
|
||||
|
||||
// ReadPage resolves a memory offset through the header mapping and reads
|
||||
// the corresponding page from the correct generation's diff file.
|
||||
func (s *DiffFileSource) ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error) {
|
||||
mappedOffset, _, buildID, err := s.header.GetShiftedMapping(ctx, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// uuid.Nil means zero-fill (empty page).
|
||||
var nilUUID [16]byte
|
||||
if *buildID == nilUUID {
|
||||
return make([]byte, size), nil
|
||||
}
|
||||
|
||||
f, ok := s.diffs[buildID.String()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no diff file for build %s", buildID)
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
n, err := f.ReadAt(buf, mappedOffset)
|
||||
if err != nil && int64(n) < size {
|
||||
return nil, fmt.Errorf("read diff at offset %d: %w", mappedOffset, err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Close closes all open diff file handles.
|
||||
func (s *DiffFileSource) Close() error {
|
||||
var errs []error
|
||||
for _, f := range s.diffs {
|
||||
if err := f.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
Reference in New Issue
Block a user