Port envd from e2b with internalized shared packages and Connect RPC
- Copy envd source from e2b-dev/infra, internalize shared dependencies
into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
314
envd/internal/api/init.go
Normal file
314
envd/internal/api/init.go
Normal file
@ -0,0 +1,314 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/txn2/txeh"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||
)
|
||||
|
||||
const (
|
||||
maxTimeInPast = 50 * time.Millisecond
|
||||
maxTimeInFuture = 5 * time.Second
|
||||
)
|
||||
|
||||
// validateInitAccessToken validates the access token for /init requests.
|
||||
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||
// If neither exists, first-time setup is allowed.
|
||||
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||
requestTokenSet := requestToken.IsSet()
|
||||
|
||||
// Fast path: token matches existing
|
||||
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check MMDS only if token didn't match existing
|
||||
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||
|
||||
switch {
|
||||
case matchesMMDS:
|
||||
return nil
|
||||
case !a.accessToken.IsSet() && !mmdsExists:
|
||||
return nil // first-time setup
|
||||
case !requestTokenSet:
|
||||
return ErrAccessTokenResetNotAuthorized
|
||||
default:
|
||||
return ErrAccessTokenMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||
// Returns (matches, mmdsExists).
|
||||
//
|
||||
// The MMDS hash is set by the orchestrator during Resume:
|
||||
// - hash(token): requires this specific token
|
||||
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||
// - "": MMDS not properly configured, no authorization granted
|
||||
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||
if a.isNotFC {
|
||||
return false, false
|
||||
}
|
||||
|
||||
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if mmdsHash == "" {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if !requestToken.IsSet() {
|
||||
return mmdsHash == keys.HashAccessToken(""), true
|
||||
}
|
||||
|
||||
tokenBytes, err := requestToken.Bytes()
|
||||
if err != nil {
|
||||
return false, true
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||
}
|
||||
|
||||
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||
|
||||
if r.Body != nil {
|
||||
// Read raw body so we can wipe it after parsing
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// Ensure body is wiped after we're done
|
||||
defer memguard.WipeBytes(body)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var initRequest PostInitJSONBody
|
||||
if len(body) > 0 {
|
||||
err = json.Unmarshal(body, &initRequest)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||
defer initRequest.AccessToken.Destroy()
|
||||
|
||||
a.initLock.Lock()
|
||||
defer a.initLock.Unlock()
|
||||
|
||||
// Update data only if the request is newer or if there's no timestamp at all
|
||||
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||
err = a.SetData(ctx, logger, initRequest)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Error().Msgf("Failed to set data: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() { //nolint:contextcheck // TODO: fix this later
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||
}()
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||
// Validate access token before proceeding with any action
|
||||
// The request must provide a token that is either:
|
||||
// 1. Matches the existing access token (if set), OR
|
||||
// 2. Matches the MMDS hash (for token change during resume)
|
||||
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.Timestamp != nil {
|
||||
// Check if current time differs significantly from the received timestamp
|
||||
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
|
||||
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
|
||||
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
|
||||
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to set system time: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
if data.EnvVars != nil {
|
||||
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||
|
||||
for key, value := range *data.EnvVars {
|
||||
logger.Debug().Msgf("Setting env var for %s", key)
|
||||
a.defaults.EnvVars.Store(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if data.AccessToken.IsSet() {
|
||||
logger.Debug().Msg("Setting access token")
|
||||
a.accessToken.TakeFrom(data.AccessToken)
|
||||
} else if a.accessToken.IsSet() {
|
||||
logger.Debug().Msg("Clearing access token")
|
||||
a.accessToken.Destroy()
|
||||
}
|
||||
|
||||
if data.HyperloopIP != nil {
|
||||
go a.SetupHyperloop(*data.HyperloopIP)
|
||||
}
|
||||
|
||||
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||
a.defaults.User = *data.DefaultUser
|
||||
}
|
||||
|
||||
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||
a.defaults.Workdir = data.DefaultWorkdir
|
||||
}
|
||||
|
||||
if data.VolumeMounts != nil {
|
||||
for _, volume := range *data.VolumeMounts {
|
||||
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||
|
||||
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||
commands := [][]string{
|
||||
{"mkdir", "-p", path},
|
||||
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||
}
|
||||
|
||||
for _, command := range commands {
|
||||
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||
|
||||
logger := a.getLogger(err)
|
||||
|
||||
logger.
|
||||
Strs("command", command).
|
||||
Str("output", string(data)).
|
||||
Msg("Mount NFS")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) SetupHyperloop(address string) {
|
||||
a.hyperloopLock.Lock()
|
||||
defer a.hyperloopLock.Unlock()
|
||||
|
||||
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||
} else {
|
||||
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||
}
|
||||
}
|
||||
|
||||
const eventsHost = "events.wrenn.local"
|
||||
|
||||
func rewriteHostsFile(address, path string) error {
|
||||
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||
ReadFilePath: path,
|
||||
WriteFilePath: path,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create hosts: %w", err)
|
||||
}
|
||||
|
||||
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||
// This will remove any existing entries for events.wrenn.local first
|
||||
ipFamily, err := getIPFamily(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip family: %w", err)
|
||||
}
|
||||
|
||||
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||
return nil // nothing to be done
|
||||
}
|
||||
|
||||
hosts.AddHost(address, eventsHost)
|
||||
|
||||
return hosts.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAddress = errors.New("invalid IP address")
|
||||
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||
)
|
||||
|
||||
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||
addressIP, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case addressIP.Is4():
|
||||
return txeh.IPFamilyV4, nil
|
||||
case addressIP.Is6():
|
||||
return txeh.IPFamilyV6, nil
|
||||
default:
|
||||
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
|
||||
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
|
||||
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
|
||||
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
|
||||
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
|
||||
}
|
||||
Reference in New Issue
Block a user