diff --git a/internal/api/handlers_sandbox.go b/internal/api/handlers_sandbox.go index aa4ed07..bb06a5f 100644 --- a/internal/api/handlers_sandbox.go +++ b/internal/api/handlers_sandbox.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "fmt" "log/slog" "net/http" "time" @@ -12,6 +13,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/validate" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -86,6 +88,10 @@ func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) { if req.Template == "" { req.Template = "minimal" } + if err := validate.SafeName(req.Template); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid template name: %s", err)) + return + } if req.VCPUs <= 0 { req.VCPUs = 1 } diff --git a/internal/api/handlers_snapshots.go b/internal/api/handlers_snapshots.go index d07fee9..8e6b36f 100644 --- a/internal/api/handlers_snapshots.go +++ b/internal/api/handlers_snapshots.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "fmt" "log/slog" "net/http" "time" @@ -12,6 +13,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/id" + "git.omukk.dev/wrenn/sandbox/internal/validate" pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -73,6 +75,10 @@ func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) { if req.Name == "" { req.Name = id.NewSnapshotName() } + if err := validate.SafeName(req.Name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err)) + return + } ctx := r.Context() overwrite := r.URL.Query().Get("overwrite") == "true" @@ -166,6 +172,10 @@ func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) { // Delete handles DELETE /v1/snapshots/{name}. func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") + if err := validate.SafeName(name); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err)) + return + } ctx := r.Context() if _, err := h.db.GetTemplate(ctx, name); err != nil { diff --git a/internal/network/setup.go b/internal/network/setup.go index 933d68e..70a8a54 100644 --- a/internal/network/setup.go +++ b/internal/network/setup.go @@ -1,6 +1,7 @@ package network import ( + "errors" "fmt" "log/slog" "net" @@ -93,6 +94,8 @@ func NewSlot(index int) *Slot { // - Veth pair bridging host and namespace // - TAP device inside namespace for Firecracker // - Routes and NAT rules for connectivity +// +// On error, all partially created resources are rolled back. func CreateNetwork(slot *Slot) error { // Lock this goroutine to the OS thread — required for netns manipulation. runtime.LockOSThread() @@ -106,12 +109,25 @@ func CreateNetwork(slot *Slot) error { defer hostNS.Close() defer func() { _ = netns.Set(hostNS) }() + // rollbacks accumulates cleanup functions; on error they run in reverse. + var rollbacks []func() + rollback := func() { + for i := len(rollbacks) - 1; i >= 0; i-- { + rollbacks[i]() + } + } + // Create named network namespace. ns, err := netns.NewNamed(slot.NamespaceID) if err != nil { return fmt.Errorf("create namespace %s: %w", slot.NamespaceID, err) } defer ns.Close() + // Deleting the namespace also cleans up TAP, loopback, namespace-internal + // routes, and namespace-internal iptables rules. + rollbacks = append(rollbacks, func() { + _ = netns.DeleteNamed(slot.NamespaceID) + }) // We are now inside the new namespace. slog.Info("created network namespace", "ns", slot.NamespaceID) @@ -124,12 +140,14 @@ func CreateNetwork(slot *Slot) error { PeerName: "eth0", } if err := netlink.LinkAdd(veth); err != nil { + rollback() return fmt.Errorf("create veth pair: %w", err) } // Configure vpeer (eth0) inside namespace. vpeer, err := netlink.LinkByName("eth0") if err != nil { + rollback() return fmt.Errorf("find eth0: %w", err) } vpeerAddr := &netlink.Addr{ @@ -139,20 +157,30 @@ func CreateNetwork(slot *Slot) error { }, } if err := netlink.AddrAdd(vpeer, vpeerAddr); err != nil { + rollback() return fmt.Errorf("set vpeer addr: %w", err) } if err := netlink.LinkSetUp(vpeer); err != nil { + rollback() return fmt.Errorf("bring up vpeer: %w", err) } // Move veth to host namespace. vethLink, err := netlink.LinkByName(slot.VethName) if err != nil { + rollback() return fmt.Errorf("find veth: %w", err) } if err := netlink.LinkSetNsFd(vethLink, int(hostNS)); err != nil { + rollback() return fmt.Errorf("move veth to host ns: %w", err) } + // Once the veth is in the host namespace, we need to clean it up from there. + rollbacks = append(rollbacks, func() { + if l, err := netlink.LinkByName(slot.VethName); err == nil { + _ = netlink.LinkDel(l) + } + }) // Create TAP device inside namespace. tapAttrs := netlink.NewLinkAttrs() @@ -162,10 +190,12 @@ func CreateNetwork(slot *Slot) error { Mode: netlink.TUNTAP_MODE_TAP, } if err := netlink.LinkAdd(tap); err != nil { + rollback() return fmt.Errorf("create tap device: %w", err) } tapLink, err := netlink.LinkByName(tapName) if err != nil { + rollback() return fmt.Errorf("find tap: %w", err) } tapAddr := &netlink.Addr{ @@ -175,18 +205,22 @@ func CreateNetwork(slot *Slot) error { }, } if err := netlink.AddrAdd(tapLink, tapAddr); err != nil { + rollback() return fmt.Errorf("set tap addr: %w", err) } if err := netlink.LinkSetUp(tapLink); err != nil { + rollback() return fmt.Errorf("bring up tap: %w", err) } // Bring up loopback. lo, err := netlink.LinkByName("lo") if err != nil { + rollback() return fmt.Errorf("find loopback: %w", err) } if err := netlink.LinkSetUp(lo); err != nil { + rollback() return fmt.Errorf("bring up loopback: %w", err) } @@ -195,6 +229,7 @@ func CreateNetwork(slot *Slot) error { Scope: netlink.SCOPE_UNIVERSE, Gw: slot.VethIP, }); err != nil { + rollback() return fmt.Errorf("add default route in namespace: %w", err) } @@ -202,6 +237,7 @@ func CreateNetwork(slot *Slot) error { if err := nsExec(slot.NamespaceID, "sysctl", "-w", "net.ipv4.ip_forward=1", ); err != nil { + rollback() return fmt.Errorf("enable ip_forward in namespace: %w", err) } @@ -212,6 +248,7 @@ func CreateNetwork(slot *Slot) error { "-o", "eth0", "-s", guestIP, "-j", "SNAT", "--to", slot.VpeerIP.String(), ); err != nil { + rollback() return fmt.Errorf("add SNAT rule: %w", err) } // Inbound: host -> guest. Packets arrive with dst=hostIP, DNAT to guest IP. @@ -220,17 +257,20 @@ func CreateNetwork(slot *Slot) error { "-i", "eth0", "-d", slot.HostIP.String(), "-j", "DNAT", "--to", guestIP, ); err != nil { + rollback() return fmt.Errorf("add DNAT rule: %w", err) } // Switch back to host namespace for host-side config. if err := netns.Set(hostNS); err != nil { + rollback() return fmt.Errorf("switch to host ns: %w", err) } // Configure veth on host side. hostVeth, err := netlink.LinkByName(slot.VethName) if err != nil { + rollback() return fmt.Errorf("find veth in host: %w", err) } vethAddr := &netlink.Addr{ @@ -240,9 +280,11 @@ func CreateNetwork(slot *Slot) error { }, } if err := netlink.AddrAdd(hostVeth, vethAddr); err != nil { + rollback() return fmt.Errorf("set veth addr: %w", err) } if err := netlink.LinkSetUp(hostVeth); err != nil { + rollback() return fmt.Errorf("bring up veth: %w", err) } @@ -252,12 +294,17 @@ func CreateNetwork(slot *Slot) error { Dst: hostNet, Gw: slot.VpeerIP, }); err != nil { + rollback() return fmt.Errorf("add host route: %w", err) } + rollbacks = append(rollbacks, func() { + _ = netlink.RouteDel(&netlink.Route{Dst: hostNet, Gw: slot.VpeerIP}) + }) // Find default gateway interface for FORWARD rules. defaultIface, err := getDefaultInterface() if err != nil { + rollback() return fmt.Errorf("get default interface: %w", err) } @@ -267,15 +314,24 @@ func CreateNetwork(slot *Slot) error { "-i", slot.VethName, "-o", defaultIface, "-j", "ACCEPT", ); err != nil { + rollback() return fmt.Errorf("add forward rule (out): %w", err) } + rollbacks = append(rollbacks, func() { + _ = iptablesHost("-D", "FORWARD", "-i", slot.VethName, "-o", defaultIface, "-j", "ACCEPT") + }) + if err := iptablesHost( "-A", "FORWARD", "-i", defaultIface, "-o", slot.VethName, "-j", "ACCEPT", ); err != nil { + rollback() return fmt.Errorf("add forward rule (in): %w", err) } + rollbacks = append(rollbacks, func() { + _ = iptablesHost("-D", "FORWARD", "-i", defaultIface, "-o", slot.VethName, "-j", "ACCEPT") + }) // MASQUERADE for outbound traffic from sandbox. // After SNAT inside the namespace, outbound packets arrive on the host @@ -286,6 +342,7 @@ func CreateNetwork(slot *Slot) error { "-o", defaultIface, "-j", "MASQUERADE", ); err != nil { + rollback() return fmt.Errorf("add masquerade rule: %w", err) } @@ -299,47 +356,65 @@ func CreateNetwork(slot *Slot) error { } // RemoveNetwork tears down the network topology for a sandbox. +// All steps are attempted even if earlier ones fail. Returns a combined +// error describing which cleanup steps failed. func RemoveNetwork(slot *Slot) error { + var errs []error + defaultIface, _ := getDefaultInterface() - // Remove host-side iptables rules (best effort). + // Remove host-side iptables rules. if defaultIface != "" { - _ = iptablesHost( + if err := iptablesHost( "-D", "FORWARD", "-i", slot.VethName, "-o", defaultIface, "-j", "ACCEPT", - ) - _ = iptablesHost( + ); err != nil { + errs = append(errs, fmt.Errorf("remove forward rule (out): %w", err)) + } + if err := iptablesHost( "-D", "FORWARD", "-i", defaultIface, "-o", slot.VethName, "-j", "ACCEPT", - ) - _ = iptablesHost( + ); err != nil { + errs = append(errs, fmt.Errorf("remove forward rule (in): %w", err)) + } + if err := iptablesHost( "-t", "nat", "-D", "POSTROUTING", "-s", fmt.Sprintf("%s/32", slot.VpeerIP.String()), "-o", defaultIface, "-j", "MASQUERADE", - ) + ); err != nil { + errs = append(errs, fmt.Errorf("remove masquerade rule: %w", err)) + } + } else { + errs = append(errs, fmt.Errorf("could not determine default interface; host iptables rules not removed")) } // Remove host route. _, hostNet, _ := net.ParseCIDR(fmt.Sprintf("%s/32", slot.HostIP.String())) - _ = netlink.RouteDel(&netlink.Route{ + if err := netlink.RouteDel(&netlink.Route{ Dst: hostNet, Gw: slot.VpeerIP, - }) + }); err != nil { + errs = append(errs, fmt.Errorf("remove host route: %w", err)) + } // Delete veth (also destroys the peer in the namespace). if veth, err := netlink.LinkByName(slot.VethName); err == nil { - _ = netlink.LinkDel(veth) + if err := netlink.LinkDel(veth); err != nil { + errs = append(errs, fmt.Errorf("delete veth: %w", err)) + } } // Delete the named namespace. - _ = netns.DeleteNamed(slot.NamespaceID) + if err := netns.DeleteNamed(slot.NamespaceID); err != nil { + errs = append(errs, fmt.Errorf("delete namespace: %w", err)) + } - slog.Info("network removed", "ns", slot.NamespaceID) + slog.Info("network removed", "ns", slot.NamespaceID, "cleanup_errors", len(errs)) - return nil + return errors.Join(errs...) } // nsExec runs a command inside a network namespace. diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 9ea7865..5fd86d8 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -18,6 +18,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/snapshot" "git.omukk.dev/wrenn/sandbox/internal/uffd" + "git.omukk.dev/wrenn/sandbox/internal/validate" "git.omukk.dev/wrenn/sandbox/internal/vm" ) @@ -83,6 +84,9 @@ func (m *Manager) Create(ctx context.Context, sandboxID, template string, vcpus, if template == "" { template = "minimal" } + if err := validate.SafeName(template); err != nil { + return nil, fmt.Errorf("invalid template name: %w", err) + } // Check if template refers to a snapshot (has snapfile + memfile + header + rootfs). if snapshot.IsSnapshot(m.cfg.ImagesDir, template) { @@ -560,6 +564,10 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) (*models.Sandbox // so the template has no dependency on the original base image. Memory state // and VM snapshot files are copied as-is. func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (int64, error) { + if err := validate.SafeName(name); err != nil { + return 0, fmt.Errorf("invalid snapshot name: %w", err) + } + // If the sandbox is running, pause it first. if _, err := m.get(sandboxID); err == nil { if err := m.Pause(ctx, sandboxID); err != nil { @@ -648,6 +656,9 @@ func (m *Manager) CreateSnapshot(ctx context.Context, sandboxID, name string) (i // DeleteSnapshot removes a snapshot template from disk. func (m *Manager) DeleteSnapshot(name string) error { + if err := validate.SafeName(name); err != nil { + return fmt.Errorf("invalid snapshot name: %w", err) + } return snapshot.Remove(m.cfg.ImagesDir, name) } diff --git a/internal/validate/name.go b/internal/validate/name.go new file mode 100644 index 0000000..2051d87 --- /dev/null +++ b/internal/validate/name.go @@ -0,0 +1,24 @@ +package validate + +import ( + "fmt" + "regexp" +) + +// nameRe matches safe path component names: alphanumeric start, then +// alphanumeric, dash, underscore, or dot. Max 64 characters. +var nameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`) + +// SafeName checks that name is safe for use as a single filesystem path +// component. It rejects empty strings, path separators, ".." sequences, +// leading dots, and anything outside the alphanumeric+dash+underscore+dot +// allowlist. +func SafeName(name string) error { + if name == "" { + return fmt.Errorf("name must not be empty") + } + if !nameRe.MatchString(name) { + return fmt.Errorf("name %q contains invalid characters or is too long (max 64, must match %s)", name, nameRe.String()) + } + return nil +} diff --git a/internal/validate/name_test.go b/internal/validate/name_test.go new file mode 100644 index 0000000..4b7769e --- /dev/null +++ b/internal/validate/name_test.go @@ -0,0 +1,41 @@ +package validate + +import "testing" + +func TestSafeName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"simple", "minimal", false}, + {"with-dash", "template-abc123", false}, + {"with-dot", "my-snapshot.v2", false}, + {"sandbox-id", "sb-12345678", false}, + {"single-char", "a", false}, + {"numbers", "123", false}, + {"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false}, + + {"empty", "", true}, + {"dot-dot", "..", true}, + {"single-dot", ".", true}, + {"leading-dot", ".hidden", true}, + {"slash", "foo/bar", true}, + {"backslash", "foo\\bar", true}, + {"traversal", "../etc/passwd", true}, + {"embedded-traversal", "foo/../bar", true}, + {"space", "foo bar", true}, + {"too-long", "abcdefghijklmnopqrstuvwxyz012345678901abcdefghijklmnopqrstuvwxyz01", true}, + {"absolute", "/etc/passwd", true}, + {"tilde", "~root", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := SafeName(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("SafeName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +}