diff --git a/envd/internal/api/conntracker.go b/envd/internal/api/conntracker.go new file mode 100644 index 0000000..054f920 --- /dev/null +++ b/envd/internal/api/conntracker.go @@ -0,0 +1,94 @@ +package api + +import ( + "net" + "net/http" + "sync" +) + +// ServerConnTracker tracks active HTTP connections via http.Server.ConnState. +// Before a Firecracker snapshot, it closes idle connections, disables +// keep-alives, and records which connections existed pre-snapshot. After +// restore, it closes ALL pre-snapshot connections (they are zombie TCP +// sockets) while leaving post-restore connections (like the /init request) +// untouched. +type ServerConnTracker struct { + mu sync.Mutex + conns map[net.Conn]http.ConnState + preSnapshot map[net.Conn]struct{} + srv *http.Server +} + +func NewServerConnTracker() *ServerConnTracker { + return &ServerConnTracker{ + conns: make(map[net.Conn]http.ConnState), + } +} + +// SetServer stores a reference to the http.Server for keep-alive control. +// Must be called before ListenAndServe. +func (t *ServerConnTracker) SetServer(srv *http.Server) { + t.mu.Lock() + t.srv = srv + t.mu.Unlock() +} + +// Track implements the http.Server.ConnState callback signature. +func (t *ServerConnTracker) Track(conn net.Conn, state http.ConnState) { + t.mu.Lock() + defer t.mu.Unlock() + switch state { + case http.StateNew, http.StateActive, http.StateIdle: + t.conns[conn] = state + case http.StateHijacked, http.StateClosed: + delete(t.conns, conn) + delete(t.preSnapshot, conn) + } +} + +// PrepareForSnapshot closes idle connections, disables keep-alives, and +// records all remaining active connections. After the response completes +// (with keep-alives disabled, the connection closes), RestoreAfterSnapshot +// will close any that survived into the snapshot as zombie TCP sockets. +// +// GC cycles are handled by PortSubsystem.Stop() which runs before this. +func (t *ServerConnTracker) PrepareForSnapshot() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.srv != nil { + t.srv.SetKeepAlivesEnabled(false) + } + + t.preSnapshot = make(map[net.Conn]struct{}, len(t.conns)) + for conn, state := range t.conns { + if state == http.StateIdle { + conn.Close() + delete(t.conns, conn) + } else { + t.preSnapshot[conn] = struct{}{} + } + } +} + +// RestoreAfterSnapshot closes ALL pre-snapshot connections (zombie TCP +// sockets after restore) and re-enables keep-alives. Post-restore +// connections (like the /init request that triggers this call) are not +// in the preSnapshot set and are left untouched. +// +// Safe to call on first boot — preSnapshot is nil, so this is a no-op +// aside from enabling keep-alives (which are already enabled by default). +func (t *ServerConnTracker) RestoreAfterSnapshot() { + t.mu.Lock() + defer t.mu.Unlock() + + for conn := range t.preSnapshot { + conn.Close() + delete(t.conns, conn) + } + t.preSnapshot = nil + + if t.srv != nil { + t.srv.SetKeepAlivesEnabled(true) + } +} diff --git a/envd/internal/api/download_test.go b/envd/internal/api/download_test.go index a4379cc..fc01573 100644 --- a/envd/internal/api/download_test.go +++ b/envd/internal/api/download_test.go @@ -99,7 +99,7 @@ func TestGetFilesContentDisposition(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") // Create request and response recorder req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil) @@ -148,7 +148,7 @@ func TestGetFilesContentDispositionWithNestedPath(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") // Create request and response recorder req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil) @@ -191,7 +191,7 @@ func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") // Create request and response recorder req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil) @@ -232,7 +232,7 @@ func TestGetFiles_GzipDownload(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil) req.Header.Set("Accept-Encoding", "gzip") @@ -297,7 +297,7 @@ func TestPostFiles_GzipUpload(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf) req.Header.Set("Content-Type", mpWriter.FormDataContentType()) @@ -357,7 +357,7 @@ func TestGzipUploadThenGzipDownload(t *testing.T) { EnvVars: utils.NewMap[string, string](), User: currentUser.Username, } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf) uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType()) diff --git a/envd/internal/api/init.go b/envd/internal/api/init.go index 3b2be4b..68a1b86 100644 --- a/envd/internal/api/init.go +++ b/envd/internal/api/init.go @@ -150,6 +150,12 @@ func (a *API) PostInit(w http.ResponseWriter, r *http.Request) { host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars) }() + // Close zombie connections from before the snapshot and re-enable + // keep-alives. On first boot this is a no-op (no zombie connections). + if a.connTracker != nil { + a.connTracker.RestoreAfterSnapshot() + } + // Start the port scanner and forwarder if they were stopped by a // pre-snapshot prepare call. Start is a no-op if already running, // so this is safe on first boot and only takes effect after restore. diff --git a/envd/internal/api/init_test.go b/envd/internal/api/init_test.go index 18ee203..9fe6ece 100644 --- a/envd/internal/api/init_test.go +++ b/envd/internal/api/init_test.go @@ -79,7 +79,7 @@ func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API { defaults := &execcontext.Defaults{ EnvVars: utils.NewMap[string, string](), } - api := New(&logger, defaults, nil, false, context.Background(), nil, "test") + api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test") if accessToken != nil { api.accessToken.TakeFrom(accessToken) } diff --git a/envd/internal/api/snapshot.go b/envd/internal/api/snapshot.go index d9e2edd..6d13381 100644 --- a/envd/internal/api/snapshot.go +++ b/envd/internal/api/snapshot.go @@ -7,9 +7,11 @@ import ( "net/http" ) -// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder) -// and forces a GC cycle before Firecracker takes a VM snapshot. This ensures -// the Go runtime's page allocator is in a consistent state when vCPUs are frozen. +// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder), +// closes idle HTTP connections, and forces a GC cycle before Firecracker takes +// a VM snapshot. Closing connections prevents Go runtime corruption from stale +// TCP state after snapshot restore. Keep-alives are disabled so the current +// request's connection also closes after the response. // // Called by the host agent as a best-effort signal before vm.Pause(). func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) { @@ -20,6 +22,11 @@ func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) { a.logger.Info().Msg("snapshot/prepare: port subsystem quiesced") } + if a.connTracker != nil { + a.connTracker.PrepareForSnapshot() + a.logger.Info().Msg("snapshot/prepare: idle connections closed, keep-alives disabled") + } + w.Header().Set("Cache-Control", "no-store") w.WriteHeader(http.StatusNoContent) } diff --git a/envd/internal/api/store.go b/envd/internal/api/store.go index ca97957..5365604 100644 --- a/envd/internal/api/store.go +++ b/envd/internal/api/store.go @@ -47,9 +47,10 @@ type API struct { // long-lived goroutines after snapshot restore. rootCtx context.Context portSubsystem *publicport.PortSubsystem + connTracker *ServerConnTracker } -func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem, version string) *API { +func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem, connTracker *ServerConnTracker, version string) *API { return &API{ logger: l, defaults: defaults, @@ -60,6 +61,7 @@ func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host. accessToken: &SecureToken{}, rootCtx: rootCtx, portSubsystem: portSubsystem, + connTracker: connTracker, version: version, } } diff --git a/envd/main.go b/envd/main.go index 1cd9403..3acd2c6 100644 --- a/envd/main.go +++ b/envd/main.go @@ -197,7 +197,9 @@ func main() { portSubsystem.Start(ctx) defer portSubsystem.Stop() - service := api.New(&envLogger, defaults, mmdsChan, isNotFC, ctx, portSubsystem, Version) + connTracker := api.NewServerConnTracker() + + service := api.New(&envLogger, defaults, mmdsChan, isNotFC, ctx, portSubsystem, connTracker, Version) handler := api.HandlerFromMux(service, m) middleware := authn.NewMiddleware(permissions.AuthenticateUsername) @@ -212,7 +214,9 @@ func main() { ReadTimeout: 0, WriteTimeout: 0, IdleTimeout: idleTimeout, + ConnState: connTracker.Track, } + connTracker.SetServer(s) // TODO: Not used anymore in template build, replaced by direct envd command call. if startCmdFlag != "" { diff --git a/internal/envdclient/client.go b/internal/envdclient/client.go index 294a37e..aed0349 100644 --- a/internal/envdclient/client.go +++ b/internal/envdclient/client.go @@ -19,10 +19,11 @@ import ( // Client wraps the Connect RPC client for envd's Process and Filesystem services. type Client struct { - hostIP string - base string - healthURL string - httpClient *http.Client + hostIP string + base string + healthURL string + httpClient *http.Client + streamingClient *http.Client process genconnect.ProcessClient filesystem genconnect.FilesystemClient @@ -32,29 +33,44 @@ type Client struct { func New(hostIP string) *Client { base := baseURL(hostIP) httpClient := newHTTPClient() + streamingClient := newStreamingHTTPClient() return &Client{ - hostIP: hostIP, - base: base, - healthURL: base + "/health", - httpClient: httpClient, - process: genconnect.NewProcessClient(httpClient, base), - filesystem: genconnect.NewFilesystemClient(httpClient, base), + hostIP: hostIP, + base: base, + healthURL: base + "/health", + httpClient: httpClient, + streamingClient: streamingClient, + process: genconnect.NewProcessClient(streamingClient, base), + filesystem: genconnect.NewFilesystemClient(httpClient, base), } } +// CloseIdleConnections closes idle connections on both the unary and streaming +// transports. Call this before taking a VM snapshot to remove stale TCP state +// from the guest. +func (c *Client) CloseIdleConnections() { + c.httpClient.CloseIdleConnections() + c.streamingClient.CloseIdleConnections() +} + // BaseURL returns the HTTP base URL for reaching envd. func (c *Client) BaseURL() string { return c.base } -// HTTPClient returns the underlying http.Client used for envd requests. -// Use this instead of http.DefaultClient when making direct HTTP calls to envd -// (e.g. file streaming) to avoid sharing the global transport with proxy traffic. +// HTTPClient returns the http.Client with a 2-minute request timeout. +// Suitable for short-lived envd calls (health, init, snapshot/prepare). func (c *Client) HTTPClient() *http.Client { return c.httpClient } +// StreamingHTTPClient returns the http.Client without a request timeout. +// Use for streaming file transfers or any request that may run indefinitely. +func (c *Client) StreamingHTTPClient() *http.Client { + return c.streamingClient +} + // ExecResult holds the output of a command execution. type ExecResult struct { Stdout []byte diff --git a/internal/envdclient/dialer.go b/internal/envdclient/dialer.go index 1813ceb..ffd3509 100644 --- a/internal/envdclient/dialer.go +++ b/internal/envdclient/dialer.go @@ -20,6 +20,22 @@ func baseURL(hostIP string) string { // so that proxy traffic to user services inside the sandbox cannot interfere // with envd RPC connections (PTY streams, exec, file ops). func newHTTPClient() *http.Client { + return &http.Client{ + Timeout: 2 * time.Minute, + Transport: &http.Transport{ + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }, + } +} + +// newStreamingHTTPClient returns an http.Client without an overall timeout, +// for long-lived streaming RPCs (PTY, exec stream) that can run indefinitely. +func newStreamingHTTPClient() *http.Client { return &http.Client{ Transport: &http.Transport{ MaxIdleConnsPerHost: 10, diff --git a/internal/hostagent/server.go b/internal/hostagent/server.go index e15ef0b..a1b40c8 100644 --- a/internal/hostagent/server.go +++ b/internal/hostagent/server.go @@ -459,7 +459,7 @@ func (s *Server) WriteFileStream( } httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType()) - resp, err := client.HTTPClient().Do(httpReq) + resp, err := client.StreamingHTTPClient().Do(httpReq) if err != nil { pw.CloseWithError(err) <-errCh @@ -504,7 +504,7 @@ func (s *Server) ReadFileStream( return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err)) } - resp, err := client.HTTPClient().Do(httpReq) + resp, err := client.StreamingHTTPClient().Do(httpReq) if err != nil { return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err)) } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 3c49cd6..117d8c7 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -387,9 +387,17 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { sb.connTracker.Drain(2 * time.Second) slog.Debug("pause: proxy connections drained", "id", sandboxID) - // Step 0b: Signal envd to quiesce continuous goroutines (port scanner, - // forwarder) and run GC before freezing vCPUs. This prevents Go runtime - // page allocator corruption ("bad summary data") on snapshot restore. + // Step 0b: Close host-side idle connections to envd. Done before + // PrepareSnapshot so FIN packets propagate to the guest during the + // PrepareSnapshot window (no extra sleep needed). + sb.client.CloseIdleConnections() + slog.Debug("pause: envd client idle connections closed", "id", sandboxID) + + // Step 0c: Signal envd to quiesce continuous goroutines (port scanner, + // forwarder), close idle HTTP connections, and run GC before freezing + // vCPUs. This prevents Go runtime page allocator corruption ("bad + // summary data") on snapshot restore. The 3s timeout also gives time + // for the FINs from Step 0b to be processed by the guest kernel. // Best-effort: a failure is logged but does not abort the pause. func() { prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second)