diff --git a/db/queries/sandboxes.sql b/db/queries/sandboxes.sql index 8cbd10b..2b19574 100644 --- a/db/queries/sandboxes.sql +++ b/db/queries/sandboxes.sql @@ -9,6 +9,14 @@ SELECT * FROM sandboxes WHERE id = $1; -- name: GetSandboxByTeam :one SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2; +-- name: GetSandboxProxyTarget :one +-- Returns the sandbox status and its host's address in one query. +-- Used by SandboxProxyWrapper to avoid two round-trips. +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2; + -- name: ListSandboxes :many SELECT * FROM sandboxes ORDER BY created_at DESC; diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index a7b9f5b..963dff6 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -1,6 +1,8 @@ package api import ( + "context" + "errors" "fmt" "log/slog" "net/http" @@ -9,6 +11,8 @@ import ( "regexp" "strconv" "strings" + "sync" + "time" "github.com/jackc/pgx/v5/pgtype" @@ -18,10 +22,45 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/lifecycle" ) +// Sentinel errors returned by proxyTarget, used to map to HTTP status codes +// without relying on error message text. +var ( + errProxySandboxNotFound = errors.New("sandbox not found") + errProxyNoHostAddress = errors.New("host agent has no address") +) + +const proxyCacheTTL = 120 * time.Second + // sandboxHostPattern matches hostnames like "49999-cl-abcd1234.localhost" or // "49999-cl-abcd1234.example.com". Captures: port, sandbox ID. var sandboxHostPattern = regexp.MustCompile(`^(\d+)-(cl-[0-9a-z]+)\.`) +// errProxySandboxNotRunning carries the sandbox status so callers can include +// it in the HTTP response without parsing error strings. +type errProxySandboxNotRunning struct{ status string } + +func (e errProxySandboxNotRunning) Error() string { + return fmt.Sprintf("sandbox is not running (status: %s)", e.status) +} + +// proxyCacheEntry caches the resolved agent URL for a (sandbox, team) pair. +// The *httputil.ReverseProxy is built per-request (cheap) so the Director closure +// can capture the correct port without the cache key needing to include it. +type proxyCacheEntry struct { + agentURL *url.URL + expiresAt time.Time +} + +// proxyCacheKey is a fixed-size key from two UUIDs, avoids string allocation. +type proxyCacheKey [32]byte + +func makeProxyCacheKey(sandboxID, teamID pgtype.UUID) proxyCacheKey { + var k proxyCacheKey + copy(k[:16], sandboxID.Bytes[:]) + copy(k[16:], teamID.Bytes[:]) + return k +} + // SandboxProxyWrapper wraps an existing HTTP handler and intercepts requests // whose Host header matches the {port}-{sandbox_id}.{domain} pattern. Matching // requests are reverse-proxied through the host agent that owns the sandbox. @@ -34,6 +73,9 @@ type SandboxProxyWrapper struct { db *db.Queries pool *lifecycle.HostClientPool transport http.RoundTripper + + cacheMu sync.Mutex + cache map[proxyCacheKey]proxyCacheEntry } // NewSandboxProxyWrapper creates a new proxy wrapper. @@ -43,9 +85,63 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec db: queries, pool: pool, transport: pool.Transport(), + cache: make(map[proxyCacheKey]proxyCacheEntry), } } +// proxyTarget looks up the cached agent URL for (sandboxID, teamID). +// On a miss it queries the DB, resolves the address, and populates the cache. +// The *httputil.ReverseProxy is built by the caller so the Director closure +// captures the correct port without the cache key needing to include it. +func (h *SandboxProxyWrapper) proxyTarget(ctx context.Context, sandboxID, teamID pgtype.UUID) (*url.URL, error) { + cacheKey := makeProxyCacheKey(sandboxID, teamID) + + h.cacheMu.Lock() + entry, ok := h.cache[cacheKey] + h.cacheMu.Unlock() + + if ok && time.Now().Before(entry.expiresAt) { + return entry.agentURL, nil + } + + // Cache miss or expired — query DB. + target, err := h.db.GetSandboxProxyTarget(ctx, db.GetSandboxProxyTargetParams{ + ID: sandboxID, + TeamID: teamID, + }) + if err != nil { + return nil, errProxySandboxNotFound + } + if target.Status != "running" { + return nil, errProxySandboxNotRunning{status: target.Status} + } + if target.HostAddress == "" { + return nil, errProxyNoHostAddress + } + + agentURL, err := url.Parse(h.pool.ResolveAddr(target.HostAddress)) + if err != nil { + return nil, fmt.Errorf("invalid host agent address: %w", err) + } + + h.cacheMu.Lock() + h.cache[cacheKey] = proxyCacheEntry{ + agentURL: agentURL, + expiresAt: time.Now().Add(proxyCacheTTL), + } + h.cacheMu.Unlock() + + return agentURL, nil +} + +// evictProxyCache removes the cached entry for a (sandbox, team) pair. +// Called on 502 so a stopped/moved sandbox is re-resolved on the next request. +func (h *SandboxProxyWrapper) evictProxyCache(sandboxID, teamID pgtype.UUID) { + h.cacheMu.Lock() + delete(h.cache, makeProxyCacheKey(sandboxID, teamID)) + h.cacheMu.Unlock() +} + func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { host := r.Host // Strip port from Host header (e.g. "49999-cl-abcd1234.localhost:8000" → "49999-cl-abcd1234.localhost") @@ -82,51 +178,26 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - ctx := r.Context() - - // Look up sandbox and verify ownership. - sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ - ID: sandboxID, - TeamID: teamID, - }) + agentURL, err := h.proxyTarget(r.Context(), sandboxID, teamID) if err != nil { - http.Error(w, "sandbox not found", http.StatusNotFound) - return - } - - if sb.Status != "running" { - http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) - return - } - - agentHost, err := h.db.GetHost(ctx, sb.HostID) - if err != nil { - http.Error(w, "host agent not found", http.StatusServiceUnavailable) - return - } - - if agentHost.Address == "" { - http.Error(w, "host agent has no address", http.StatusServiceUnavailable) - return - } - - agentAddr := h.pool.ResolveAddr(agentHost.Address) - upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) - - target, err := url.Parse(agentAddr) - if err != nil { - http.Error(w, "invalid host agent address", http.StatusInternalServerError) + switch { + case errors.Is(err, errProxySandboxNotFound): + http.Error(w, err.Error(), http.StatusNotFound) + case errors.As(err, new(errProxySandboxNotRunning)): + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } return } proxy := &httputil.ReverseProxy{ Transport: h.transport, Director: func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = upstreamPath - req.URL.RawQuery = r.URL.RawQuery - req.Host = target.Host + req.URL.Scheme = agentURL.Scheme + req.URL.Host = agentURL.Host + req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path + req.Host = agentURL.Host }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { slog.Debug("sandbox proxy error", @@ -134,10 +205,10 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) "port", port, "error", err, ) + h.evictProxyCache(sandboxID, teamID) http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) }, } - proxy.ServeHTTP(w, r) } diff --git a/internal/auth/cert.go b/internal/auth/cert.go index 1af4867..d76f1de 100644 --- a/internal/auth/cert.go +++ b/internal/auth/cert.go @@ -235,9 +235,9 @@ func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config { pool := x509.NewCertPool() pool.AddCert(ca.Cert) return &tls.Config{ - RootCAs: pool, - GetClientCertificate: certStore.GetClientCertificate, - MinVersion: tls.VersionTLS13, + RootCAs: pool, + GetClientCertificate: certStore.GetClientCertificate, + MinVersion: tls.VersionTLS13, } } diff --git a/internal/db/sandboxes.sql.go b/internal/db/sandboxes.sql.go index 4107f1a..3ce1644 100644 --- a/internal/db/sandboxes.sql.go +++ b/internal/db/sandboxes.sql.go @@ -105,6 +105,32 @@ func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamPara return i, err } +const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one +SELECT s.status, h.address AS host_address +FROM sandboxes s +JOIN hosts h ON h.id = s.host_id +WHERE s.id = $1 AND s.team_id = $2 +` + +type GetSandboxProxyTargetParams struct { + ID pgtype.UUID `json:"id"` + TeamID pgtype.UUID `json:"team_id"` +} + +type GetSandboxProxyTargetRow struct { + Status string `json:"status"` + HostAddress string `json:"host_address"` +} + +// Returns the sandbox status and its host's address in one query. +// Used by SandboxProxyWrapper to avoid two round-trips. +func (q *Queries) GetSandboxProxyTarget(ctx context.Context, arg GetSandboxProxyTargetParams) (GetSandboxProxyTargetRow, error) { + row := q.db.QueryRow(ctx, getSandboxProxyTarget, arg.ID, arg.TeamID) + var i GetSandboxProxyTargetRow + err := row.Scan(&i.Status, &i.HostAddress) + return i, err +} + const insertSandbox = `-- name: InsertSandbox :one INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index b4a39ee..bbee474 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -8,7 +8,6 @@ import ( "strconv" "strings" - "git.omukk.dev/wrenn/sandbox/internal/models" "git.omukk.dev/wrenn/sandbox/internal/sandbox" ) @@ -62,18 +61,14 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - sb, err := h.mgr.Get(sandboxID) - if err != nil { - http.Error(w, "sandbox not found", http.StatusNotFound) + hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID) + if !ok { + http.Error(w, "sandbox is not available", http.StatusServiceUnavailable) return } + defer tracker.Release() - if sb.Status != models.StatusRunning { - http.Error(w, fmt.Sprintf("sandbox is not running (status: %s)", sb.Status), http.StatusConflict) - return - } - - targetHost := fmt.Sprintf("%s:%d", sb.HostIP.String(), portNum) + targetHost := fmt.Sprintf("%s:%d", hostIP, portNum) proxy := &httputil.ReverseProxy{ Transport: h.transport, diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go new file mode 100644 index 0000000..d9eac72 --- /dev/null +++ b/internal/sandbox/conntracker.go @@ -0,0 +1,66 @@ +package sandbox + +import ( + "sync" + "sync/atomic" + "time" +) + +// ConnTracker tracks active proxy connections for a single sandbox and +// provides a drain mechanism for pre-pause graceful shutdown. +// It is safe for concurrent use. +type ConnTracker struct { + draining atomic.Bool + wg sync.WaitGroup +} + +// Acquire registers one in-flight connection. Returns false if the tracker +// is already draining; the caller must not call Release in that case. +func (t *ConnTracker) Acquire() bool { + if t.draining.Load() { + return false + } + t.wg.Add(1) + // Re-check after Add: Drain may have set draining between our Load + // and Add. If so, undo the Add and reject the connection. + if t.draining.Load() { + t.wg.Done() + return false + } + return true +} + +// Release marks one connection as complete. Must be called exactly once +// per successful Acquire. +func (t *ConnTracker) Release() { + t.wg.Done() +} + +// Drain marks the tracker as draining (all future Acquire calls return +// false) and waits up to timeout for in-flight connections to finish. +// +// Note: if the timeout expires with connections still in-flight, the +// internal goroutine waiting on wg.Wait() will remain until those +// connections complete. This is bounded by the number of hung connections +// at drain time and self-heals once they close. +func (t *ConnTracker) Drain(timeout time.Duration) { + t.draining.Store(true) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + } +} + +// Reset re-enables the tracker after a failed drain. This allows the +// sandbox to accept proxy connections again if the pause operation fails +// and the VM is resumed. +func (t *ConnTracker) Reset() { + t.draining.Store(false) +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index ac2bc22..67a70ca 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net" "os" "os/exec" "path/filepath" @@ -50,7 +51,8 @@ type sandboxState struct { models.Sandbox slot *network.Slot client *envdclient.Client - uffdSocketPath string // non-empty for sandboxes restored from snapshot + connTracker *ConnTracker // tracks in-flight proxy connections for pre-pause drain + uffdSocketPath string // non-empty for sandboxes restored from snapshot dmDevice *devicemapper.SnapshotDevice baseImagePath string // path to the base template rootfs (for loop registry release) @@ -224,6 +226,7 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template }, slot: slot, client: client, + connTracker: &ConnTracker{}, dmDevice: dmDev, baseImagePath: baseRootfs, } @@ -308,10 +311,17 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Step 0: Drain in-flight proxy connections before freezing vCPUs. + // This prevents Go runtime corruption inside the guest caused by stale + // TCP state from connections that were alive when the VM was snapshotted. + sb.connTracker.Drain(2 * time.Second) + slog.Debug("pause: proxy connections drained", "id", sandboxID) + pauseStart := time.Now() // Step 1: Pause the VM (freeze vCPUs). if err := m.vm.Pause(ctx, sandboxID); err != nil { + sb.connTracker.Reset() return fmt.Errorf("pause VM: %w", err) } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) @@ -326,8 +336,10 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // resumeOnError unpauses the VM so the sandbox stays usable when a // post-freeze step fails. If the resume itself fails, the sandbox is - // left frozen — the caller should destroy it. + // left frozen — the caller should destroy it. It also resets the + // connection tracker so the sandbox can accept proxy connections again. resumeOnError := func() { + sb.connTracker.Reset() if err := m.vm.Resume(ctx, sandboxID); err != nil { slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) } @@ -692,6 +704,7 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int) }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseImagePath, @@ -1094,6 +1107,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team }, slot: slot, client: client, + connTracker: &ConnTracker{}, uffdSocketPath: uffdSocketPath, dmDevice: dmDev, baseImagePath: baseRootfs, @@ -1190,6 +1204,25 @@ func (m *Manager) GetClient(sandboxID string) (*envdclient.Client, error) { return sb.client, nil } +// AcquireProxyConn atomically looks up a sandbox by ID and registers an +// in-flight proxy connection. Returns the sandbox's host-reachable IP, the +// connection tracker, and true on success. The caller must call +// tracker.Release() when the request completes. Returns zero values and +// false if the sandbox is not found, not running, or is draining for a pause. +func (m *Manager) AcquireProxyConn(sandboxID string) (net.IP, *ConnTracker, bool) { + m.mu.RLock() + sb, ok := m.boxes[sandboxID] + m.mu.RUnlock() + + if !ok || sb.Status != models.StatusRunning { + return nil, nil, false + } + if !sb.connTracker.Acquire() { + return nil, nil, false + } + return sb.HostIP, sb.connTracker, true +} + // Ping resets the inactivity timer for a running sandbox. func (m *Manager) Ping(sandboxID string) error { m.mu.Lock()