forked from wrenn/wrenn
Three root causes addressed: 1. Go page allocator corruption: allocations between the pre-snapshot GC and VM freeze leave the summary tree inconsistent. After restore, GC reads corrupted metadata — either panicking (killing PID 1 → kernel panic) or silently failing to collect, causing unbounded heap growth until OOM. Fix: move GC to after all HTTP allocations in PostSnapshotPrepare, then set GOMAXPROCS(1) so any remaining allocations run sequentially with no concurrent page allocator access. GOMAXPROCS is restored on first health check after restore. 2. PostInit timeout starvation: WaitUntilReady and PostInit shared a single 30s context. If WaitUntilReady consumed most of it, PostInit failed — RestoreAfterSnapshot never ran, leaving envd with keep-alives disabled and zombie connections. Fix: separate timeout contexts. 3. CP HTTP server missing timeouts: no ReadHeaderTimeout or IdleTimeout caused goroutine leaks from hung proxy connections. Fix: add both, matching host agent values. Also adds UFFD prefetch to proactively load all guest pages after restore, eliminating on-demand page fault latency for subsequent RPC calls.
305 lines
8.4 KiB
Go
305 lines
8.4 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
// Modifications by M/S Omukk
|
|
|
|
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/netip"
|
|
"os/exec"
|
|
"time"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
|
"github.com/awnumar/memguard"
|
|
"github.com/rs/zerolog"
|
|
"github.com/txn2/txeh"
|
|
)
|
|
|
|
var (
|
|
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
|
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
|
)
|
|
|
|
// validateInitAccessToken validates the access token for /init requests.
|
|
// Token is valid if it matches the existing token OR the MMDS hash.
|
|
// If neither exists, first-time setup is allowed.
|
|
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
|
requestTokenSet := requestToken.IsSet()
|
|
|
|
// Fast path: token matches existing
|
|
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
|
return nil
|
|
}
|
|
|
|
// Check MMDS only if token didn't match existing
|
|
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
|
|
|
switch {
|
|
case matchesMMDS:
|
|
return nil
|
|
case !a.accessToken.IsSet() && !mmdsExists:
|
|
return nil // first-time setup
|
|
case !requestTokenSet:
|
|
return ErrAccessTokenResetNotAuthorized
|
|
default:
|
|
return ErrAccessTokenMismatch
|
|
}
|
|
}
|
|
|
|
// checkMMDSHash checks if the request token matches the MMDS hash.
|
|
// Returns (matches, mmdsExists).
|
|
//
|
|
// The MMDS hash is set by the orchestrator during Resume:
|
|
// - hash(token): requires this specific token
|
|
// - hash(""): explicitly allows nil token (token reset authorized)
|
|
// - "": MMDS not properly configured, no authorization granted
|
|
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
|
if a.isNotFC {
|
|
return false, false
|
|
}
|
|
|
|
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
|
if err != nil {
|
|
return false, false
|
|
}
|
|
|
|
if mmdsHash == "" {
|
|
return false, false
|
|
}
|
|
|
|
if !requestToken.IsSet() {
|
|
return mmdsHash == keys.HashAccessToken(""), true
|
|
}
|
|
|
|
tokenBytes, err := requestToken.Bytes()
|
|
if err != nil {
|
|
return false, true
|
|
}
|
|
defer memguard.WipeBytes(tokenBytes)
|
|
|
|
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
|
}
|
|
|
|
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
|
defer r.Body.Close()
|
|
|
|
ctx := r.Context()
|
|
|
|
operationID := logs.AssignOperationID()
|
|
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
|
|
|
if r.Body != nil {
|
|
// Read raw body so we can wipe it after parsing
|
|
body, err := io.ReadAll(r.Body)
|
|
// Ensure body is wiped after we're done
|
|
defer memguard.WipeBytes(body)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to read request body: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
var initRequest PostInitJSONBody
|
|
if len(body) > 0 {
|
|
err = json.Unmarshal(body, &initRequest)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to decode request: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// Ensure request token is destroyed if not transferred via TakeFrom.
|
|
// This handles: validation failures, timestamp-based skips, and any early returns.
|
|
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
|
defer initRequest.AccessToken.Destroy()
|
|
|
|
a.initLock.Lock()
|
|
defer a.initLock.Unlock()
|
|
|
|
// Update data only if the request is newer or if there's no timestamp at all
|
|
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
|
err = a.SetData(ctx, logger, initRequest)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
default:
|
|
logger.Error().Msgf("Failed to set data: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}
|
|
w.Write([]byte(err.Error()))
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
go func() { //nolint:contextcheck // TODO: fix this later
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
|
}()
|
|
|
|
// Safety net: if the health check's postRestoreRecovery didn't run yet
|
|
// (e.g. PostInit arrived before the first health check), re-enable GC
|
|
// here. On first boot needsRestore is false so CAS is a no-op.
|
|
if a.needsRestore.CompareAndSwap(true, false) {
|
|
a.postRestoreRecovery()
|
|
}
|
|
// RestoreAfterSnapshot is idempotent (clears preSnapshot set), and
|
|
// Start is a no-op if already running.
|
|
if a.connTracker != nil {
|
|
a.connTracker.RestoreAfterSnapshot()
|
|
}
|
|
if a.portSubsystem != nil {
|
|
a.portSubsystem.Start(a.rootCtx)
|
|
}
|
|
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.Header().Set("Content-Type", "")
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
|
// Validate access token before proceeding with any action
|
|
// The request must provide a token that is either:
|
|
// 1. Matches the existing access token (if set), OR
|
|
// 2. Matches the MMDS hash (for token change during resume)
|
|
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
if data.EnvVars != nil {
|
|
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
|
|
|
for key, value := range *data.EnvVars {
|
|
logger.Debug().Msgf("Setting env var for %s", key)
|
|
a.defaults.EnvVars.Store(key, value)
|
|
}
|
|
}
|
|
|
|
if data.AccessToken.IsSet() {
|
|
logger.Debug().Msg("Setting access token")
|
|
a.accessToken.TakeFrom(data.AccessToken)
|
|
} else if a.accessToken.IsSet() {
|
|
logger.Debug().Msg("Clearing access token")
|
|
a.accessToken.Destroy()
|
|
}
|
|
|
|
if data.HyperloopIP != nil {
|
|
go a.SetupHyperloop(*data.HyperloopIP)
|
|
}
|
|
|
|
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
|
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
|
a.defaults.User = *data.DefaultUser
|
|
}
|
|
|
|
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
|
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
|
a.defaults.Workdir = data.DefaultWorkdir
|
|
}
|
|
|
|
if data.VolumeMounts != nil {
|
|
for _, volume := range *data.VolumeMounts {
|
|
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
|
|
|
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
|
commands := [][]string{
|
|
{"mkdir", "-p", path},
|
|
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
|
}
|
|
|
|
for _, command := range commands {
|
|
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
|
|
|
logger := a.getLogger(err)
|
|
|
|
logger.
|
|
Strs("command", command).
|
|
Str("output", string(data)).
|
|
Msg("Mount NFS")
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *API) SetupHyperloop(address string) {
|
|
a.hyperloopLock.Lock()
|
|
defer a.hyperloopLock.Unlock()
|
|
|
|
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
|
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
|
} else {
|
|
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
|
}
|
|
}
|
|
|
|
const eventsHost = "events.wrenn.local"
|
|
|
|
func rewriteHostsFile(address, path string) error {
|
|
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
|
ReadFilePath: path,
|
|
WriteFilePath: path,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create hosts: %w", err)
|
|
}
|
|
|
|
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
|
// This will remove any existing entries for events.wrenn.local first
|
|
ipFamily, err := getIPFamily(address)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get ip family: %w", err)
|
|
}
|
|
|
|
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
|
return nil // nothing to be done
|
|
}
|
|
|
|
hosts.AddHost(address, eventsHost)
|
|
|
|
return hosts.Save()
|
|
}
|
|
|
|
var (
|
|
ErrInvalidAddress = errors.New("invalid IP address")
|
|
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
|
)
|
|
|
|
func getIPFamily(address string) (txeh.IPFamily, error) {
|
|
addressIP, err := netip.ParseAddr(address)
|
|
if err != nil {
|
|
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case addressIP.Is4():
|
|
return txeh.IPFamilyV4, nil
|
|
case addressIP.Is6():
|
|
return txeh.IPFamilyV6, nil
|
|
default:
|
|
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
|
}
|
|
}
|