- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
315 lines
8.9 KiB
Go
315 lines
8.9 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/netip"
|
|
"os/exec"
|
|
"time"
|
|
|
|
"github.com/awnumar/memguard"
|
|
"github.com/rs/zerolog"
|
|
"github.com/txn2/txeh"
|
|
"golang.org/x/sys/unix"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
|
)
|
|
|
|
var (
|
|
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
|
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
|
)
|
|
|
|
const (
|
|
maxTimeInPast = 50 * time.Millisecond
|
|
maxTimeInFuture = 5 * time.Second
|
|
)
|
|
|
|
// validateInitAccessToken validates the access token for /init requests.
|
|
// Token is valid if it matches the existing token OR the MMDS hash.
|
|
// If neither exists, first-time setup is allowed.
|
|
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
|
requestTokenSet := requestToken.IsSet()
|
|
|
|
// Fast path: token matches existing
|
|
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
|
return nil
|
|
}
|
|
|
|
// Check MMDS only if token didn't match existing
|
|
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
|
|
|
switch {
|
|
case matchesMMDS:
|
|
return nil
|
|
case !a.accessToken.IsSet() && !mmdsExists:
|
|
return nil // first-time setup
|
|
case !requestTokenSet:
|
|
return ErrAccessTokenResetNotAuthorized
|
|
default:
|
|
return ErrAccessTokenMismatch
|
|
}
|
|
}
|
|
|
|
// checkMMDSHash checks if the request token matches the MMDS hash.
|
|
// Returns (matches, mmdsExists).
|
|
//
|
|
// The MMDS hash is set by the orchestrator during Resume:
|
|
// - hash(token): requires this specific token
|
|
// - hash(""): explicitly allows nil token (token reset authorized)
|
|
// - "": MMDS not properly configured, no authorization granted
|
|
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
|
if a.isNotFC {
|
|
return false, false
|
|
}
|
|
|
|
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
|
if err != nil {
|
|
return false, false
|
|
}
|
|
|
|
if mmdsHash == "" {
|
|
return false, false
|
|
}
|
|
|
|
if !requestToken.IsSet() {
|
|
return mmdsHash == keys.HashAccessToken(""), true
|
|
}
|
|
|
|
tokenBytes, err := requestToken.Bytes()
|
|
if err != nil {
|
|
return false, true
|
|
}
|
|
defer memguard.WipeBytes(tokenBytes)
|
|
|
|
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
|
}
|
|
|
|
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
|
defer r.Body.Close()
|
|
|
|
ctx := r.Context()
|
|
|
|
operationID := logs.AssignOperationID()
|
|
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
|
|
|
if r.Body != nil {
|
|
// Read raw body so we can wipe it after parsing
|
|
body, err := io.ReadAll(r.Body)
|
|
// Ensure body is wiped after we're done
|
|
defer memguard.WipeBytes(body)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to read request body: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
var initRequest PostInitJSONBody
|
|
if len(body) > 0 {
|
|
err = json.Unmarshal(body, &initRequest)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to decode request: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// Ensure request token is destroyed if not transferred via TakeFrom.
|
|
// This handles: validation failures, timestamp-based skips, and any early returns.
|
|
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
|
defer initRequest.AccessToken.Destroy()
|
|
|
|
a.initLock.Lock()
|
|
defer a.initLock.Unlock()
|
|
|
|
// Update data only if the request is newer or if there's no timestamp at all
|
|
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
|
err = a.SetData(ctx, logger, initRequest)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
default:
|
|
logger.Error().Msgf("Failed to set data: %v", err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}
|
|
w.Write([]byte(err.Error()))
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
go func() { //nolint:contextcheck // TODO: fix this later
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
|
}()
|
|
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.Header().Set("Content-Type", "")
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
|
// Validate access token before proceeding with any action
|
|
// The request must provide a token that is either:
|
|
// 1. Matches the existing access token (if set), OR
|
|
// 2. Matches the MMDS hash (for token change during resume)
|
|
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
if data.Timestamp != nil {
|
|
// Check if current time differs significantly from the received timestamp
|
|
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
|
|
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
|
|
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
|
|
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
|
|
if err != nil {
|
|
logger.Error().Msgf("Failed to set system time: %v", err)
|
|
}
|
|
} else {
|
|
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
|
|
}
|
|
}
|
|
|
|
if data.EnvVars != nil {
|
|
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
|
|
|
for key, value := range *data.EnvVars {
|
|
logger.Debug().Msgf("Setting env var for %s", key)
|
|
a.defaults.EnvVars.Store(key, value)
|
|
}
|
|
}
|
|
|
|
if data.AccessToken.IsSet() {
|
|
logger.Debug().Msg("Setting access token")
|
|
a.accessToken.TakeFrom(data.AccessToken)
|
|
} else if a.accessToken.IsSet() {
|
|
logger.Debug().Msg("Clearing access token")
|
|
a.accessToken.Destroy()
|
|
}
|
|
|
|
if data.HyperloopIP != nil {
|
|
go a.SetupHyperloop(*data.HyperloopIP)
|
|
}
|
|
|
|
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
|
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
|
a.defaults.User = *data.DefaultUser
|
|
}
|
|
|
|
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
|
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
|
a.defaults.Workdir = data.DefaultWorkdir
|
|
}
|
|
|
|
if data.VolumeMounts != nil {
|
|
for _, volume := range *data.VolumeMounts {
|
|
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
|
|
|
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
|
commands := [][]string{
|
|
{"mkdir", "-p", path},
|
|
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
|
}
|
|
|
|
for _, command := range commands {
|
|
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
|
|
|
logger := a.getLogger(err)
|
|
|
|
logger.
|
|
Strs("command", command).
|
|
Str("output", string(data)).
|
|
Msg("Mount NFS")
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *API) SetupHyperloop(address string) {
|
|
a.hyperloopLock.Lock()
|
|
defer a.hyperloopLock.Unlock()
|
|
|
|
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
|
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
|
} else {
|
|
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
|
}
|
|
}
|
|
|
|
const eventsHost = "events.wrenn.local"
|
|
|
|
func rewriteHostsFile(address, path string) error {
|
|
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
|
ReadFilePath: path,
|
|
WriteFilePath: path,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create hosts: %w", err)
|
|
}
|
|
|
|
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
|
// This will remove any existing entries for events.wrenn.local first
|
|
ipFamily, err := getIPFamily(address)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get ip family: %w", err)
|
|
}
|
|
|
|
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
|
return nil // nothing to be done
|
|
}
|
|
|
|
hosts.AddHost(address, eventsHost)
|
|
|
|
return hosts.Save()
|
|
}
|
|
|
|
var (
|
|
ErrInvalidAddress = errors.New("invalid IP address")
|
|
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
|
)
|
|
|
|
func getIPFamily(address string) (txeh.IPFamily, error) {
|
|
addressIP, err := netip.ParseAddr(address)
|
|
if err != nil {
|
|
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case addressIP.Is4():
|
|
return txeh.IPFamilyV4, nil
|
|
case addressIP.Is6():
|
|
return txeh.IPFamilyV6, nil
|
|
default:
|
|
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
|
}
|
|
}
|
|
|
|
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
|
|
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
|
|
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
|
|
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
|
|
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
|
|
}
|