forked from wrenn/wrenn
After Firecracker snapshot restore, zombie TCP sockets from the previous session cause Go runtime corruption inside the guest VM, making envd unresponsive. This manifests as infinite loading in the file browser and terminal timeouts (524) in production (HTTP/2 + Cloudflare) but not locally. Four-part fix: - Add ServerConnTracker to envd that tracks connections via ConnState callback, closes idle connections and disables keep-alives before snapshot, then closes all pre-snapshot zombie connections on restore (while preserving post-restore connections like the /init request) - Split envdclient into timeout (2min) and streaming (no timeout) HTTP clients; use streaming client for file transfers and process RPCs - Close host-side idle envdclient connections before PrepareSnapshot so FIN packets propagate during the 3s quiesce window - Add StreamingHTTPClient() accessor; streaming file transfer handlers in hostagent use it instead of the timeout client
303 lines
8.4 KiB
Go
303 lines
8.4 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
// Modifications by M/S Omukk
|
|
|
|
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/netip"
|
|
"os/exec"
|
|
"time"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
|
"github.com/awnumar/memguard"
|
|
"github.com/rs/zerolog"
|
|
"github.com/txn2/txeh"
|
|
)
|
|
|
|
var (
|
|
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
|
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
|
)
|
|
|
|
// validateInitAccessToken validates the access token for /init requests.
|
|
// Token is valid if it matches the existing token OR the MMDS hash.
|
|
// If neither exists, first-time setup is allowed.
|
|
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
|
requestTokenSet := requestToken.IsSet()
|
|
|
|
// Fast path: token matches existing
|
|
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
|
return nil
|
|
}
|
|
|
|
// Check MMDS only if token didn't match existing
|
|
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
|
|
|
switch {
|
|
case matchesMMDS:
|
|
return nil
|
|
case !a.accessToken.IsSet() && !mmdsExists:
|
|
return nil // first-time setup
|
|
case !requestTokenSet:
|
|
return ErrAccessTokenResetNotAuthorized
|
|
default:
|
|
return ErrAccessTokenMismatch
|
|
}
|
|
}
|
|
|
|
// checkMMDSHash checks if the request token matches the MMDS hash.
|
|
// Returns (matches, mmdsExists).
|
|
//
|
|
// The MMDS hash is set by the orchestrator during Resume:
|
|
// - hash(token): requires this specific token
|
|
// - hash(""): explicitly allows nil token (token reset authorized)
|
|
// - "": MMDS not properly configured, no authorization granted
|
|
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
|
if a.isNotFC {
|
|
return false, false
|
|
}
|
|
|
|
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
|
if err != nil {
|
|
return false, false
|
|
}
|
|
|
|
if mmdsHash == "" {
|
|
return false, false
|
|
}
|
|
|
|
if !requestToken.IsSet() {
|
|
return mmdsHash == keys.HashAccessToken(""), true
|
|
}
|
|
|
|
tokenBytes, err := requestToken.Bytes()
|
|
if err != nil {
|
|
return false, true
|
|
}
|
|
defer memguard.WipeBytes(tokenBytes)
|
|
|
|
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
|
}
|
|
|
|
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
|
defer r.Body.Close()
|
|
|
|
ctx := r.Context()
|
|
|
|
operationID := logs.AssignOperationID()
|
|
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
|
|
|
if r.Body != nil {
|
|
// Read raw body so we can wipe it after parsing
|
|
body, err := io.ReadAll(r.Body)
|
|
// Ensure body is wiped after we're done
|
|
defer memguard.WipeBytes(body)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to read request body: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
var initRequest PostInitJSONBody
|
|
if len(body) > 0 {
|
|
err = json.Unmarshal(body, &initRequest)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to decode request: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// Ensure request token is destroyed if not transferred via TakeFrom.
|
|
// This handles: validation failures, timestamp-based skips, and any early returns.
|
|
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
|
defer initRequest.AccessToken.Destroy()
|
|
|
|
a.initLock.Lock()
|
|
defer a.initLock.Unlock()
|
|
|
|
// Update data only if the request is newer or if there's no timestamp at all
|
|
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
|
err = a.SetData(ctx, logger, initRequest)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
default:
|
|
logger.Error().Msgf("Failed to set data: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}
|
|
w.Write([]byte(err.Error()))
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
go func() { //nolint:contextcheck // TODO: fix this later
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
|
}()
|
|
|
|
// Close zombie connections from before the snapshot and re-enable
|
|
// keep-alives. On first boot this is a no-op (no zombie connections).
|
|
if a.connTracker != nil {
|
|
a.connTracker.RestoreAfterSnapshot()
|
|
}
|
|
|
|
// Start the port scanner and forwarder if they were stopped by a
|
|
// pre-snapshot prepare call. Start is a no-op if already running,
|
|
// so this is safe on first boot and only takes effect after restore.
|
|
if a.portSubsystem != nil {
|
|
a.portSubsystem.Start(a.rootCtx)
|
|
}
|
|
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.Header().Set("Content-Type", "")
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
|
// Validate access token before proceeding with any action
|
|
// The request must provide a token that is either:
|
|
// 1. Matches the existing access token (if set), OR
|
|
// 2. Matches the MMDS hash (for token change during resume)
|
|
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
if data.EnvVars != nil {
|
|
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
|
|
|
for key, value := range *data.EnvVars {
|
|
logger.Debug().Msgf("Setting env var for %s", key)
|
|
a.defaults.EnvVars.Store(key, value)
|
|
}
|
|
}
|
|
|
|
if data.AccessToken.IsSet() {
|
|
logger.Debug().Msg("Setting access token")
|
|
a.accessToken.TakeFrom(data.AccessToken)
|
|
} else if a.accessToken.IsSet() {
|
|
logger.Debug().Msg("Clearing access token")
|
|
a.accessToken.Destroy()
|
|
}
|
|
|
|
if data.HyperloopIP != nil {
|
|
go a.SetupHyperloop(*data.HyperloopIP)
|
|
}
|
|
|
|
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
|
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
|
a.defaults.User = *data.DefaultUser
|
|
}
|
|
|
|
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
|
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
|
a.defaults.Workdir = data.DefaultWorkdir
|
|
}
|
|
|
|
if data.VolumeMounts != nil {
|
|
for _, volume := range *data.VolumeMounts {
|
|
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
|
|
|
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
|
commands := [][]string{
|
|
{"mkdir", "-p", path},
|
|
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
|
}
|
|
|
|
for _, command := range commands {
|
|
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
|
|
|
logger := a.getLogger(err)
|
|
|
|
logger.
|
|
Strs("command", command).
|
|
Str("output", string(data)).
|
|
Msg("Mount NFS")
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *API) SetupHyperloop(address string) {
|
|
a.hyperloopLock.Lock()
|
|
defer a.hyperloopLock.Unlock()
|
|
|
|
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
|
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
|
} else {
|
|
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
|
}
|
|
}
|
|
|
|
const eventsHost = "events.wrenn.local"
|
|
|
|
func rewriteHostsFile(address, path string) error {
|
|
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
|
ReadFilePath: path,
|
|
WriteFilePath: path,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create hosts: %w", err)
|
|
}
|
|
|
|
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
|
// This will remove any existing entries for events.wrenn.local first
|
|
ipFamily, err := getIPFamily(address)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get ip family: %w", err)
|
|
}
|
|
|
|
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
|
return nil // nothing to be done
|
|
}
|
|
|
|
hosts.AddHost(address, eventsHost)
|
|
|
|
return hosts.Save()
|
|
}
|
|
|
|
var (
|
|
ErrInvalidAddress = errors.New("invalid IP address")
|
|
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
|
)
|
|
|
|
func getIPFamily(address string) (txeh.IPFamily, error) {
|
|
addressIP, err := netip.ParseAddr(address)
|
|
if err != nil {
|
|
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case addressIP.Is4():
|
|
return txeh.IPFamilyV4, nil
|
|
case addressIP.Is6():
|
|
return txeh.IPFamilyV6, nil
|
|
default:
|
|
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
|
}
|
|
}
|