1
0
forked from wrenn/wrenn
Files
wrenn-releases/envd/internal/api/init.go
pptx704 7ef9a64613 fix: close stale TCP connections across snapshot/restore to prevent envd hangs
After Firecracker snapshot restore, zombie TCP sockets from the previous
session cause Go runtime corruption inside the guest VM, making envd
unresponsive. This manifests as infinite loading in the file browser and
terminal timeouts (524) in production (HTTP/2 + Cloudflare) but not locally.

Four-part fix:
- Add ServerConnTracker to envd that tracks connections via ConnState callback,
  closes idle connections and disables keep-alives before snapshot, then closes
  all pre-snapshot zombie connections on restore (while preserving post-restore
  connections like the /init request)
- Split envdclient into timeout (2min) and streaming (no timeout) HTTP clients;
  use streaming client for file transfers and process RPCs
- Close host-side idle envdclient connections before PrepareSnapshot so FIN
  packets propagate during the 3s quiesce window
- Add StreamingHTTPClient() accessor; streaming file transfer handlers in
  hostagent use it instead of the timeout client
2026-05-02 05:19:37 +06:00

303 lines
8.4 KiB
Go

// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"os/exec"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
)
var (
ErrAccessTokenMismatch = errors.New("access token validation failed")
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
requestTokenSet := requestToken.IsSet()
// Fast path: token matches existing
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
return nil
}
// Check MMDS only if token didn't match existing
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
switch {
case matchesMMDS:
return nil
case !a.accessToken.IsSet() && !mmdsExists:
return nil // first-time setup
case !requestTokenSet:
return ErrAccessTokenResetNotAuthorized
default:
return ErrAccessTokenMismatch
}
}
// checkMMDSHash checks if the request token matches the MMDS hash.
// Returns (matches, mmdsExists).
//
// The MMDS hash is set by the orchestrator during Resume:
// - hash(token): requires this specific token
// - hash(""): explicitly allows nil token (token reset authorized)
// - "": MMDS not properly configured, no authorization granted
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
if a.isNotFC {
return false, false
}
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
if err != nil {
return false, false
}
if mmdsHash == "" {
return false, false
}
if !requestToken.IsSet() {
return mmdsHash == keys.HashAccessToken(""), true
}
tokenBytes, err := requestToken.Bytes()
if err != nil {
return false, true
}
defer memguard.WipeBytes(tokenBytes)
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
}
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := r.Context()
operationID := logs.AssignOperationID()
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
if r.Body != nil {
// Read raw body so we can wipe it after parsing
body, err := io.ReadAll(r.Body)
// Ensure body is wiped after we're done
defer memguard.WipeBytes(body)
if err != nil {
logger.Error().Msgf("Failed to read request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
var initRequest PostInitJSONBody
if len(body) > 0 {
err = json.Unmarshal(body, &initRequest)
if err != nil {
logger.Error().Msgf("Failed to decode request: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
}
// Ensure request token is destroyed if not transferred via TakeFrom.
// This handles: validation failures, timestamp-based skips, and any early returns.
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
defer initRequest.AccessToken.Destroy()
a.initLock.Lock()
defer a.initLock.Unlock()
// Update data only if the request is newer or if there's no timestamp at all
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
err = a.SetData(ctx, logger, initRequest)
if err != nil {
switch {
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Error().Msgf("Failed to set data: %v", err)
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(err.Error()))
return
}
}
}
go func() { //nolint:contextcheck // TODO: fix this later
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
}()
// Close zombie connections from before the snapshot and re-enable
// keep-alives. On first boot this is a no-op (no zombie connections).
if a.connTracker != nil {
a.connTracker.RestoreAfterSnapshot()
}
// Start the port scanner and forwarder if they were stopped by a
// pre-snapshot prepare call. Start is a no-op if already running,
// so this is safe on first boot and only takes effect after restore.
if a.portSubsystem != nil {
a.portSubsystem.Start(a.rootCtx)
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
// Validate access token before proceeding with any action
// The request must provide a token that is either:
// 1. Matches the existing access token (if set), OR
// 2. Matches the MMDS hash (for token change during resume)
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
return err
}
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
for key, value := range *data.EnvVars {
logger.Debug().Msgf("Setting env var for %s", key)
a.defaults.EnvVars.Store(key, value)
}
}
if data.AccessToken.IsSet() {
logger.Debug().Msg("Setting access token")
a.accessToken.TakeFrom(data.AccessToken)
} else if a.accessToken.IsSet() {
logger.Debug().Msg("Clearing access token")
a.accessToken.Destroy()
}
if data.HyperloopIP != nil {
go a.SetupHyperloop(*data.HyperloopIP)
}
if data.DefaultUser != nil && *data.DefaultUser != "" {
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
a.defaults.User = *data.DefaultUser
}
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
a.defaults.Workdir = data.DefaultWorkdir
}
if data.VolumeMounts != nil {
for _, volume := range *data.VolumeMounts {
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
}
}
return nil
}
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
commands := [][]string{
{"mkdir", "-p", path},
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
}
for _, command := range commands {
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
logger := a.getLogger(err)
logger.
Strs("command", command).
Str("output", string(data)).
Msg("Mount NFS")
if err != nil {
return
}
}
}
func (a *API) SetupHyperloop(address string) {
a.hyperloopLock.Lock()
defer a.hyperloopLock.Unlock()
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
a.logger.Error().Err(err).Msg("failed to modify hosts file")
} else {
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
}
}
const eventsHost = "events.wrenn.local"
func rewriteHostsFile(address, path string) error {
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
ReadFilePath: path,
WriteFilePath: path,
})
if err != nil {
return fmt.Errorf("failed to create hosts: %w", err)
}
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
// This will remove any existing entries for events.wrenn.local first
ipFamily, err := getIPFamily(address)
if err != nil {
return fmt.Errorf("failed to get ip family: %w", err)
}
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
return nil // nothing to be done
}
hosts.AddHost(address, eventsHost)
return hosts.Save()
}
var (
ErrInvalidAddress = errors.New("invalid IP address")
ErrUnknownAddressFormat = errors.New("unknown IP address format")
)
func getIPFamily(address string) (txeh.IPFamily, error) {
addressIP, err := netip.ParseAddr(address)
if err != nil {
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
}
switch {
case addressIP.Is4():
return txeh.IPFamilyV4, nil
case addressIP.Is6():
return txeh.IPFamilyV6, nil
default:
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}