1
0
forked from wrenn/wrenn

85 Commits

Author SHA1 Message Date
429de4001f refactor(hosts): remove duplicated adminHostResponse, reuse hostResponse for both endpoints 2026-05-23 04:13:17 +06:00
76f8cc1675 feat(hosts): add resource transparency to BYOC dashboard host page 2026-05-23 04:10:13 +06:00
e061095110 feat(templates): resolve actual disk usage via host agent RPC
- Add GetTemplateSize RPC to host agent proto with team_id + template_id
- Implement TemplateRootfsSize using block-level stat (Blocks * 512)
  for accurate sparse file accounting
- Add resolveTemplateSizes helper that queries a host agent and
  persists the result to the DB for subsequent requests
- Wire resolveTemplateSizes into GET /v1/admin/templates and
  GET /v1/snapshots so system base templates show correct sizes
- Add UpdateTemplateSize sqlc query
- Remove unused cert_fingerprint / cert_expires_at from host queries
- Add GET /v1/admin/hosts route
2026-05-23 03:57:30 +06:00
f4eb6645bc feat(admin): host metrics, team details, and admin panel improvements 2026-05-23 03:42:40 +06:00
6014fa72cf Merge pull request 'Added a multi-distro system and asynchronized snapshot action' (#52) from fix/image-creation-and-maintenance into dev
Reviewed-on: wrenn/wrenn#52
2026-05-22 20:07:18 +00:00
c164aadfc9 feat(templates): multi-distro system base images + paused-state snapshotting
- Replace single 'minimal' system template with four well-known distro
  base images: minimal-ubuntu (id 0), minimal-alpine (id 1), minimal-arch
  (id 2), minimal-fedora (id 3) — all platform-owned, reserved IDs 0–1024
- Add id.IsReservedTemplateID() guard and well-known ID constants
- Seed the four system base templates via a goose migration
- On-disk layout: images/teams/{b36(team)}/{b36(tid)}/rootfs.ext4
- Rebuild scripts (update-minimal-rootfs.sh, build-*.sh) handle all four
- rootfs-from-container.sh: fix usr-merged distro tini install, flatten
  IMAGE_NAME slashes for tar path

Snapshotting:
- Allow snapshots from 'paused' state (not just 'running')
- Implement flattenPausedCow for paused sandboxes: re-attach dm-snapshot,
  flatten CoW, tear down; distinct dm name to avoid colliding with Resume
- copyMemorySnapshotFiles + linkOrCopyFile: hardlink CH memory artefacts
  with sparse-preserving cp --sparse=always fallback across filesystems
- Promote staging dir atomically with rename(2)
- Track origStatus through snapshotInBackground so badge returns to
  paused (not running) after paused snapshot
- Expand CleanupOrphanPauseDirs to clean up .stage- prefixes too

Build service:
- Look up all base templates (including system ones) via DB query instead
  of hardcoding the minimal template path
- Insert a sandbox DB record for builder sandboxes
- Destroy builder sandbox + mark DB record stopped on finalize
- Default base template to minimal-ubuntu instead of minimal

Chores:
- Remove stale recipe files (code-runner-beta, jupyter test)
- Remove prepare-wrenn-user.sh (replaced by setup-host.sh)
- Remove old build-rootfs.sh / docker-to-rootfs.sh / per-template build.sh
- Update CLAUDE.md, README.md, envd-rs README with new template info
- Frontend: update admin templates page, admin capsule view, snapshot
  dialog to support paused-state snapshots
2026-05-23 01:58:51 +06:00
a01796a4c3 feat(sandbox): async snapshots with snapshotting state + lifecycle toasts
Snapshots now mirror the async pause flow. POST /v1/snapshots returns 202
with the capsule in a new "snapshotting" state and registers the template
in a background goroutine, instead of blocking the request while Cloud
Hypervisor pauses the VM. Users get clear feedback that the capsule is
busy rather than seeing it appear wedged.

- New "snapshotting" sandbox state: CAS running -> snapshotting -> running
- CreateSnapshot async via snapshotInBackground; the back-to-running CAS
  is gated so a racing destroy wins (no false "running" signal)
- serviceEventToCanonical maps snapshot completion + state-change events;
  state changes are published transiently (SSE only)
- HostMonitor reconciles stuck snapshotting: 15m grace while the VM is
  alive, error if the VM is gone
- Global lifecycle toasts (create/pause/resume/snapshot/destroy) driven by
  terminal system-actor SSE events, exactly one per operation
- openapi: status enum, 202 + Capsule responses, admin snapshot name optional
- Remove dead LogSnapshotCreate; host-callback failure during a snapshot
  marks the sandbox errored instead of removing the row
2026-05-22 21:36:46 +06:00
d98acc2764 fix(sandbox): harden pause/snapshot against wedged CH and IO starvation
quiesceAndPauseCH now bounds PrepareSnapshot (5s), probes ch.vm.info
for "Running" before pause, and bounds ch.pause (30s). Second snapshot
attempts against a dead CH socket no longer hang on a 2m envd timeout.

Punch zero-page scan reads 1 MiB chunks instead of 4 KiB, cutting
read(2) syscalls by 256x during 20 GiB snapshot writes on single-disk
hosts. Zero-run batching preserved across chunk boundaries.

Agent startup flushes orphan conntrack rows in 10.11.0.0/16 and
10.12.0.0/16, and logs pre-existing wrenn loop attachments for
post-crash visibility.

wrenn-agent.service gains IOSchedulingClass=best-effort,
IOSchedulingPriority=5, IOWeight=50 so snapshot IO cannot starve
sshd/journal on contended disks.
2026-05-21 14:20:12 +06:00
35654ccb9a Fixed frontend build error 2026-05-20 06:19:49 +06:00
d856f7e18d fix(gitignore): anchor builds/ to repo root; add admin template builds page 2026-05-20 06:15:41 +06:00
eccfb76550 Merge pull request 'Updated template creation process from admin panel and improved agent graceful shutdown' (#51) from fix/template-creation into dev
Reviewed-on: wrenn/wrenn#51
2026-05-19 23:33:28 +00:00
77392ce997 chore(recipes): add data libs and kernel smoke test to code-runner-beta 2026-05-20 05:30:21 +06:00
a8576121cf feat(auth): hash session IDs at rest
Store sha256(SID) in Postgres + Redis; raw SID lives only in cookie.
Prevents DB/cache dump from yielding usable session tokens.
2026-05-20 05:30:17 +06:00
76587e17a2 fix(sandbox): stop leaking paused sandboxes as stopped
Paused sandboxes were silently turning into 'stopped' rows in the CP
DB whenever the auto_paused callback failed delivery and the agent
later restarted. Snapshot dirs survived on disk but the agent had no
restore path, so HostMonitor's missing->stopped reconcile orphaned
them.

Agent
- Manager.RestorePausedSandboxes: scans WRENN_DIR/sandboxes/ on
  startup, picks newest-per-slot, reserves the slot, registers a
  Paused entry in m.boxes. Losers and corrupt metas are trashed.
- Manager.Destroy: re-insert + bail if a racing Pause completed
  between m.boxes delete and lifecycleMu acquisition, so an
  in-flight user Pause's snapshot is not wiped by a concurrent
  Destroy or by Shutdown's destroy loop. Distinguishes from
  legitimate destroy-of-paused via statusAtEntry.
- Manager.Shutdown: emit sandbox.stopped (SendAsync) after each
  non-paused Destroy so the CP flips DB out of running/error.
- ErrDraining + atomic.Bool guard: Create/Resume bounce with
  CodeUnavailable once Shutdown begins, preventing new lifecycle
  work from racing the destroy loop.
- resumeFromMeta: set sb.baseImagePath so cleanup's loops.Release
  pairs with the Acquire just performed (latent bug for any
  Resume->Destroy after the restore path lands).
- ListSandboxes: nil-guard HostIP so paused entries don't emit
  '<nil>' into proto.

Control plane
- HostMonitor: new reconcile block. When the agent reports a
  sandbox as paused but the DB row is stopped (legacy zombie),
  issue DestroySandbox to release the on-disk snapshot and slot.
  Gated on len(paused) > 0 to avoid unbounded stopped-row query.

Host agent main
- Call RestorePausedSandboxes after SetEventSender, before
  StartTTLReaper and before HTTP serve.
- Keep mgr.Shutdown BEFORE httpServer.Shutdown so main does not
  exit while pause work is still draining.
2026-05-20 05:26:31 +06:00
0e8ba30c7a fix(snapshot): propagate SandboxDir across template chains
A snapshot-of-a-snapshot (t2 from t1 from t0) failed to launch with
"GET /api/v1/vm.info: status 500: VM is not created". Cloud Hypervisor's
saved config.json hardcodes the original sandbox's tmpfs path, but
createFromSnapshotTemplate rebuilt SandboxDir from meta.SandboxID — which
on a chained template was the immediate parent sandbox, not the root
ancestor whose path CH baked in. Mismatch → restore silently failed.

wrenn-snapshot.json now carries the effective SandboxDir verbatim, set
at snapshot time from sandboxDirOverride (inherited across chains) or
SandboxTmpDir(sb.ID). The launcher uses meta.SandboxDir directly.

Other changes to the snapshot metadata file:
- Drop sandbox_id (no longer load-bearing).
- Add template_name (plumbed via the existing, previously-deprecated
  CreateSnapshotRequest.name field).
- Make slot_index omitempty and stop writing it for templates — a
  template allocates a fresh slot per launch. Pause still records it.

Template deletion residuals:
- Manager.DeleteSnapshot now prunes the parent team directory when it
  becomes empty.
- New strict deleteSnapshotEverywhere helper used by both the
  user-facing DELETE /v1/snapshots/{name} and admin
  DELETE /v1/admin/templates/{name}. Aborts the DB-row deletion if any
  active host is offline or returns a non-NotFound error, so a failure
  leaves the template visible and retryable instead of orphaning files.

Hard cutover: snapshot templates created before this change have no
sandbox_dir field and will fail to launch with a clear error. Existing
pause snapshots are unaffected (Resume reuses the original sandbox ID,
so the empty SandboxDir falls back to the same default).
2026-05-20 03:14:41 +06:00
2e71ebe16f feat(builds): real-time WebSocket+PTY template build console
The admin template build experience was poll-based: the frontend fetched
GET /v1/admin/builds/{id} every 3s and per-step logs only appeared after a
step finished. This replaces it with a live console.

Backend:
- recipe.Execute gains streaming callbacks (StepStartFunc, OutputChunkFunc)
  and a StreamExecFunc. RUN steps now stream via a PTY so build tools emit
  unbuffered, colorized output. execRunStreaming treats a stream that ends
  without a terminal chunk as a failure.
- The build worker runs each RUN step through the host agent PtyAttach RPC
  and publishes step-start/output/step-end/build-status events to a per-build
  Redis pub/sub channel (wrenn:build:{id}).
- BuildBroker fans those events from Redis out to in-process WebSocket
  subscribers, with lazy per-build subscriptions.
- New WS endpoint GET /v1/admin/builds/{id}/stream replays the completed-step
  history from the DB log, then live-tails broker events.

run_as_root:
- The non-root build user is now injected as USER/WORKDIR steps prepended to
  the persisted recipe by the create handler, instead of being hardcoded in
  the pre-build phase. "Run as root" simply omits them, so wrenn-user is never
  created in a root template. No build-level column is needed.

Frontend:
- New /admin/templates/builds/[id] route with a hybrid console: an xterm.js
  terminal streaming live PTY output plus a structured step list (status,
  exit code, timing).
- Build rows on the templates page navigate to the console; the inline
  log-expand is removed. A "Run recipe as root" checkbox is added to the
  create modal.
2026-05-20 02:09:03 +06:00
f06d03996a fix: niceness EPERM for non-root processes, flatten sync
envd: spawn_process wrapped every command in `nice -n {delta}`. current_nice()
used `20 - getpriority()`, but getpriority already returns the nice value, so
delta was -20 for a default-nice envd. Non-root users cannot raise priority, so
the wrapper failed with "cannot set niceness: permission denied" for any process
run as a non-root user. current_nice() now returns the raw value; the wrapper
invokes `nice` only when delta > 0. The oom_score_adj write is kept (always
permitted for raising one's own score).

sandbox: FlattenRootfs used a plain ch.pause, which freezes vCPUs but does not
flush the guest VFS page cache, so freshly written files (e.g. pip installs)
had not reached the block device and the flattened rootfs captured empty files.
Switch to quiesceAndPauseCH (envd /snapshot/prepare: sync + drop_caches), as
CreateSnapshot and Pause already do, and reset the connection tracker after
resume on both the success and quiesce-failure paths.
2026-05-20 01:34:15 +06:00
21c837aa02 fix: surface sandbox create failures and stop destroy leak
The createInBackground "sandbox.failed" event had no case in
serviceEventToCanonical, so it was dropped: no SSE push, no webhook,
no audit row. The dashboard stayed on "starting" until a manual
refresh. Add the case and route create/resume/pause failures through
with a non-empty reason so channels.isRedundantSystemFollowup does
not filter them out of webhook delivery.

A DestroySandbox RPC arriving while Create was still booting found
nothing in m.boxes (the sandbox is registered only after the envd
readiness wait) and no-oped, while Create raced on to register a VM
that no DB record owned — a permanent VM/dm/network/loop leak. Create
now registers a cancellable handle in m.creates before acquiring
resources; Destroy aborts the in-flight create and waits for its
rollback. Shutdown drains in-flight creates for the same reason.

The envd readiness wait was a fixed 30s, too tight for large VMs
which cold-boot slower. Scale it at 8s per GiB of guest RAM with a
120s floor.

handleFailed skips its audit write for reconciler-emitted events
(reason=transient_timeout); LogSandboxCreateSystem already wrote that
row before publishing, so auditing again would double-count.
2026-05-20 01:09:26 +06:00
0012088191 fix(envd): avoid lost prune on already-exited process
Check cached_end() after subscribing so the cleanup task does not block
on a broadcast receiver that already missed the end event from a
short-lived process.
2026-05-19 23:41:29 +06:00
28e35bf23e refactor(cpextension): remove LimitsProvider/UsageProvider hooks
Quota enforcement moves to the cloud repo as a MiddlewareProvider
wrapping POST /v1/capsules. Drops the Limits/Usage structs, both
provider interfaces, the enforceLimits gate, and the DB-backed
defaultUsageProvider. OSS deployment is unmetered.
2026-05-19 18:56:47 +06:00
a94a72afa2 Updated cleanup script 2026-05-19 17:28:36 +06:00
75607b4464 fix(envd): signal process groups, drain output before exit event
SendSignal now targets the negative pid so the whole process group is
killed, not just the /bin/sh wrapper — children are no longer orphaned.
The pipe spawn path calls setpgid to become a group leader (pty path
gets this from setsid).

The waiter thread joins the stdout/stderr/pty reader threads before
publishing the end event, and the RPC streams drain the data channel
fully before emitting the end event (and emit a still-buffered end
event when the data channel closes first). Trailing output is no
longer lost to a process-exit/read race.
2026-05-19 17:16:16 +06:00
8563229bf4 fix(processes,fs): prune killed processes, preserve envd error codes
ProcessServiceImpl.processes was a bare DashMap; the cleanup task
cloned it (DashMap::clone deep-copies), so dead PIDs were removed
from a detached copy and ListProcesses showed them forever. Share
the map via Arc.

Host agent flattened every envd filesystem error to CodeInternal,
discarding the Connect code. mkdir on an existing directory thus
surfaced as 500/agent_error instead of 409. Add envdErr() helper
that preserves the envd Connect code (NotFound, AlreadyExists, ...)
and apply it to WriteFile/ReadFile/ListDir/MakeDir/RemovePath.

Split agentErrToHTTP so AlreadyExists maps to 409 already_exists
(distinct from FailedPrecondition -> 409 conflict). Document the
already_exists case on the mkdir endpoint in openapi.yaml.
2026-05-19 14:53:49 +06:00
e87506ce6b Merge pull request 'Migrate to Cloud Hypervisor' (#49) from feat/migrate-to-ch into dev
Reviewed-on: wrenn/wrenn#49
2026-05-18 22:59:57 +00:00
e843be21a4 extensions: auth/sandbox hooks, limits/usage providers, exported session middleware
Cloud-repo extensions could not build on the cookie-session migration —
they still called the removed JWT helpers and duplicated middleware. Move
the session/CSRF middleware and cookie helpers into pkg/auth/session/middleware
as the single source of truth, with thin re-exports on cpextension.

Add hook interfaces so cloud can plug billing without forking OSS:
- AuthHook (OnSignup fail-fast; OnLogin / OnAccount*Delete log+ignore)
- SandboxEventHook (un-acks on error so messages redeliver; idempotent)
- LimitsProvider / UsageProvider (402 on overage; DB-backed usage default)

ServerContext gains OAuthRegistry, Channels, ChannelPub so extensions stop
reimplementing them.
2026-05-19 04:49:19 +06:00
9b34d6a82f events/capsules: dedupe channel notifications, harden create dialog
- channels dispatcher: drop capsule.{create,pause,resume,destroy} events
  with system actor and no reason metadata. Suppresses the goroutine /
  host-callback follow-up that duplicated every user-initiated action in
  notification channels (Telegram, webhooks). Genuinely system-only
  emitters (TTL auto-pause, host monitor reconciler, host failures) all
  set reason, so they continue to notify.
- CreateCapsuleDialog: wrap submit in try/finally so the creating flag
  always clears, and close the dialog before invoking oncreated to avoid
  the parent receiving the new capsule while the dialog is still open.
- capsules page: guard against double-insertion of the same capsule when
  the SSE event arrives before the dialog's oncreated callback resolves.
2026-05-19 04:27:11 +06:00
42af7c4357 auth: replace user JWTs with cookie sessions
User authentication moves from short-lived JWT bearer tokens to opaque
session cookies (wrenn_sid) backed by a Postgres sessions table and a
Redis hot cache. Browsers get a paired wrenn_csrf cookie; all mutating
requests must echo it via X-CSRF-Token (double-submit).

- New pkg/auth/session service: issue/revoke, idle (6h) + absolute
  (24h) lifetimes, switch-team rotation, RevokeAllForUser on password
  events, per-user listing for self-service.
- Middleware: requireSession + requireCSRF replace requireJWT and the
  WS first-message JWT exchange. SSE/WS endpoints rely on the cookie
  flowing on the upgrade — SSE ticket store deleted.
- API keys (wrn_<32hex>) remain for SDK/server use; capsule routes
  accept either via requireSessionOrAPIKey.
- Host-agent JWTs (signed by JWT_SECRET) are unchanged — that channel
  is wrenn-cp ↔ wrenn-agent and unrelated to user identity.
- Frontend client drops bearer-token plumbing, sends credentials and
  the CSRF header on every mutating call.
- OpenAPI + dashboard host-registration docs updated.
2026-05-19 04:01:24 +06:00
4b58dc32ab events: outcome + metadata + transient channel
Rename CapsuleCreated→CapsuleCreate (and pair siblings) into action
verbs, add Outcome (success/error), Metadata, and Error fields to the
canonical Event. Introduce PublishTransient for ephemeral SSE-only
signals (capsule.state.changed) so dashboard transitions don't reach
webhook/telegram subscribers.

Audit logger now publishes the canonical event itself with the derived
outcome, collapsing the old "audit then separately publish" split.
Sandbox event consumer rebuilt around the unified stream: host-agent
callbacks are translated once into canonical events, then fan out via
DB reconciler, channel dispatcher, and SSE relay independently.

Documents the channels subscription model in the README.
2026-05-19 04:01:06 +06:00
09cb78f1b8 feat(frontend): typed capsule/SSE state machine, resilient event stream
Introduce CapsuleStatus union + RESUMABLE_STATUSES / TRANSIENT_STATUSES
sets that mirror the backend state machine; the routes and SnapshotDialog
now derive button enablement from the sets instead of ad-hoc string
checks. Add disk_size_mb + metadata to the Capsule shape.

SSEEventKind union + isSSEEvent guard so malformed wire payloads can't
reach handlers via blind casts. Event stream reconnect now:
  - retries with backoff when the ticket fetch itself fails (previously
    gave up on a single 401/network blip),
  - reconnects immediately on window 'online' and document visibilitychange
    (back to visible) when the EventSource is not OPEN,
  - subscribes to capsule.error.

openapi.yaml: align OAuth paths (/v1/auth/oauth → /auth/oauth to match
the actual mount point), document bearerAuth on capsule routes, fix
'capsulees' typos, and expand schemas for the new state machine surfaces
the frontend now consumes.
2026-05-19 01:29:38 +06:00
f002839c48 feat(api): plumb root ctx through SSE store, surface pause/resume failures, clamp TTL
NewSSETicketStore now takes a context so its cleanup goroutine exits on
server shutdown instead of leaking for the process lifetime. Threaded
through api.New and pkg/cpserver/run.go.

SandboxEventConsumer learns sandbox.pause_failed / sandbox.resume_failed
event types and forwards TeamID from the publisher; server.go propagates
TeamID into the SSE broker so per-team subscribers receive failure
events. resumeInBackground now rolls resuming → paused on failure (was
resuming → error) so the user can retry without manual intervention.

pkg/service/sandbox: mirror internal/sandbox.MinTimeoutSec + clampTimeout
on the control plane so the DB row's timeout_sec agrees with what the
agent runs after its own silent clamp.
2026-05-19 01:29:28 +06:00
cab50db1c1 refactor(sandbox): per-sandbox dir layout, atomic envd client, sentinel errors
Move CoW from sandboxes/{id}.cow to sandboxes/{id}/rootfs.cow so every
per-sandbox artifact lives under one parent. PauseSnapshotDir now aliases
SandboxDir; Pause stages the CoW into the staging dir before swapDir so
the swap carries it through.

Publish sb.client via atomic.Pointer so Exec/Pty/Process callers can load
without holding lifecycleMu; Pause's releaseRuntime stores nil, Resume
stores a fresh client. Funnel every caller through new activeClient()
that nil-checks after Load to close the pause-vs-exec race.

Replace string-sniffing for "not found" / "not running" with sentinel
errors (ErrNotFound, ErrNotRunning, ErrNotPaused, ErrInvalidRange) and a
single mapSandboxError switch in the hostagent server. Add
parseSandboxIDs helper for the repeated team+template UUID parse.

Rewrite ConnTracker off sync.WaitGroup onto an explicit counter + zeroCh
so Drain/ForceClose can select on cancellation and timeout without
leaking the waiter goroutine on repeated pause failures.

Add internal/sandbox/punch.go: post-snapshot SEEK_DATA scan that
fallocate-punches any 4 KiB block of zeros in CH memory-* files (guest
dirty-then-free pages CH writes verbatim). Run after both pause snapshot
and CreateSnapshot. Bump envd quiesce sleep 500ms → 1s so the kernel
fully flushes before CH dumps memory.

Add sandboxDirOverride threaded through snapshotMeta + restoreVMConfig:
sandboxes launched from snapshot templates carry the original source
sandbox's tmpfs path in CH's saved config.json, so every subsequent
restore must reuse it.
2026-05-19 01:29:20 +06:00
802af222ee feat(sandbox): launch sandboxes from snapshot templates
New createFromSnapshotTemplate path branches off Manager.Create when the
template directory contains a CH memory snapshot (state.json + config.json
+ rootfs.ext4). Mirrors the pause/resume restore mechanics — same UFFD
lazy memory + post-restore memory loader — but produces a fresh sandbox
per call (new ID, new slot, new CoW on the shared flattened rootfs).

Shared restore primitives extracted to restore.go (buildRestoreVMConfig,
launchRestoredVM, initAndStartMemoryLoader) and reused by resumeFromMeta.

Chain correctness: descendants of snapshot templates start the memory
loader so subsequent CreateSnapshot from them is self-contained.

Defensive guards:
- CreateSnapshot refuses to overwrite an existing template dir.
- DeleteSnapshot refuses when running sandboxes still reference it.
- TimeoutSec clamped to MinTimeoutSec=60 to keep TTL reaper well clear of
  the post-create startup window.
- Snapshot routing skips minimal template even if a stray state.json lands.

vm.SandboxTmpDir / vm.SandboxSocketPath extracted so launchers don't
re-derive CH disk paths independently.
2026-05-18 15:20:45 +06:00
8262a4999e feat(vm): CH live snapshot, pause/resume with UFFD memory restore
Pause and live-snapshot share one CH primitive (ch.pause + ch.snapshot +
ch.destroy/resume). Pause writes artefacts to a staging dir and
atomically swaps to avoid CH re-reading a memory-ranges file mid-rewrite
across pause-resume-pause chains. Resume uses
memory_restore_mode=ondemand backed by userfaultfd; CH lazily faults
pages from the source file. A new envd /memory/preload endpoint
materialises every physical page (one byte per page via /dev/mem,
fallback /proc/kcore) so a subsequent snapshot writes a self-contained
file instead of holes.

Sandbox manager refactor: lifecycle / pause / resume code extracted to
internal/sandbox/pause.go, leaving manager.go focused on the in-memory
state map and orchestration entry points (-871 / +72). Stale CH process
and dm-snapshot cleanup runs at agent startup (internal/vm/cleanup.go)
and via scripts/cleanup-stale.sh for operator use.

Host monitor honors the agent's reported per-sandbox status when
reconciling missing rows (so an agent-side pause during a CP
disconnect isn't silently promoted back to running). New
BulkRestoreMissingToStatus query replaces the running-only path.
Transient statuses (pausing/resuming/starting/stopping) defer
reconciliation to the next tick.
2026-05-18 14:05:27 +06:00
b9cb3998f8 feat(api): real-time SSE event stream for sandbox lifecycle
In-process broker fans out sandbox state events (created/paused/running/
destroyed) to connected SSE clients, filtered by team. Backend publishes
through the channels Publisher; an SSE relay subscribes to Redis Pub/Sub
and dispatches to subscribers. Browser auth uses short-lived tickets
issued via /v1/events/token; SDKs use header auth. Admin routes get a
parallel stream that sees all teams. Frontend dashboard and admin
capsule pages subscribe to push state changes instead of polling.

Sandbox event publishing moved out of AuditLogger into the service layer
so callbacks from the host agent and direct state changes share one
path.
2026-05-18 14:05:06 +06:00
62bede5dae fix: resolve bugs and DRY violations in sandbox manager and API handlers
- Fix createFromSnapshot discarding memoryMB param (balloon optimization was dead)
- Fix double dm-snapshot removal in Pause() cleanupPauseFailure path
- Fix DestroySandbox RPC mapping all errors to CodeNotFound
- Fix handleFailed event consumer missing pausing/resuming → error transitions
- Fix stream resource leak in StreamUpload on early-return paths
- Add envs/cwd fields to ExecRequest proto for foreground exec parity
- Extract createResources rollback helper to eliminate 4x duplicated teardown
- Remove unused chClient.ping method
- Add .mcp.json to gitignore

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-17 02:30:32 +06:00
74f85ce4e9 refactor: polish control plane and host agent code
- Decompose executeBuild (318 lines) into provisionBuildSandbox and
  finalizeBuild helpers for readability
- Extract cleanupPauseFailure in sandbox manager to unify 3 inconsistent
  inline teardown paths (also fixes CoW file leak on rename failure)
- Remove unused ctx parameter from startProcess/startProcessForRestore
- Add missing MASQUERADE rollback entry in CreateNetwork for symmetry
- Consolidate duplicate writeJSON for UTF-8/base64 exec response
2026-05-17 02:11:48 +06:00
124e097e23 refactor: eliminate DRY violations across control plane and host agent
Extract shared helpers to consolidate repeated patterns:
- requireRunningSandbox: sandbox lookup + running check (10 call sites)
- upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers)
- updateLastActive: last_active_at update with background context (5 sites)
- attachCowAndCreate: cow loop attach + dmsetup create (devicemapper)
- issueRegistrationToken: token gen + Redis + audit (host service)
- ErrNotFound sentinel: replaces string matching in hostagent server

Also merges duplicate wsProcessOut/wsOutMsg types into one.

Net: -208 lines, zero behavior change.
2026-05-17 02:03:06 +06:00
a5425969ed fix: assorted bug fixes for CH migration
Fix resource leaks, race conditions, and error handling across host
agent and control plane: proper sparse file cleanup on close error,
connect error wrapping for MakeDir, CoW file cleanup on pause failure,
per-sandbox VM directories, deferred map deletion to avoid race in VM
destroy, and goroutine launch for extension background workers.
2026-05-17 01:47:56 +06:00
fb16bc9ed1 chore: update proto, scripts, and docs for CH migration
- Update hostagent proto: firecracker_version → vmm_version in metadata
- Regenerate hostagent.pb.go
- Update .env.example: WRENN_FIRECRACKER_BIN → WRENN_CH_BIN
- Update Makefile: remove --isnotfc from dev-envd target
- Update prepare-wrenn-user.sh: firecracker → cloud-hypervisor paths
  and capability assignments
- Update wrenn-init.sh: disable write_zeroes on rootfs for dm-snapshot
  compatibility with CH
- Update README.md and CLAUDE.md: Firecracker → Cloud Hypervisor
  throughout
2026-05-17 01:33:35 +06:00
dd8a940431 feat(envd): update guest agent for Cloud Hypervisor
Remove Firecracker-specific MMDS metadata fetching and metrics host
module. CH communicates with the guest purely over TAP networking,
so MMDS (Firecracker's metadata service via MMDS address) is no longer
needed.

- Remove src/host/ module (mmds.rs, metrics.rs)
- Remove reqwest dependency (was only used for MMDS HTTP calls)
- Remove --isnotfc CLI flag (no longer dual-mode)
- Simplify health endpoint and init handler
- Update state management for CH snapshot lifecycle
- Bump version to 0.3.0
2026-05-17 01:33:25 +06:00
eaa6b8576d feat(vm): replace Firecracker with Cloud Hypervisor
Migrate the entire VM layer from Firecracker to Cloud Hypervisor (CH).
CH provides native snapshot/restore via its HTTP API, eliminating the
need for custom UFFD handling, memfile processing, and snapshot header
management that Firecracker required.

Key changes:
- Remove fc.go, jailer.go (FC process management)
- Remove internal/uffd/ package (userfaultfd lazy page loading)
- Remove snapshot/header.go, mapping.go, memfile.go (FC snapshot format)
- Add ch.go (CH HTTP API client over Unix socket)
- Add process.go (CH process lifecycle with unshare+netns)
- Add chversion.go (CH version detection)
- Refactor sandbox manager: remove UFFD socket tracking, snapshot
  parent/diff chaining, FC-specific balloon logic; add crash watcher
- Simplify snapshot/local.go to CH's native snapshot format
- Update VM config: FirecrackerBin → VMMBin, new CH-specific fields
- Update envdclient, devicemapper, network for CH compatibility
2026-05-17 01:33:12 +06:00
c2dc382787 Updated openapi schema 2026-05-16 18:32:37 +06:00
3671af2498 feat: immediate sandbox reconciliation on host reconnect
When a host transitions from unreachable → online via heartbeat, trigger
ReconcileHost in a background goroutine so "missing" sandboxes are
resolved instantly instead of waiting up to 60s for the next monitor tick.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-16 16:15:49 +06:00
e34bcedc31 Merge pull request 'fix/remove-sync-updates' (#47) from fix/remove-sync-updates into dev
Reviewed-on: wrenn/wrenn#47
2026-05-15 08:08:07 +00:00
ff91ef3edf Bump versions 2026-05-15 13:56:04 +06:00
ba3a3db98c Updated openapi specs 2026-05-15 12:39:06 +06:00
6faad45a28 feat: async sandbox lifecycle with Redis Stream events
Replace synchronous RPC-based CP-host communication for sandbox
lifecycle operations (Create, Pause, Resume, Destroy) with an async
pattern. CP handlers now return 202 Accepted immediately, fire agent
RPCs in background goroutines, and publish state events to a Redis
Stream. A background consumer processes events as a fallback writer.

Agent-side auto-pause events are pushed to the CP via HTTP callback
(POST /v1/hosts/sandbox-events), keeping Redis internal to the CP.

All DB status transitions use conditional updates
(UpdateSandboxStatusIf, UpdateSandboxRunningIf) to prevent race
conditions between concurrent operations and background goroutines.

The HostMonitor reconciler is kept at 60s as a safety net, extended
to handle transient statuses (starting, pausing, resuming, stopping).

Frontend updated to handle 202 responses with empty bodies and render
transient statuses with blue indicators.
2026-05-15 12:25:16 +06:00
c08884fa2c Merge branch 'main' of git.omukk.dev:wrenn/wrenn into dev 2026-05-13 11:05:49 +06:00
4707f16c76 v0.1.6 (#45)
## What's New?
Performance updates for large capsules, admin panel enhancement and bug fixes

### Envd
- Fixed bug with sandbox metrics calculation
- Page cache drop and balloon inflation to reduce memfile snapshot
- Updated rpc timeout logic for better control
- Added tests

### Admin Panel
- Add/Remove platform admin
- Updated template deletion logic for fine grained permission

### Others
- Minor frontend visual improvement
- Minor bugfixes
- Version bump

Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: wrenn/wrenn#45
Co-authored-by: pptx704 <rafeed@omukk.dev>
Co-committed-by: pptx704 <rafeed@omukk.dev>
2026-05-13 05:05:35 +00:00
6164d7cae3 version bump 2026-05-13 10:58:54 +06:00
dc6776cc8f fix(agent): register with CP before inflating rootfs images 2026-05-13 10:52:22 +06:00
0bfda08f47 Merge pull request 'test (envd): add 136 unit tests across 12 modules' (#44) from testing/envd into dev
Reviewed-on: wrenn/wrenn#44
2026-05-13 04:42:06 +00:00
485be22a16 test(envd): add 136 unit tests across 12 modules
Cover all pure-function modules with inline #[cfg(test)] blocks:
crypto (NIST/RFC 4231 known-answer vectors), auth (SecureToken ops,
signature generation/validation), conntracker (snapshot lifecycle),
execcontext, util (AtomicMax concurrent correctness), http/encoding
(RFC 7231 negotiation), port/conn (/proc/net/tcp parsing),
rpc/entry (format_permissions), and permissions/path (tilde expansion,
ensure_dirs). Add tempfile dev-dep for filesystem tests. Update
Makefile test target to include cargo test.
2026-05-13 10:39:54 +06:00
ead406bdac Merge pull request 'fix: resolve large operation reliability — stream hangs, pause races, and memory bloat' (#43) from fix/large-operations into dev
Reviewed-on: wrenn/wrenn#43
2026-05-13 03:44:41 +00:00
1472d77b52 Merge branch 'dev' into fix/large-operations 2026-05-13 03:44:19 +00:00
6a0fea30a6 Rootfs script updated 2026-05-13 09:35:06 +06:00
8c34388fc2 Changed commands to check if envd is statically linked or not 2026-05-12 23:19:30 +06:00
aca43d51eb fix: resolve process stream hangs, pause race, and PTY signal loss
- Cache terminal EndEvent on ProcessHandle so connect() can detect
  already-exited processes instead of hanging forever on broadcast
  receivers that missed the event. Subscribe before checking cache
  to close the TOCTOU window.

- Protect sb.Status writes in Pause with m.mu to prevent data race
  with concurrent readers (AcquireProxyConn, Exec, etc.).

- Restart metrics sampler in restoreRunning so a failed pause attempt
  doesn't permanently kill sandbox metrics collection.

- Return dequeued non-input messages from coalescePtyInput instead of
  dropping them, preventing silent loss of kill/resize signals during
  typing bursts.
2026-05-09 18:11:15 +06:00
522e1c5e90 fix: subscribe to process channels before spawning threads to prevent event loss
Fast-exiting processes (e.g. echo) sent data/end events before
start() subscribed to the broadcast channels, causing the stream
to hang indefinitely and the exec RPC to time out with 502.

Move channel subscription into spawn_process, before reader/waiter
threads start, and return pre-subscribed receivers via SpawnedProcess.
2026-05-09 17:28:37 +06:00
d1d316f35c fix: resolve exec 502 by terminating process streams on exit
The start() and connect() streaming RPCs blocked forever in the data
event loop because ProcessHandle retains a broadcast sender (needed for
reconnection via connect()), preventing the channel from closing.

Race data_rx against end_rx with tokio::select! so the stream terminates
when the process exits. Remaining buffered data is drained before
yielding the end event.
2026-05-09 16:36:33 +06:00
2af8412cdc fix: use RwLock for envd Defaults to fix silent mutation loss
The /init handler's default_user mutation cloned the Defaults struct,
mutated the clone, then dropped it — the actual state was never updated.
This caused processes to always run as "root" regardless of the user
set via POST /init. Additionally, default_workdir was accepted in the
init request but never applied.

Wrap user and workdir fields in RwLock with accessor methods so mutations
propagate correctly through the shared AppState.
2026-05-09 15:28:09 +06:00
c93ad5e2db fix: harden pause flow with connection isolation and UFFD event handling
Restructure pause to: block new operations (StatusPausing), drain proxy
connections with 5s grace, force-close remaining via context cancellation,
drop page cache, inflate balloon, then freeze vCPUs. Previously connections
could arrive during the pause window and API operations weren't blocked.

Handle UFFD_EVENT_REMOVE/UNMAP/REMAP/FORK gracefully instead of crashing
the UFFD server. These events fire during balloon deflation on snapshot
restore, killing the page fault handler and preventing VM boot.

Also adds ConnTracker.ForceClose() with cancellable context propagated
through the proxy handler, so lingering proxy connections are actively
terminated rather than left dangling.
2026-05-09 14:51:19 +06:00
38799770db fix: inflate balloon before snapshot to reduce memfile size
Firecracker dumps the entire VM memory region regardless of guest
usage. A 20GB VM using 500MB still produces a ~20GB memfile because
freed pages retain stale data (non-zero blocks).

Inflate the balloon device before snapshot to reclaim free guest
memory. Balloon pages become zero from FC's perspective, allowing
ProcessMemfile to skip them. This reduces memfile size from ~20GB
to ~1-2GB for lightly-used VMs.

- Pause: read guest memory usage, inflate balloon to reclaim free
  pages, wait 2s for guest kernel to process, then proceed
- Resume: deflate balloon to 0 after PostInit so guest gets full
  memory back
- createFromSnapshot: same deflation since template snapshots
  inherit inflated balloon state
- All balloon ops are best-effort with debug logging on failure
2026-05-05 15:38:04 +06:00
51b5d7b3ba fix: resolve pause/snapshot failures and CoW exhaustion on large VMs
Remove hard 10s timeout from Firecracker HTTP client — callers already
pass context.Context with appropriate deadlines, and 20GB+ memfile
writes easily exceed 10s.

Ensure CoW file is at least as large as the origin rootfs. Previously,
WRENN_DEFAULT_ROOTFS_SIZE=30Gi expanded the base image to 30GB but the
default 5GB CoW could not hold all writes, causing dm-snapshot
invalidation and EIO on all guest I/O.

Destroy frozen VMs in resumeOnError instead of leaving zombies that
report "running" but can't execute. Use fresh context for the resume
attempt so a cancelled caller context doesn't falsely trigger destroy.

Increase CP→Agent ResponseHeaderTimeout from 45s to 5min and
PrepareSnapshot timeout from 3s to 30s for large-memory VMs.

After failed pause, ping agent to detect destroyed sandboxes and mark
DB status as "error" instead of reverting to "running".
2026-05-04 01:46:57 +06:00
fd5fa28205 Merge pull request 'Enhanced frontend ux' (#42) from enhance/frontend into dev
Reviewed-on: wrenn/wrenn#42
2026-05-03 11:08:48 +00:00
1244c08e42 fix: fetch sandbox metrics immediately on page load
Metrics data was only fetched after Chart.js dynamic import completed,
leaving graphs empty until the first poll interval fired. Now
loadMetrics() runs in parallel with the Chart.js import, and
initCharts() resets the dedup key so pre-fetched data populates
newly created chart instances.
2026-05-03 16:43:26 +06:00
021d709de2 feat: show template owner and restrict delete in admin panel
Add Owner column to admin templates table, resolving team IDs to names
via admin teams API. Disable delete for non-platform templates and the
minimal template, with contextual tooltips explaining why.
2026-05-03 15:51:20 +06:00
cac6fcd626 feat: admin grant/revoke from admin panel
Add PUT /v1/admin/users/{id}/admin endpoint and frontend UI for
granting and revoking platform admin status. Uses atomic conditional
SQL (RevokeUserAdmin) to prevent race conditions that could remove
the last admin. Includes idempotency check, audit logging, and
confirmation dialog with self-demotion warning.
2026-05-03 15:24:34 +06:00
4954b19d7c fix: merge capsule data in-place to prevent visual refresh on poll
Replaces full array assignment with granular merge that reuses existing
Svelte proxy objects, so only rows with actual data changes re-render.
2026-05-03 15:09:21 +06:00
01819642cc fix: drop page cache before snapshot to reduce memory dump size
Linux keeps freed memory as page cache, which Firecracker snapshots
as non-zero blocks. A 16GB VM with 12GB stale cache would write all
12GB to disk. Dropping pagecache (not dentries/inodes) in
/snapshot/prepare before blocking the reclaimer shrinks snapshots
to actual working set size with minimal resume latency impact.
2026-05-03 14:27:49 +06:00
cb28f7759d Merge pull request 'fix: accurate sandbox metrics and memory management' (#41) from bugfix/sandbox-metrics-calculations into dev
Reviewed-on: wrenn/wrenn#41
2026-05-03 06:41:41 +00:00
1178ab8b21 fix: accurate sandbox metrics and memory management
Three issues fixed:

1. Memory metrics read host-side VmRSS of the Firecracker process,
   which includes guest page cache and never decreases. Replaced
   readMemRSS(fcPID) with readEnvdMemUsed(client) that queries
   envd's /metrics endpoint for guest-side total - MemAvailable.
   This matches neofetch and reflects actual process memory.

2. Added Firecracker balloon device (deflate_on_oom, 5s stats) and
   envd-side periodic page cache reclaimer (drop_caches when >80%
   used). Reclaimer is gated by snapshot_in_progress flag with
   sync() before freeze to prevent memory corruption during pause.

3. Sampling interval 500ms → 1s, ring buffer capacities adjusted
   to maintain same time windows. Reduces per-host HTTP load from
   240 calls/sec to 120 calls/sec at 120 capsules.

Also: maxDiffGenerations 8 → 1 (merge every re-pause since UFFD
lazy-loads anyway), envd mem_used formula uses total - available.
2026-05-03 12:19:01 +06:00
233e747d5d Merge branch 'main' of git.omukk.dev:wrenn/wrenn into dev 2026-05-03 04:56:14 +06:00
f5a23c1fa0 v0.1.5 (#40)
Reviewed-on: wrenn/wrenn#40
2026-05-02 22:56:00 +00:00
20a228eb8d Merge pull request 'Rewritten envd with rust to improve reliability during pause and resume operations' (#39) from feat/envd-rewrite into dev
Reviewed-on: wrenn/wrenn#39
2026-05-02 22:49:36 +00:00
ef5f223863 fix: improve error feedback for terminal disconnects and host unavailability
Show "[session disconnected]" in terminal when PTY websocket closes cleanly.
Map scheduler and agent unavailability errors to 503 with user-friendly
message instead of leaking internal details.
2026-05-03 04:47:10 +06:00
31456fd169 fix: resolve PTY failure, MMDS file writes, and metrics instability in envd-rs
Three bugs fixed:

1. PTY connections failed because home directory was hardcoded as
   /home/{username} instead of reading from /etc/passwd. For root,
   this produced /home/root/ which doesn't exist — CWD validation
   rejected every PTY Start request without explicit cwd. Fixed all
   6 locations to use user.dir from nix::unistd::User.

2. MMDS polling silently failed to parse metadata because the
   logs_collector_address field lacked #[serde(default)]. The host
   agent only sends instanceID + envID — missing "address" field
   caused every deserialize attempt to fail, so .WRENN_SANDBOX_ID
   and .WRENN_TEMPLATE_ID were never written. Also added error
   logging and create_dir_all before file writes.

3. Metrics CPU values were non-deterministic because a fresh
   sysinfo::System was created per request with a 100ms sleep
   between reads. Replaced with a background thread that samples
   CPU at fixed 1-second intervals via a persistent System instance,
   matching gopsutil's internal caching behavior. Metrics endpoint
   now reads cached atomic values — no blocking, consistent window.

Also: close master PTY fd in child pre_exec, add process.Start
request logging, bump version to 0.2.0.
2026-05-03 04:28:10 +06:00
bbcde17d49 Updated static link check for envd 2026-05-03 03:32:41 +06:00
f328113a2a rename guest hostname from "sandbox" to "capsule"
Terminal prompt inside VMs now shows root@capsule instead of
root@sandbox, aligning with user-facing "capsule" terminology.
2026-05-03 03:32:03 +06:00
1143acd37a refactor: remove Go envd module, update host agent for Rust envd
The Go envd guest agent (`envd/`) is fully replaced by the Rust
implementation (`envd-rs/`). This commit removes the Go module and
updates all references across the codebase.

Makefile: remove ENVD_DIR, VERSION_ENVD, build-envd-go, dev-envd-go,
and Go envd from proto/fmt/vet/tidy/clean targets. Add static-link
verification to build-envd.

Host agent: rewrite snapshot quiesce comments that referenced Go GC
and page allocator corruption — no longer applicable with Rust envd.
Tighten envdclient to expect HTTP 200 (not 204) from health and file
upload endpoints, and require JSON version response from FetchVersion.

Remove NOTICE (no e2b-derived code remains). Update CLAUDE.md and
README.md to reflect Rust envd architecture.
2026-05-03 03:12:25 +06:00
0b53d34417 feat: rewrite envd guest agent in Rust (envd-rs)
Complete Rust rewrite of the Go envd guest daemon that runs as PID 1
inside Firecracker microVMs. Feature-complete across all 8 phases:

- Health, metrics, and env var endpoints
- Crypto (SHA-256/512, HMAC), auth (secure token, signing), init/snapshot
- Connect RPC via connectrpc + buffa (process + filesystem services)
- File transfer (GET/POST /files) with gzip, multipart, chown, ENOSPC
- Port subsystem (/proc/net/tcp scanner, socat forwarder)
- Cgroup2 manager with noop fallback
- Snapshot/restore lifecycle (conntracker, port subsystem stop/restart)
- SIGTERM graceful shutdown, --cmd initial process spawn
- MMDS metadata polling for Firecracker mode

42 source files, ~4200 LOC, 4.1MB stripped release binary.
Makefile updated: build-envd now targets Rust (musl static),
build-envd-go preserved for Go builds.
2026-05-03 02:47:15 +06:00
3deecbff89 fix: prevent Go runtime memory corruption and sandbox halt after snapshot restore
Three root causes addressed:

1. Go page allocator corruption: allocations between the pre-snapshot GC
   and VM freeze leave the summary tree inconsistent. After restore, GC
   reads corrupted metadata — either panicking (killing PID 1 → kernel
   panic) or silently failing to collect, causing unbounded heap growth
   until OOM. Fix: move GC to after all HTTP allocations in
   PostSnapshotPrepare, then set GOMAXPROCS(1) so any remaining
   allocations run sequentially with no concurrent page allocator access.
   GOMAXPROCS is restored on first health check after restore.

2. PostInit timeout starvation: WaitUntilReady and PostInit shared a
   single 30s context. If WaitUntilReady consumed most of it, PostInit
   failed — RestoreAfterSnapshot never ran, leaving envd with keep-alives
   disabled and zombie connections. Fix: separate timeout contexts.

3. CP HTTP server missing timeouts: no ReadHeaderTimeout or IdleTimeout
   caused goroutine leaks from hung proxy connections. Fix: add both,
   matching host agent values.

Also adds UFFD prefetch to proactively load all guest pages after restore,
eliminating on-demand page fault latency for subsequent RPC calls.
2026-05-02 17:22:51 +06:00
bb582deefa fix: prevent sandbox halt after resume by fixing HTTP/2 HOL blocking and adding timeouts
Disable HTTP/2 on both host agent server and CP→agent transport — multiplexing
caused head-of-line blocking when a slow sandbox RPC stalled the shared connection.
Add ResponseHeaderTimeout to envd HTTP clients. Merge SetDefaults into Resume's
PostInit call to eliminate an extra round-trip that could hang on a stale connection.
2026-05-02 13:48:51 +06:00
7ef9a64613 fix: close stale TCP connections across snapshot/restore to prevent envd hangs
After Firecracker snapshot restore, zombie TCP sockets from the previous
session cause Go runtime corruption inside the guest VM, making envd
unresponsive. This manifests as infinite loading in the file browser and
terminal timeouts (524) in production (HTTP/2 + Cloudflare) but not locally.

Four-part fix:
- Add ServerConnTracker to envd that tracks connections via ConnState callback,
  closes idle connections and disables keep-alives before snapshot, then closes
  all pre-snapshot zombie connections on restore (while preserving post-restore
  connections like the /init request)
- Split envdclient into timeout (2min) and streaming (no timeout) HTTP clients;
  use streaming client for file transfers and process RPCs
- Close host-side idle envdclient connections before PrepareSnapshot so FIN
  packets propagate during the 3s quiesce window
- Add StreamingHTTPClient() accessor; streaming file transfer handlers in
  hostagent use it instead of the timeout client
2026-05-02 05:19:37 +06:00
f3572f7356 Fix empty WRENN_TEMPLATE_ID after resuming paused sandbox
Resume() was building VMConfig without TemplateID, so Firecracker MMDS
received an empty string. envd's PostInit then wrote that empty value to
/run/wrenn/.WRENN_TEMPLATE_ID. Fix by persisting the template ID in
snapshot metadata during Pause and reading it back during Resume.
2026-05-02 04:57:08 +06:00
339 changed files with 23087 additions and 28684 deletions

View File

@ -16,7 +16,7 @@ WRENN_HOST_LISTEN_ADDR=:50051
WRENN_HOST_INTERFACE=eth0
WRENN_CP_URL=http://localhost:9725
WRENN_DEFAULT_ROOTFS_SIZE=5Gi
WRENN_FIRECRACKER_BIN=/usr/local/bin/firecracker
WRENN_CH_BIN=/usr/local/bin/cloud-hypervisor
# Auth
JWT_SECRET=

9
.gitignore vendored
View File

@ -36,9 +36,13 @@ go.work.sum
e2b/
.impeccable.md
.gstack
.mcp.json
## Builds
builds/
/builds/
## Rust
envd-rs/target/
## Frontend
frontend/node_modules/
@ -49,3 +53,6 @@ frontend/build/
internal/dashboard/static/*
!internal/dashboard/static/.gitkeep.dual-graph/
.dual-graph/
# Added by code-review-graph
.code-review-graph/
.mcp.json

160
CLAUDE.md
View File

@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Project Overview
Wrenn Sandbox is a microVM-based code execution platform. Users create isolated sandboxes (Firecracker microVMs), run code inside them, and get output back via SDKs. Think E2B but with persistent sandboxes, pool-based pricing, and a single-binary deployment story.
Wrenn Sandbox is a microVM-based code execution platform. Users create isolated sandboxes (Cloud Hypervisor microVMs), run code inside them, and get output back via SDKs. Think E2B but with persistent sandboxes, pool-based pricing, and a single-binary deployment story.
## Build & Development Commands
@ -14,7 +14,7 @@ All commands go through the Makefile. Never use raw `go build` or `go run`.
make build # Build all binaries → builds/
make build-cp # Control plane only
make build-agent # Host agent only
make build-envd # envd static binary (verified statically linked)
make build-envd # envd static binary (Rust, musl, verified statically linked)
make build-frontend # SvelteKit dashboard → frontend/build/ (served by Caddy)
make dev # Full local dev: infra + migrate + control plane
@ -23,13 +23,13 @@ make dev-down # Stop dev infra
make dev-cp # Control plane with hot reload (if air installed)
make dev-frontend # Vite dev server with HMR (port 5173)
make dev-agent # Host agent (sudo required)
make dev-envd # envd in TCP debug mode
make dev-envd # envd in debug mode (port 49983)
make check # fmt + vet + lint + test (CI order)
make test # Unit tests: go test -race -v ./internal/...
make test-integration # Integration tests (require host agent + Firecracker)
make fmt # gofmt both modules
make vet # go vet both modules
make test-integration # Integration tests (require host agent + Cloud Hypervisor)
make fmt # gofmt
make vet # go vet
make lint # golangci-lint
make migrate-up # Apply pending migrations
@ -38,8 +38,8 @@ make migrate-create name=xxx # Scaffold new goose migration (never create manua
make migrate-reset # Drop + re-apply all
make generate # Proto (buf) + sqlc codegen
make proto # buf generate for all proto dirs
make tidy # go mod tidy both modules
make proto # buf generate for proto dirs
make tidy # go mod tidy
```
Run a single test: `go test -race -v -run TestName ./internal/path/...`
@ -50,15 +50,15 @@ Run a single test: `go test -race -v -run TestName ./internal/path/...`
User SDK → HTTPS/WS → Control Plane → Connect RPC → Host Agent → HTTP/Connect RPC over TAP → envd (inside VM)
```
**Three binaries, two Go modules:**
**Three binaries:**
| Binary | Module | Entry point | Runs as |
|--------|--------|-------------|---------|
| wrenn-cp | `git.omukk.dev/wrenn/wrenn` | `cmd/control-plane/main.go` | Unprivileged |
| wrenn-agent | `git.omukk.dev/wrenn/wrenn` | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
| envd | `git.omukk.dev/wrenn/wrenn/envd` (standalone `envd/go.mod`) | `envd/main.go` | PID 1 inside guest VM |
| Binary | Language | Entry point | Runs as |
|--------|----------|-------------|---------|
| wrenn-cp | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/control-plane/main.go` | Unprivileged |
| wrenn-agent | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
| envd | Rust (`envd-rs/`) | `envd-rs/src/main.rs` | PID 1 inside guest VM |
envd is a **completely independent Go module**. It is never imported by the main module. The only connection is the protobuf contract. It compiles to a static binary baked into rootfs images.
envd is a standalone Rust binary (Tokio + Axum + connectrpc-rs). It is completely independent from the Go module — the only connection is the protobuf contract. It compiles to a statically linked musl binary baked into rootfs images.
**Key architectural invariant:** The host agent is **stateful** (in-memory `boxes` map is the source of truth for running VMs). The control plane is **stateless** (all persistent state in PostgreSQL). The reconciler (`internal/api/reconciler.go`) bridges the gap — it periodically compares DB records against the host agent's live state and marks orphaned sandboxes as "stopped".
@ -66,11 +66,22 @@ envd is a **completely independent Go module**. It is never imported by the main
**Internal packages:** `internal/api/`, `internal/email/`
**Public packages (importable by cloud repo):** `pkg/config/`, `pkg/db/`, `pkg/auth/`, `pkg/auth/oauth/`, `pkg/scheduler/`, `pkg/lifecycle/`, `pkg/channels/`, `pkg/audit/`, `pkg/service/`, `pkg/events/`, `pkg/id/`, `pkg/validate/`
**Public packages (importable by cloud repo):** `pkg/config/`, `pkg/db/`, `pkg/auth/`, `pkg/auth/oauth/`, `pkg/auth/session/`, `pkg/auth/session/middleware/`, `pkg/scheduler/`, `pkg/lifecycle/`, `pkg/channels/`, `pkg/audit/`, `pkg/service/`, `pkg/events/`, `pkg/id/`, `pkg/validate/`
**Extension framework:** `pkg/cpextension/` (shared `Extension` interface + `ServerContext`), `pkg/cpserver/` (exported `Run()` entrypoint with functional options for cloud `main.go`)
**Extension framework:** `pkg/cpextension/` (shared `Extension` interface + `ServerContext` + hook interfaces), `pkg/cpserver/` (exported `Run()` entrypoint with functional options for cloud `main.go`)
The cloud repo imports this module as a Go dependency and calls `cpserver.Run(cpserver.WithExtensions(myExt))`. Each extension implements two methods: `RegisterRoutes(r chi.Router, sctx ServerContext)` to add HTTP routes, and `BackgroundWorkers(sctx ServerContext) []func(context.Context)` to add long-running goroutines. `ServerContext` carries all OSS services (DB, scheduler, auth, etc.) so extensions can use them without reimplementing anything. To expose a new OSS service to extensions, add it to `ServerContext` in `pkg/cpextension/extension.go` and populate it in `pkg/cpserver/run.go`.
The cloud repo imports this module as a Go dependency and calls `cpserver.Run(cpserver.WithExtensions(myExt))`. An extension always implements:
- `RegisterRoutes(r chi.Router, sctx ServerContext)` — adds HTTP routes.
- `BackgroundWorkers(sctx ServerContext) []func(context.Context)` — starts long-running goroutines.
It can optionally implement any of these hook interfaces (the OSS server type-asserts at startup):
- `MiddlewareProvider``Middlewares(sctx) []func(http.Handler) http.Handler`, applied before OSS routes so cloud middleware can wrap them (e.g. billing gates).
- `AuthHook``OnSignup` (synchronous, error aborts the request with 500 `signup_hook_failed`), `OnLogin`, `OnAccountSoftDelete`, `OnAccountHardDelete` (the last three log + ignore errors). OnSignup fires after team provisioning in both email-activate and OAuth-new-signup paths.
- `SandboxEventHook``OnSandboxEvent(ctx, SandboxEvent)`, invoked from the unified Redis stream consumer for capsule create/pause/resume/destroy success events. Hook errors leave the message un-acked so it will be redelivered; hooks must be idempotent.
`ServerContext` carries the initialized OSS dependencies: `Queries`, `PgPool`, `Redis`, `HostPool`, `Scheduler`, `CA`, `Audit`, `Mailer`, `OAuthRegistry`, `Channels`, `ChannelPub`, `JWTSecret`, `Sessions`, `Config`. To expose a new OSS service to extensions, add it to `ServerContext` in `pkg/cpextension/extension.go` and populate it in `pkg/cpserver/run.go`.
**Auth helpers for extensions** (`pkg/auth/session/middleware/`, re-exported via `cpextension`): `RequireSession(sctx)`, `RequireSessionOrAPIKey(sctx)`, `RequireAdmin(sctx)`, `RequireCSRF()`, `IssueSession(w, r, sctx, userID, teamID)`, `ClearSessionCookies(w, r)`. Cookie/header names are exported as `SessionCookieName`, `CSRFCookieName`, `CSRFHeaderName`. OSS handlers (`internal/api/middleware_session.go`) are thin shims over this package — single source of truth.
**pkg/ vs internal/ decision rule:** A package belongs in `pkg/` only if the cloud repo needs to import it directly. Everything else stays in `internal/`. New OSS services (e.g. email, notifications) go in `internal/` — the cloud repo accesses them through `ServerContext`, not by importing the package. Do not put a service in `pkg/` just because the cloud repo uses it.
@ -86,33 +97,37 @@ Startup (`cmd/control-plane/main.go`) is a thin wrapper: `cpserver.Run(cpserver.
**Packages:** `internal/hostagent/`, `internal/sandbox/`, `internal/vm/`, `internal/network/`, `internal/devicemapper/`, `internal/envdclient/`, `internal/snapshot/`
**Production deployment:** `scripts/prepare-wrenn-user.sh` creates the `wrenn` system user, sets Linux capabilities (setcap) on wrenn-agent and all child binaries (iptables, losetup, dmsetup, etc.), installs an apt hook to restore capabilities after package updates, configures udev rules for `/dev/net/tun`, loads required kernel modules, and writes systemd unit files for both services. No sudo grants — all privilege is via capabilities.
**Production deployment:** `make setup-host` (→ `scripts/setup-host.sh`) prepares the host: creates the `wrenn` system user, sets Linux capabilities (setcap) on wrenn-agent and all child binaries (iptables, losetup, dmsetup, etc.), installs an apt hook to restore capabilities after package updates, configures udev rules for `/dev/net/tun`, and loads required kernel modules. No sudo grants — all privilege is via capabilities. `make install` then copies the binaries to `/usr/local/bin` and installs the systemd units from `deploy/systemd/`.
Startup (`cmd/host-agent/main.go`) wires: root/capabilities check → enable IP forwarding → clean up stale dm devices → `sandbox.Manager` (containing `vm.Manager` + `network.SlotAllocator` + `devicemapper.LoopRegistry`) → `hostagent.Server` (Connect RPC handler) → HTTP server.
- **RPC Server** (`internal/hostagent/server.go`): implements `hostagentv1connect.HostAgentServiceHandler`. Thin wrapper — every method delegates to `sandbox.Manager`. Maps Connect error codes on return.
- **Sandbox Manager** (`internal/sandbox/manager.go`): the core orchestration layer. Maintains in-memory state in `boxes map[string]*sandboxState` (protected by `sync.RWMutex`). Each `sandboxState` holds a `models.Sandbox`, a `*network.Slot`, and an `*envdclient.Client`. Runs a TTL reaper (every 10s) that auto-destroys timed-out sandboxes.
- **VM Manager** (`internal/vm/manager.go`, `fc.go`, `config.go`): manages Firecracker processes. Uses raw HTTP API over Unix socket (`/tmp/fc-{sandboxID}.sock`), not the firecracker-go-sdk Machine type. Launches Firecracker via `unshare -m` + `ip netns exec`. Configures VM via PUT to `/boot-source`, `/drives/rootfs`, `/network-interfaces/eth0`, `/machine-config`, then starts with PUT `/actions`.
- **VM Manager** (`internal/vm/manager.go`, `ch.go`, `config.go`): manages Cloud Hypervisor processes. Uses raw HTTP API over Unix socket (`/tmp/ch-{sandboxID}.sock`). Launches Cloud Hypervisor via `unshare -m` + `ip netns exec` with `--api-socket path=...`. Configures and boots VM via `PUT /vm.create` + `PUT /vm.boot`. Snapshot restore uses `--restore source_url=file://...`.
- **Network** (`internal/network/setup.go`, `allocator.go`): per-sandbox network namespace with veth pair + TAP device. See Networking section below.
- **Device Mapper** (`internal/devicemapper/devicemapper.go`): CoW rootfs via device-mapper snapshots. Shared read-only loop devices per base template (refcounted `LoopRegistry`), per-sandbox sparse CoW files, dm-snapshot create/restore/remove/flatten operations.
- **envd Client** (`internal/envdclient/client.go`, `health.go`): dual interface to the guest agent. Connect RPC for streaming process exec (`process.Start()` bidirectional stream). Plain HTTP for file operations (POST/GET `/files?path=...&username=root`). Health check polls `GET /health` every 100ms until ready (30s timeout).
### envd (Guest Agent)
**Module:** `envd/` with its own `go.mod` (`git.omukk.dev/wrenn/wrenn/envd`)
**Directory:** `envd-rs/` — standalone Rust crate
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd). Extracted from E2B (Apache 2.0), with shared packages internalized into `envd/internal/shared/`. Listens on TCP `0.0.0.0:49983`.
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd via tini). Built with `cargo build --release --target x86_64-unknown-linux-musl`. Listens on TCP `0.0.0.0:49983`.
- **ProcessService**: start processes, stream stdout/stderr, signal handling, PTY support
- **FilesystemService**: stat/list/mkdir/move/remove/watch files
- **Health**: GET `/health`
- **Stack**: Tokio (async runtime) + Axum (HTTP) + connectrpc-rs (Connect protocol RPC)
- **ProcessService** (Connect RPC): start/connect/list/signal processes, stream stdout/stderr, PTY support
- **FilesystemService** (Connect RPC): stat/list/mkdir/move/remove/watch files
- **HTTP endpoints**: GET `/health`, GET `/metrics`, POST `/init`, POST `/snapshot/prepare`, GET/POST `/files`
- **Proto codegen**: `connectrpc-build` compiles `proto/envd/*.proto` at `cargo build` time via `build.rs` — no committed stubs
- **Build**: `make build-envd` → static musl binary in `builds/envd`
- **Dev**: `make dev-envd``cargo run -- --port 49983`
### Dashboard (Frontend)
**Directory:** `frontend/` — standalone SvelteKit app (Svelte 5, runes mode)
- **Stack**: SvelteKit + `adapter-static` + Tailwind CSS v4 + Bits UI (headless accessible components)
- **Package manager**: pnpm
- **Package manager**: Bun
- **Routing**: SvelteKit file-based routing under `frontend/src/routes/`
- **Routing layout**: `/login` and `/signup` at root, authenticated pages under `/dashboard/*` (e.g. `/dashboard/capsules`, `/dashboard/keys`)
- **Build output**: `frontend/build/` — static files served by Caddy
@ -160,7 +175,7 @@ HIBERNATED → RUNNING (cold snapshot resume, slower)
**Sandbox creation** (`POST /v1/capsules`):
1. API handler generates sandbox ID, inserts into DB as "pending"
2. RPC `CreateSandbox` → host agent → `sandbox.Manager.Create()`
3. Manager: resolve base rootfs → acquire shared loop device → create dm-snapshot (sparse CoW file) → allocate network slot → `CreateNetwork()` (netns + veth + tap + NAT) → `vm.Create()` (start Firecracker with `/dev/mapper/wrenn-{id}`, configure via HTTP API, boot) → `envdclient.WaitUntilReady()` (poll /health) → store in-memory state
3. Manager: resolve base rootfs → acquire shared loop device → create dm-snapshot (sparse CoW file) → allocate network slot → `CreateNetwork()` (netns + veth + tap + NAT) → `vm.Create()` (start Cloud Hypervisor with `/dev/mapper/wrenn-{id}`, configure via `PUT /vm.create` + `PUT /vm.boot`) → `envdclient.WaitUntilReady()` (poll /health) → store in-memory state
4. API handler updates DB to "running" with host_ip
**Command execution** (`POST /v1/capsules/{id}/exec`):
@ -179,23 +194,40 @@ HIBERNATED → RUNNING (cold snapshot resume, slower)
## REST API
Routes defined in `internal/api/server.go`, handlers in `internal/api/handlers_*.go`. OpenAPI spec embedded via `//go:embed` and served at `/openapi.yaml` (Swagger UI at `/docs`). JSON request/response. API key auth via `X-API-Key` header. Error responses: `{"error": {"code": "...", "message": "..."}}`.
Routes defined in `internal/api/server.go`, handlers in `internal/api/handlers_*.go`. OpenAPI spec embedded via `//go:embed` and served at `/openapi.yaml` (Swagger UI at `/docs`). JSON request/response. Error responses: `{"error": {"code": "...", "message": "..."}}`.
### Authentication
Two paths, no JWTs for user auth:
- **SDK / server-to-server**: `X-API-Key: wrn_<32hex>` header. Keys are created via the dashboard or `POST /v1/api-keys`, SHA-256-hashed at rest, scoped to a single team. `requireSessionOrAPIKey` middleware accepts this on every capsule lifecycle route.
- **Browser (dashboard)**: opaque cookie session. `POST /v1/auth/login` (or activate / oauth-callback) sets two cookies — `wrenn_sid` (HttpOnly+Secure+SameSite=Strict) and `wrenn_csrf` (readable by JS, SameSite=Strict). All non-GET requests must echo the CSRF token in an `X-CSRF-Token` header (double-submit). Sessions live in Postgres `sessions` plus a Redis cache (`wrenn:session:{sid}`) — see `pkg/auth/session/`.
Session semantics:
- **Idle**: 6h. Each request bumps the Redis TTL. Postgres `last_seen_at` is updated on a debounce.
- **Absolute**: 24h from `created_at` (`expires_at`). Never extended.
- **Rotation**: `POST /v1/auth/switch-team` issues a fresh SID and re-sets both cookies; old SID revoked.
- **Revocation**: password change, password add, and password reset all call `RevokeAllForUser`. Self-service is at `GET /v1/me/sessions`, `DELETE /v1/me/sessions/{id}`, `POST /v1/auth/logout`, `POST /v1/auth/logout-all`.
- **`is_admin` freshness**: the session blob caches it for display, but `requireAdmin` always re-reads Postgres so revoked admins lose access at the next admin request.
Host JWTs (long-lived, signed by `JWT_SECRET`) are unchanged — that is the wrenn-cp ↔ wrenn-agent trust channel and has nothing to do with user auth.
SSE / WebSocket auth: browsers send the `wrenn_sid` cookie automatically on `EventSource` and WS upgrades (same-origin via Caddy); SDKs set `X-API-Key`. No ticket exchange is involved.
## Code Generation
### Proto (Connect RPC)
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate. Three `buf.gen.yaml` files control output:
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate Go stubs. Two `buf.gen.yaml` files control Go output:
| buf.gen.yaml location | Generates to | Used by |
|---|---|---|
| `proto/envd/buf.gen.yaml` | `proto/envd/gen/` | Main module (host agent's envd client) |
| `proto/hostagent/buf.gen.yaml` | `proto/hostagent/gen/` | Main module (control plane ↔ host agent) |
| `envd/spec/buf.gen.yaml` | `envd/internal/services/spec/` | envd module (guest agent server) |
The envd `buf.gen.yaml` reads from `../../proto/envd/` (same source protos) but generates into envd's own module. This means the same `.proto` files produce two independent sets of Go stubs — one for each Go module.
The Rust envd (`envd-rs/`) generates its own protobuf stubs at `cargo build` time via `connectrpc-build` in `envd-rs/build.rs`, reading from the same `proto/envd/*.proto` sources. No committed Rust stubs — they live in `OUT_DIR`.
To add a new RPC method: edit the `.proto` file → `make proto` → implement the handler on both sides.
To add a new RPC method: edit the `.proto` file → `make proto` (Go stubs) → rebuild envd-rs (Rust stubs generated automatically) → implement the handler on both sides.
### sqlc
@ -206,10 +238,10 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
## Key Technical Decisions
- **Connect RPC** (not gRPC) for all RPC communication between components
- **Buf + protoc-gen-connect-go** for code generation (not protoc-gen-go-grpc)
- **Raw Firecracker HTTP API** via Unix socket (not firecracker-go-sdk Machine type)
- **Buf + protoc-gen-connect-go** for Go code generation; **connectrpc-build** for Rust code generation in envd
- **Raw Cloud Hypervisor HTTP API** via Unix socket (`PUT /vm.create` + `PUT /vm.boot`)
- **TAP networking** (not vsock) for host-to-envd communication
- **Device-mapper snapshots** for rootfs CoW — shared read-only loop device per base template, per-sandbox sparse CoW file, Firecracker gets `/dev/mapper/wrenn-{id}`
- **Device-mapper snapshots** for rootfs CoW — shared read-only loop device per base template, per-sandbox sparse CoW file, Cloud Hypervisor gets `/dev/mapper/wrenn-{id}`
- **PostgreSQL** via pgx/v5 + sqlc (type-safe query generation). Goose for migrations (plain SQL, up/down)
- **Dashboard**: SvelteKit (Svelte 5, adapter-static) + Tailwind CSS v4 + Bits UI. Built to static files in `frontend/build/`, served by Caddy (not embedded in the Go binary)
- **Lago** for billing (external service, not in this codebase)
@ -218,39 +250,36 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
- **Go style**: `gofmt`, `go vet`, `context.Context` everywhere, errors wrapped with `fmt.Errorf("action: %w", err)`, `slog` for logging, no global state
- **Naming**: Sandbox IDs `sb-` + 8 hex, API keys `wrn_` + 32 chars, Host IDs `host-` + 8 hex
- **Dependencies**: Use `go get` to add deps, never hand-edit go.mod. For envd deps: `cd envd && go get ...` (separate module)
- **Dependencies**: Use `go get` to add Go deps, never hand-edit go.mod. For envd-rs deps: edit `envd-rs/Cargo.toml`
- **Generated code**: Always commit generated code (proto stubs, sqlc). Never add generated code to .gitignore
- **Migrations**: Always use `make migrate-create name=xxx`, never create migration files manually
- **Testing**: Table-driven tests for handlers and state machine transitions
### Two-module gotcha
The main module (`go.mod`) and envd (`envd/go.mod`) are fully independent. `make tidy`, `make fmt`, `make vet` already operate on both. But when adding dependencies manually, remember to target the correct module (`cd envd && go get ...` for envd deps). `make proto` also generates stubs for both modules from the same proto sources.
## Rootfs & Guest Init
- **wrenn-init** (`images/wrenn-init.sh`): the PID 1 init script baked into every rootfs. Mounts virtual filesystems, sets hostname, writes `/etc/resolv.conf`, then execs envd.
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-debug-rootfs.sh [rootfs_path]`. This builds envd via `make build-envd`, mounts the rootfs image, copies in the new binaries, and unmounts. Defaults to `/var/lib/wrenn/images/minimal.ext4`.
- Rootfs images are minimal debootstrap — no systemd, no coreutils beyond busybox. Use `/bin/sh -c` for shell builtins inside the guest.
- **System base templates**: four built-in distro images — `minimal-ubuntu` (id 0, default), `minimal-alpine` (1), `minimal-arch` (2), `minimal-fedora` (3) — built via `images/build-{ubuntu,alpine,arch,fedora}.sh` (or `make images`). All platform-owned, protected from deletion (reserved IDs 01024). Same static envd + tini run on all four. Each has a `wrenn-user` with passwordless sudo.
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-minimal-rootfs.sh`. Builds envd via `make build-envd` (Rust → static musl binary), then re-injects envd + wrenn-init + tini into all four system base images.
- Rootfs images are built from distro containers — no systemd (init is overridden to `wrenn-init`). Use `/bin/sh -c` for shell builtins inside the guest.
## Fixed Paths (on host machine)
- Kernel: `/var/lib/wrenn/kernels/vmlinux`
- Base rootfs images: `/var/lib/wrenn/images/{template}.ext4`
- Base rootfs images: `/var/lib/wrenn/images/teams/{base36(teamID)}/{base36(templateID)}/rootfs.ext4` (system templates use the platform team, base36 all-zeros)
- Sandbox clones: `/var/lib/wrenn/sandboxes/`
- Firecracker: `/usr/local/bin/firecracker` (e2b's fork of firecracker)
- Cloud Hypervisor: `/usr/local/bin/cloud-hypervisor`
## Design Context
### Users
Developers across the full spectrum — solo engineers building side projects, startup teams integrating sandboxed execution into products, and platform/infra engineers at larger organizations running production workloads on Firecracker microVMs. They arrive with context: they know what a process is, what a rootfs is, what a TTY means. The interface must feel at home for all three: approachable enough not to intimidate a hacker, precise enough to earn the trust of a production ops team. Never condescend, never oversimplify. Trust the user to understand what they're looking at.
Developers across the full spectrum — solo engineers building side projects, startup teams integrating sandboxed execution into products, and platform/infra engineers at larger organizations running production workloads on Cloud Hypervisor microVMs. They arrive with context: they know what a process is, what a rootfs is, what a TTY means. The interface must feel at home for all three: approachable enough not to intimidate a hacker, precise enough to earn the trust of a production ops team. Never condescend, never oversimplify. Trust the user to understand what they're looking at.
**Primary job to be done:** Understand what's running, act on it confidently, and get back to code.
### Brand Personality
**Precise. Warm. Uncompromising.**
Wrenn is an engineer's favorite tool — built with visible care, not assembled from defaults. It runs real infrastructure (Firecracker microVMs), so the UI should reflect that seriousness without becoming cold or corporate. The warmth comes from the typography and color palette; the precision comes from hierarchy, density, and data fidelity.
Wrenn is an engineer's favorite tool — built with visible care, not assembled from defaults. It runs real infrastructure (Cloud Hypervisor microVMs), so the UI should reflect that seriousness without becoming cold or corporate. The warmth comes from the typography and color palette; the precision comes from hierarchy, density, and data fidelity.
Emotional goal: **in control.** Users leave a session with full confidence in what's running, what happened, and what comes next. Nothing is hidden, nothing is ambiguous.
@ -372,3 +401,42 @@ All values are CSS custom properties in `frontend/src/app.css`.
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
<!-- code-review-graph MCP tools -->
## MCP Tools: code-review-graph
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
you structural context (callers, dependents, test coverage) that file
scanning cannot.
### When to use graph tools FIRST
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
- **Architecture questions**: `get_architecture_overview` + `list_communities`
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
### Key Tools
| Tool | Use when |
|------|----------|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
| `get_review_context` | Need source snippets for review — token-efficient |
| `get_impact_radius` | Understanding blast radius of a change |
| `get_affected_flows` | Finding which execution paths are impacted |
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
| `get_architecture_overview` | Understanding high-level codebase structure |
| `refactor_tool` | Planning renames, finding dead code |
### Workflow
1. The graph auto-updates on file changes (via hooks).
2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage.

View File

@ -2,12 +2,10 @@
# Variables
# ═══════════════════════════════════════════════════
DATABASE_URL ?= postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
GOBIN := $(shell pwd)/builds
ENVD_DIR := envd
BIN_DIR := $(shell pwd)/builds
COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
VERSION_CP := $(shell cat VERSION_CP 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
VERSION_AGENT := $(shell cat VERSION_AGENT 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
VERSION_ENVD := $(shell cat envd/VERSION 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
LDFLAGS := -s -w
# ═══════════════════════════════════════════════════
@ -18,19 +16,23 @@ LDFLAGS := -s -w
build: build-cp build-agent build-envd
build-frontend:
cd frontend && pnpm install --frozen-lockfile && pnpm build
cd frontend && bun install --frozen-lockfile && bun run build
build-cp:
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-cp ./cmd/control-plane
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-cp ./cmd/control-plane
build-agent:
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-agent ./cmd/host-agent
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-agent ./cmd/host-agent
build-envd:
cd $(ENVD_DIR) && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
go build -ldflags="$(LDFLAGS) -X main.Version=$(VERSION_ENVD) -X main.commitSHA=$(COMMIT)" -o $(GOBIN)/envd .
@file $(GOBIN)/envd | grep -q "statically linked" || \
(echo "ERROR: envd is not statically linked!" && exit 1)
cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl
@cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd
@readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \
readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \
! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \
{ ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \
readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \
(echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1)
# ═══════════════════════════════════════════════════
# Development
@ -57,11 +59,10 @@ dev-agent:
sudo go run ./cmd/host-agent
dev-frontend:
cd frontend && pnpm dev --port 5173 --host 0.0.0.0
cd frontend && bun run dev --port 5173 --host 0.0.0.0
dev-envd:
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
cd envd-rs && cargo run -- --port 49983
# ═══════════════════════════════════════════════════
# Database (goose)
@ -94,7 +95,6 @@ generate: proto sqlc
proto:
cd proto/envd && buf generate
cd proto/hostagent && buf generate
cd $(ENVD_DIR)/spec && buf generate
sqlc:
sqlc generate
@ -106,17 +106,16 @@ sqlc:
fmt:
gofmt -w .
cd $(ENVD_DIR) && gofmt -w .
lint:
golangci-lint run ./...
vet:
go vet ./...
cd $(ENVD_DIR) && go vet ./...
test:
go test -race -v ./internal/...
cd envd-rs && cargo test
test-integration:
go test -race -v -tags=integration ./tests/integration/...
@ -125,7 +124,6 @@ test-all: test test-integration
tidy:
go mod tidy
cd $(ENVD_DIR) && go mod tidy
## Run all quality checks in CI order
check: fmt vet lint test
@ -133,32 +131,24 @@ check: fmt vet lint test
# ═══════════════════════════════════════════════════
# Rootfs Images
# ═══════════════════════════════════════════════════
.PHONY: images image-minimal image-python image-node
.PHONY: images rootfs-ubuntu rootfs-alpine rootfs-arch rootfs-fedora
images: build-envd image-minimal image-python image-node
# Build all four system base rootfs images (ubuntu/alpine/arch/fedora). Each
# spawns a distro container, installs the required packages + wrenn-user, then
# exports to images/teams/<platform>/<id>/rootfs.ext4. Requires docker + sudo.
images: rootfs-ubuntu rootfs-alpine rootfs-arch rootfs-fedora
image-minimal:
sudo bash images/templates/minimal/build.sh
rootfs-ubuntu:
bash images/build-ubuntu.sh
image-python:
sudo bash images/templates/python312/build.sh
rootfs-alpine:
bash images/build-alpine.sh
image-node:
sudo bash images/templates/node20/build.sh
rootfs-arch:
bash images/build-arch.sh
# ═══════════════════════════════════════════════════
# Deployment
# ═══════════════════════════════════════════════════
.PHONY: setup-host install
setup-host:
sudo bash scripts/setup-host.sh
install: build
sudo cp $(GOBIN)/wrenn-cp /usr/local/bin/
sudo cp $(GOBIN)/wrenn-agent /usr/local/bin/
sudo cp deploy/systemd/*.service /etc/systemd/system/
sudo systemctl daemon-reload
rootfs-fedora:
bash images/build-fedora.sh
# ═══════════════════════════════════════════════════
# Clean
@ -167,7 +157,7 @@ install: build
clean:
rm -rf builds/
cd $(ENVD_DIR) && rm -f envd
cd envd-rs && cargo clean
# ═══════════════════════════════════════════════════
# Help
@ -183,11 +173,11 @@ help:
@echo " make dev-cp Control plane (hot reload if air installed)"
@echo " make dev-frontend Vite dev server with HMR (port 5173)"
@echo " make dev-agent Host agent (sudo required)"
@echo " make dev-envd envd in TCP debug mode"
@echo " make dev-envd envd in debug mode (port 49983)"
@echo ""
@echo " make build Build all binaries → builds/"
@echo " make build-frontend Build SvelteKit dashboard → frontend/build/"
@echo " make build-envd Build envd static binary"
@echo " make build-envd Build envd static binary (Rust, musl)"
@echo ""
@echo " make migrate-up Apply migrations"
@echo " make migrate-create name=xxx New migration"

19
NOTICE
View File

@ -1,19 +0,0 @@
Wrenn Sandbox
Copyright (c) 2026 M/S Omukk, Bangladesh
This project includes software derived from the following project:
Project: e2b infra
Repository: https://github.com/e2b-dev/infra
The following files and directories in this repository contain code derived from the above project:
- envd/
- proto/envd/*.proto
- internal/snapshot/
- internal/uffd/
Modifications to this code were made by M/S Omukk.
Copyright (c) 2023 FoundryLabs, Inc.
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh

222
README.md
View File

@ -5,10 +5,11 @@ Secure infrastructure for AI
## Prerequisites
- Linux host with `/dev/kvm` access (bare metal or nested virt)
- Firecracker binary at `/usr/local/bin/firecracker`
- Cloud Hypervisor binary at `/usr/local/bin/cloud-hypervisor`
- PostgreSQL
- Go 1.25+
- pnpm (for frontend)
- Rust 1.88+ with `x86_64-unknown-linux-musl` target (`rustup target add x86_64-unknown-linux-musl`)
- Bun (for frontend)
- Docker (for dev infra and rootfs builds)
## Build
@ -21,7 +22,7 @@ Produces three binaries: `wrenn-cp` (control plane), `wrenn-agent` (host agent),
## Host setup
The host agent needs a kernel, a minimal rootfs image, and working directories on the host machine.
The host agent needs a kernel, the system base rootfs images, and working directories on the host machine.
### Directory structure
@ -30,59 +31,74 @@ The host agent needs a kernel, a minimal rootfs image, and working directories o
├── kernels/
│ └── vmlinux # uncompressed Linux kernel (not bzImage)
├── images/
│ └── minimal/
│ └── rootfs.ext4 # base rootfs (all other templates snapshot from this)
│ └── teams/
│ └── 0000000000000000000000000/ # platform team (base36 all-zeros)
│ ├── 0000000000000000000000000/rootfs.ext4 # minimal-ubuntu (id 0)
│ ├── 0000000000000000000000001/rootfs.ext4 # minimal-alpine (id 1)
│ ├── 0000000000000000000000002/rootfs.ext4 # minimal-arch (id 2)
│ └── 0000000000000000000000003/rootfs.ext4 # minimal-fedora (id 3)
├── sandboxes/ # per-sandbox CoW files (created at runtime)
└── snapshots/ # pause/hibernate snapshot files (created at runtime)
```
Create the directories:
Create the base directories (the per-template image dirs are created by the build scripts):
```bash
sudo mkdir -p /var/lib/wrenn/{kernels,images/minimal,sandboxes,snapshots}
sudo mkdir -p /var/lib/wrenn/{kernels,images,sandboxes,snapshots}
```
### Kernel
Place an uncompressed `vmlinux` kernel at `/var/lib/wrenn/kernels/vmlinux`. Versioned kernels (`vmlinux-{semver}`) are also supported — the agent picks the latest by semver.
### Minimal rootfs
### System base rootfs images
The minimal rootfs is the base image that all other templates (Python, Node, etc.) are built on top of via device-mapper snapshots. It must contain:
There are four built-in **system base templates** — one per distro — that all other
templates snapshot from via device-mapper. They are platform-owned (visible to every
team) and protected from deletion (reserved template IDs 01024):
| Template | Distro | ID |
|----------|--------|----|
| `minimal-ubuntu` | `ubuntu:26.04` | 0 |
| `minimal-alpine` | `alpine:3.22` | 1 |
| `minimal-arch` | `archlinux:base` | 2 |
| `minimal-fedora` | `fedora:45` | 3 |
`minimal-ubuntu` is the default template for new sandboxes and builds. The same
statically-linked `envd` + `tini` run on all four regardless of the distro's libc
(glibc on Ubuntu/Arch/Fedora, musl on Alpine).
Each image contains these packages plus a `wrenn-user` account with passwordless `sudo`:
| Package | Why |
|---------|-----|
| `socat` | Bidirectional relay for port forwarding |
| `chrony` | Time sync from KVM PTP clock (`/dev/ptp0`) |
| `tini` | PID 1 zombie reaper (injected by build script, not apt) |
| `iproute2` (`iproute` on Fedora) | `ip` for guest network setup in `wrenn-init` |
| `tini` | PID 1 zombie reaper |
| `sudo` | User privilege management inside the guest |
| `wget` | HTTP fetching |
| `curl` | HTTP client |
| `ca-certificates` | TLS certificate verification |
| `git` | Version control |
**To build a rootfs from a Docker container:**
**To build all four images** (each spawns a distro container, installs the packages +
`wrenn-user`, builds `envd`, injects `wrenn-init` + `tini`, and exports to the
team-scoped path). Requires Docker + sudo:
1. Create and configure a container with the required packages:
```bash
docker run -it --name wrenn-minimal debian:bookworm bash
# Inside the container:
apt update && apt install -y socat chrony sudo wget curl ca-certificates
exit
make images
```
2. Export to a rootfs image (builds envd, injects wrenn-init + tini, shrinks to minimum size):
```bash
sudo bash scripts/rootfs-from-container.sh wrenn-minimal minimal
```
Or build a single distro: `make rootfs-ubuntu` / `rootfs-alpine` / `rootfs-arch` / `rootfs-fedora`.
**To update an existing rootfs** after changing envd or `wrenn-init.sh`:
**To update the images** after changing `envd` or `wrenn-init.sh` (rebuilds `envd` once,
then re-injects `envd` + `wrenn-init` + `tini` into every system base image):
```bash
bash scripts/update-minimal-rootfs.sh
```
This rebuilds envd via `make build-envd` and copies the fresh binaries into the mounted rootfs image.
### IP forwarding
```bash
@ -119,14 +135,7 @@ make check # fmt + vet + lint + test
Hosts must be registered with the control plane before they can serve sandboxes.
1. **Create a host record** (via API or dashboard):
```bash
curl -X POST http://localhost:8000/v1/hosts \
-H "Authorization: Bearer $JWT_TOKEN" \
-H "Content-Type: application/json" \
-d '{"type": "regular"}'
```
This returns a `registration_token` (valid for 1 hour).
1. **Create a host record** in the dashboard (admin only — host management is not exposed over the SDK / API keys). Sign in at `/login`, open the admin hosts page, and click **Add host**. The dashboard returns a `registration_token` valid for 1 hour.
2. **Start the host agent** with the registration token and its externally-reachable address:
```bash
@ -142,13 +151,152 @@ Hosts must be registered with the control plane before they can serve sandboxes.
sudo ./builds/wrenn-agent --address <host-ip>:50051
```
4. **If registration fails** (e.g., network error after token was consumed), regenerate a token:
```bash
curl -X POST http://localhost:8000/v1/hosts/$HOST_ID/token \
-H "Authorization: Bearer $JWT_TOKEN"
```
Then restart the agent with the new token.
4. **If registration fails** (e.g., network error after token was consumed), regenerate a token from the dashboard host detail page, then restart the agent with the new token.
The agent sends heartbeats to the control plane every 30 seconds.
## Notification channels
Teams can subscribe to lifecycle events via webhook, Discord, Slack, Teams, Google Chat, Telegram, or Matrix. All providers consume the same event stream (durable Redis stream `wrenn:events`, consumer group `wrenn-channels-v1`, at-least-once delivery with two retries at 10s / 30s).
### Subscribable event types
| Event | Emitted on | Has outcome |
|-------|-----------|-------------|
| `capsule.create` | First boot of a sandbox | yes |
| `capsule.pause` | Manual pause, TTL auto-pause, or reconciler-detected pause | yes |
| `capsule.resume` | Unpause (any subsequent boot after `capsule.create`) | yes |
| `capsule.destroy` | Stop / destroy, including system cleanup-on-error | yes |
| `template.snapshot.create` | Snapshot taken from a running sandbox | yes |
| `template.snapshot.delete` | Snapshot deletion (including cleanup-on-error) | yes |
| `host.up` | Host agent comes online | no |
| `host.down` | Host agent crashes or misses heartbeats | no |
Subscribing to an event type delivers **both success and failure**. The `outcome` field on the payload (`success` or `error`) distinguishes them. `error` events carry an `error` string with the failure reason.
The transient `capsule.state.changed` event (intermediate transitions like `starting`, `pausing`, `resuming`) is **not** subscribable — it is delivered to the dashboard via SSE only and never written to the durable stream.
### Event payload
All channels receive the same canonical JSON shape:
```json
{
"event": "capsule.pause",
"outcome": "success",
"timestamp": "2026-05-19T14:23:01Z",
"team_id": "tm_...",
"actor": {
"type": "user",
"id": "usr_...",
"name": "alice@example.com"
},
"resource": {
"id": "sb_a1b2c3d4",
"type": "sandbox"
},
"metadata": {
"reason": "ttl_expired"
},
"error": ""
}
```
| Field | Type | Notes |
|-------|------|-------|
| `event` | string | Event type (see table above) |
| `outcome` | `"success"` \| `"error"` \| `""` | Omitted for host.up/host.down |
| `timestamp` | RFC3339 UTC | When the event was published |
| `team_id` | string | Owning team |
| `actor.type` | `"user"` \| `"api_key"` \| `"system"` | System = TTL reaper, reconciler, cleanup-on-error |
| `actor.id` | string | User ID, API key ID, or empty for system |
| `actor.name` | string | Display name (email for user, label for api_key) |
| `resource.id` | string | Sandbox ID, snapshot ID, or host ID |
| `resource.type` | `"sandbox"` \| `"snapshot"` \| `"host"` | |
| `metadata` | object\<string,string\> | Event-specific context (e.g., `reason`, `from`/`to`, `inferred`) |
| `error` | string | Failure reason when `outcome == "error"` |
`metadata` keys you may observe:
- `reason` — `ttl_expired` (auto-pause), `orphaned` (reconciler cleanup), `cleanup_after_create_error`, `restored_after_host_recovery`, `host_state_sync`, `transient_timeout`, `transient_timeout_inferred`
- `inferred` — `"true"` when the reconciler derived the event from host state, not a direct host callback
### Webhook delivery
Webhook channels receive a raw `POST` with the JSON payload as the body.
Headers:
| Header | Value |
|--------|-------|
| `Content-Type` | `application/json` |
| `X-Wrenn-Delivery` | UUID, unique per delivery attempt |
| `X-Wrenn-Timestamp` | RFC3339 UTC, used for signature verification |
| `X-WRENN-SIGNATURE` | `sha256=<hex>` HMAC over `<timestamp>.<body>` using the channel's signing secret |
The signing secret is shown **once** at channel creation. Verify signatures by computing `HMAC-SHA256(secret, timestamp + "." + body)` and comparing to the header (constant-time compare). Reject deliveries where `X-Wrenn-Timestamp` is outside your acceptable clock skew window. Redirects are not followed.
Any non-2xx response triggers retry (10s, then 30s). After three total failures the event is dropped (logged on the control plane).
### Other providers
Discord, Slack, Teams, Google Chat, Telegram, and Matrix receive a formatted text message — the same fields, rendered as human-readable text — not the JSON payload. Use webhook if you need the structured event.
## Extending the control plane
The OSS control plane is designed to be embedded by a private cloud distribution without forking. Import this module, implement the `Extension` interface from `pkg/cpextension`, and pass it to `cpserver.Run`:
```go
import (
"git.omukk.dev/wrenn/wrenn/pkg/cpextension"
"git.omukk.dev/wrenn/wrenn/pkg/cpserver"
)
func main() {
cpserver.Run(
cpserver.WithVersion("cloud-1.0.0"),
cpserver.WithExtensions(&myExtension{}),
)
}
```
Every extension implements two methods:
```go
RegisterRoutes(r chi.Router, sctx cpextension.ServerContext)
BackgroundWorkers(sctx cpextension.ServerContext) []func(context.Context)
```
`ServerContext` exposes the initialized OSS services so extensions never re-implement them: `Queries`, `PgPool`, `Redis`, `HostPool`, `Scheduler`, `CA`, `Audit`, `Mailer`, `OAuthRegistry`, `Channels`, `ChannelPub`, `JWTSecret`, `Sessions`, `Config`.
### Optional hook interfaces
An extension can also implement any subset of these — the OSS server type-asserts at startup:
| Interface | When it fires | Failure semantics |
|---|---|---|
| `MiddlewareProvider` | Wraps every OSS route before registration | n/a |
| `AuthHook.OnSignup(ctx, userID, teamID, email)` | After team provisioning on email-activate or OAuth-new-signup | Error aborts signup with 500 `signup_hook_failed` (billing customer creation must succeed) |
| `AuthHook.OnLogin(ctx, userID)` | After a successful login or OAuth callback | Error logged, login still succeeds |
| `AuthHook.OnAccountSoftDelete(ctx, userID)` | After `DELETE /v1/me` commits | Error logged, request still succeeds |
| `AuthHook.OnAccountHardDelete(ctx, userID)` | After the 15-day cleanup goroutine purges a soft-deleted account | Error logged, cleanup continues |
| `SandboxEventHook.OnSandboxEvent(ctx, ev)` | Capsule create/pause/resume/destroy success, from the Redis stream consumer | Error leaves the message un-acked — hooks **must** be idempotent |
| `LimitsProvider.EffectiveLimits(ctx, teamID)` | `POST /v1/capsules` consults before scheduling | Returns 402 (`concurrent_sandbox_limit` / `vcpu_limit` / `memory_limit`) when over |
| `UsageProvider.CurrentUsage(ctx, teamID)` | Feeds `LimitsProvider` checks; falls back to OSS DB-backed default | Error → 402 `usage_unavailable` |
### Auth middleware helpers
For extensions that gate their own routes:
```go
r.With(cpextension.RequireSession(sctx)).Get("/billing", handler)
r.With(cpextension.RequireSessionOrAPIKey(sctx)).Get("/usage", handler)
r.With(cpextension.RequireSession(sctx), cpextension.RequireAdmin(sctx)).Get("/admin/exports", handler)
// Issue a session from a custom flow (e.g. invite-accept):
sess, err := cpextension.IssueSession(w, r, sctx, userID, teamID)
```
Cookie/header names are exported as `cpextension.SessionCookieName`, `CSRFCookieName`, `CSRFHeaderName`.
See `CLAUDE.md` for full architecture documentation.

View File

@ -1 +1 @@
0.1.1
0.2.0

View File

@ -1 +1 @@
0.1.4
0.2.0

View File

@ -24,6 +24,7 @@ import (
"git.omukk.dev/wrenn/wrenn/internal/layout"
"git.omukk.dev/wrenn/wrenn/internal/network"
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
"git.omukk.dev/wrenn/wrenn/internal/vm"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/logging"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
@ -63,8 +64,12 @@ func main() {
os.Exit(1)
}
// Clean up stale resources from a previous crash.
// Clean up stale resources from a previous crash. Order matters:
// kill stale CH processes first — they hold dm-snapshot devices open and
// would otherwise cause "Device or resource busy" on dmsetup remove.
vm.CleanupStaleProcesses()
devicemapper.CleanupStaleDevices()
devicemapper.LogLoopState()
network.CleanupStaleNamespaces()
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
@ -80,6 +85,25 @@ func main() {
os.Exit(1)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Register with the control plane before touching rootfs images. If the
// agent can't reach the CP there's no point inflating images (and crashing
// afterward would leave them in the expanded state).
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M").
defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB
if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" {
@ -107,46 +131,47 @@ func main() {
}
slog.Info("resolved kernel", "version", kernelVersion, "path", kernelPath)
// Detect firecracker version.
fcBin := envOrDefault("WRENN_FIRECRACKER_BIN", "/usr/local/bin/firecracker")
fcVersion, err := sandbox.DetectFirecrackerVersion(fcBin)
// Detect cloud-hypervisor version.
chBin := envOrDefault("WRENN_CH_BIN", "/usr/local/bin/cloud-hypervisor")
chVersion, err := sandbox.DetectCHVersion(chBin)
if err != nil {
slog.Error("failed to detect firecracker version", "error", err)
slog.Error("failed to detect cloud-hypervisor version", "error", err)
os.Exit(1)
}
slog.Info("resolved firecracker", "version", fcVersion, "path", fcBin)
slog.Info("resolved cloud-hypervisor", "version", chVersion, "path", chBin)
cfg := sandbox.Config{
WrennDir: rootDir,
DefaultRootfsSizeMB: defaultRootfsSizeMB,
KernelPath: kernelPath,
KernelVersion: kernelVersion,
FirecrackerBin: fcBin,
FirecrackerVersion: fcVersion,
VMMBin: chBin,
VMMVersion: chVersion,
AgentVersion: version,
}
// Remove any *.staging-* / *.trash-* directories left behind by a
// previous Pause that crashed before completing the atomic swap. Must
// run before any Resume so we don't race with a sandbox restoration.
sandbox.CleanupOrphanPauseDirs(rootDir)
mgr := sandbox.New(cfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up lifecycle event callback sender so autonomous events
// (auto-pause, auto-destroy) are pushed to the CP proactively.
cb := hostagent.NewCallbackSender(cpURL, credsFile, creds.HostID)
mgr.SetEventSender(hostagent.NewEventSender(cb))
// Restore paused sandboxes from disk so ListSandboxes reports them as
// 'paused' immediately. Without this, the CP's HostMonitor would mark
// every paused-on-disk sandbox 'stopped' via the missing→stopped
// reconcile path on the first ListSandboxes after agent restart.
// Must run before HTTP server starts serving (an early Create would
// race the slot reservation).
mgr.RestorePausedSandboxes()
mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating.
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: credsFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", creds.HostID)
// httpServer is declared here so the shutdown func can reference it.
// ReadTimeout/WriteTimeout are intentionally omitted — they would kill
// long-lived Connect RPC streams and WebSocket proxy connections.
@ -154,6 +179,11 @@ func main() {
Addr: listenAddr,
ReadHeaderTimeout: 10 * time.Second,
IdleTimeout: 620 * time.Second, // > typical LB upstream timeout (600s)
// Disable HTTP/2: empty non-nil map prevents Go from registering
// the h2 ALPN token. Connect RPC works over HTTP/1.1; HTTP/2
// multiplexing causes HOL blocking when a slow sandbox RPC stalls
// the shared connection.
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
}
// mTLS is mandatory — refuse to start without a valid certificate.
@ -183,10 +213,22 @@ func main() {
shutdownOnce.Do(func() {
slog.Info("shutting down", "reason", reason)
cancel()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
// Shutdown pauses every running sandbox in parallel (PauseAll uses
// a worker pool). Per-sandbox Pause can take 1030s (memory loader
// wait + ch.snapshot of guest RAM). 5 minutes is enough headroom for
// a busy host while still bounded so a wedged sandbox can't keep the
// process alive indefinitely — a second signal force-exits anyway.
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer shutdownCancel()
// Order matters: mgr.Shutdown FIRST so it runs to completion
// before httpServer.Shutdown unblocks main's Serve and lets the
// process exit. mgr.Shutdown internally flips a draining flag
// that rejects new Create/Resume RPCs with Unavailable so any
// in-flight HTTP handlers can't add sandboxes after PauseAll
// snapshotted state. User-initiated Pauses already running are
// awaited by PauseAll/Destroy's lifecycleMu serialization.
mgr.Shutdown(shutdownCtx)
sandbox.ShrinkMinimalImage(rootDir)
sandbox.ShrinkSystemImages(rootDir)
if err := httpServer.Shutdown(shutdownCtx); err != nil {
slog.Error("http server shutdown error", "error", err)
}
@ -219,8 +261,9 @@ func main() {
func() {
doShutdown("host deleted from CP")
},
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
// onCredsRefreshed: hot-swap the TLS certificate and update callback JWT.
func(tf *hostagent.TokenFile) {
cb.UpdateJWT(tf.JWT)
if tf.CertPEM == "" || tf.KeyPEM == "" {
return
}
@ -232,12 +275,16 @@ func main() {
},
)
// Graceful shutdown on SIGINT/SIGTERM.
// Graceful shutdown on SIGINT/SIGTERM. A second signal force-exits
// so the operator can always kill the process if shutdown hangs.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigCh
doShutdown("signal: " + sig.String())
go doShutdown("signal: " + sig.String())
sig = <-sigCh
slog.Error("received second signal, force exiting", "signal", sig.String())
os.Exit(1)
}()
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID, "version", version, "commit", commit)
@ -279,7 +326,7 @@ func checkPrivileges() error {
name string
}{
{1, "CAP_DAC_OVERRIDE"}, // /dev/loop*, /dev/mapper/*, /dev/net/tun
{5, "CAP_KILL"}, // SIGTERM/SIGKILL to Firecracker processes
{5, "CAP_KILL"}, // SIGTERM/SIGKILL to cloud-hypervisor processes
{12, "CAP_NET_ADMIN"}, // netlink, iptables, routing, TAP/veth
{13, "CAP_NET_RAW"}, // raw sockets (iptables)
{19, "CAP_SYS_PTRACE"}, // reading /proc/self/ns/net (netns.Get)

View File

@ -0,0 +1,21 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
team_id UUID NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
csrf_token TEXT NOT NULL,
user_agent TEXT NOT NULL DEFAULT '',
ip_address TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_seen_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX sessions_user_id_idx ON sessions(user_id);
CREATE INDEX sessions_expires_at_idx ON sessions(expires_at);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TABLE IF EXISTS sessions;
-- +goose StatementEnd

View File

@ -0,0 +1,15 @@
-- +goose Up
-- +goose StatementBegin
-- Session IDs are now stored as sha256(raw_sid) hex so a DB/Redis dump
-- cannot be replayed as session cookies. Existing sessions hold raw SIDs
-- in id; they are unrecoverable under the new scheme and must be wiped.
-- Users will need to log in again after this migration.
TRUNCATE TABLE sessions;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Down: nothing to do schematically. Hashed rows remain but will never
-- match a raw cookie under the old code path; safest is to wipe again.
TRUNCATE TABLE sessions;
-- +goose StatementEnd

View File

@ -0,0 +1,49 @@
-- +goose Up
-- Replace the old all-zeros "minimal" base template with the four system base
-- templates (ubuntu/alpine/arch/fedora). All are platform-owned (team_id
-- all-zeros) with reserved template IDs 0..3, default user wrenn-user.
--
-- Template IDs are well-known: the all-zeros UUID + low byte = {0,1,2,3}.
-- On disk each lives at images/teams/{base36(0)}/{base36(id)}/rootfs.ext4.
-- 0 → minimal-ubuntu (was "minimal").
UPDATE templates
SET name = 'minimal-ubuntu',
default_user = 'wrenn-user'
WHERE id = '00000000-0000-0000-0000-000000000000';
-- Seed the row if it did not already exist (fresh DBs).
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id, default_user)
VALUES ('00000000-0000-0000-0000-000000000000', 'minimal-ubuntu', 'base', 1, 512, 0,
'00000000-0000-0000-0000-000000000000', 'wrenn-user')
ON CONFLICT (id) DO NOTHING;
-- 1 → minimal-alpine, 2 → minimal-arch, 3 → minimal-fedora.
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id, default_user)
VALUES
('00000000-0000-0000-0000-000000000001', 'minimal-alpine', 'base', 1, 512, 0,
'00000000-0000-0000-0000-000000000000', 'wrenn-user'),
('00000000-0000-0000-0000-000000000002', 'minimal-arch', 'base', 1, 512, 0,
'00000000-0000-0000-0000-000000000000', 'wrenn-user'),
('00000000-0000-0000-0000-000000000003', 'minimal-fedora', 'base', 1, 512, 0,
'00000000-0000-0000-0000-000000000000', 'wrenn-user')
ON CONFLICT (id) DO NOTHING;
-- Point the sandboxes.template column default at the new default base template.
ALTER TABLE sandboxes ALTER COLUMN template SET DEFAULT 'minimal-ubuntu';
-- +goose Down
ALTER TABLE sandboxes ALTER COLUMN template SET DEFAULT 'minimal';
DELETE FROM templates WHERE id IN (
'00000000-0000-0000-0000-000000000001',
'00000000-0000-0000-0000-000000000002',
'00000000-0000-0000-0000-000000000003'
);
UPDATE templates
SET name = 'minimal',
default_user = 'root'
WHERE id = '00000000-0000-0000-0000-000000000000';

View File

@ -13,7 +13,36 @@ SELECT * FROM hosts ORDER BY created_at DESC;
SELECT * FROM hosts WHERE type = $1 ORDER BY created_at DESC;
-- name: ListHostsByTeam :many
SELECT * FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC;
-- Returns hosts by team with per-host sandbox resource consumption aggregated.
-- Follows the same aggregation pattern as ListHostsAdmin and GetHostsWithLoad.
SELECT
h.id,
h.type,
h.team_id,
h.provider,
h.availability_zone,
h.arch,
h.cpu_cores,
h.memory_mb,
h.disk_gb,
h.address,
h.status,
h.last_heartbeat_at,
h.metadata,
h.created_by,
h.created_at,
h.updated_at,
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_disk_mb
FROM hosts h
LEFT JOIN sandboxes s ON s.host_id = h.id
AND s.status IN ('running', 'paused', 'starting', 'pending')
WHERE h.team_id = $1 AND h.type = 'byoc'
GROUP BY h.id
ORDER BY h.created_at DESC;
-- name: ListHostsByStatus :many
SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
@ -101,8 +130,6 @@ SELECT
h.created_by,
h.created_at,
h.updated_at,
h.cert_fingerprint,
h.cert_expires_at,
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
@ -125,5 +152,37 @@ SET last_heartbeat_at = NOW(),
updated_at = NOW()
WHERE id = $1;
-- name: ListHostsAdmin :many
-- Returns all hosts with per-host sandbox resource consumption aggregated.
-- Unlike GetHostsWithLoad, this returns ALL hosts (not just online) so admins
-- can see resource usage across the entire fleet including offline/pending hosts.
SELECT
h.id,
h.type,
h.team_id,
h.provider,
h.availability_zone,
h.arch,
h.cpu_cores,
h.memory_mb,
h.disk_gb,
h.address,
h.status,
h.last_heartbeat_at,
h.metadata,
h.created_by,
h.created_at,
h.updated_at,
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_disk_mb
FROM hosts h
LEFT JOIN sandboxes s ON s.host_id = h.id
AND s.status IN ('running', 'paused', 'starting', 'pending')
GROUP BY h.id
ORDER BY h.created_at DESC;
-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1;

View File

@ -42,6 +42,13 @@ ORDER BY ts ASC;
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1;
-- name: GetLatestSandboxMetricPoint :one
SELECT ts, cpu_pct, mem_bytes, disk_bytes
FROM sandbox_metric_points
WHERE sandbox_id = $1
ORDER BY ts DESC
LIMIT 1;
-- name: DeleteSandboxMetricPointsByTier :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2;

View File

@ -72,7 +72,7 @@ ORDER BY created_at DESC;
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending', 'pausing', 'resuming', 'stopping');
-- name: UpdateSandboxMetadata :exec
UPDATE sandboxes
@ -80,10 +80,36 @@ SET metadata = $2,
last_updated = NOW()
WHERE id = $1;
-- name: BulkRestoreRunning :exec
-- Called by the reconciler when a host comes back online and its sandboxes are
-- confirmed alive. Restores only sandboxes that are in 'missing' state.
-- name: UpdateSandboxRunningIf :one
-- Conditionally transition a sandbox to running only if the current status
-- matches the expected value. Prevents races where a user destroys a sandbox
-- while the create/resume goroutine is still in-flight.
UPDATE sandboxes
SET status = 'running',
host_ip = $3,
guest_ip = $4,
started_at = $5,
last_active_at = $5,
last_updated = NOW()
WHERE id = $1 AND status = $2
RETURNING *;
-- name: UpdateSandboxStatusIf :one
-- Atomically update status only when the current status matches the expected value.
-- Prevents background goroutines from overwriting a status that has since changed
-- (e.g. user destroyed a sandbox while Create was in-flight).
UPDATE sandboxes
SET status = $3,
last_updated = NOW()
WHERE id = $1 AND status = $2
RETURNING *;
-- name: BulkRestoreMissingToStatus :exec
-- Called by the reconciler when a host comes back online and its sandboxes are
-- confirmed alive. Restores only sandboxes currently in 'missing' state to the
-- given target status (typically 'running' or 'paused' based on the live state
-- reported by the host agent's ListSandboxes RPC).
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::uuid[]) AND status = 'missing';

28
db/queries/sessions.sql Normal file
View File

@ -0,0 +1,28 @@
-- name: InsertSession :one
INSERT INTO sessions (id, user_id, team_id, csrf_token, user_agent, ip_address, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING *;
-- name: GetSession :one
SELECT * FROM sessions WHERE id = $1;
-- name: TouchSession :exec
UPDATE sessions SET last_seen_at = NOW() WHERE id = $1;
-- name: UpdateSessionTeam :exec
UPDATE sessions SET team_id = $2 WHERE id = $1;
-- name: DeleteSession :exec
DELETE FROM sessions WHERE id = $1;
-- name: DeleteSessionForUser :exec
DELETE FROM sessions WHERE id = $1 AND user_id = $2;
-- name: ListSessionsByUserID :many
SELECT * FROM sessions WHERE user_id = $1 ORDER BY last_seen_at DESC;
-- name: DeleteSessionsByUserID :many
DELETE FROM sessions WHERE user_id = $1 RETURNING id;
-- name: DeleteExpiredSessions :exec
DELETE FROM sessions WHERE expires_at < NOW();

View File

@ -66,7 +66,9 @@ SELECT
COALESCE(owner_u.name, '') AS owner_name,
COALESCE(owner_u.email, '') AS owner_email,
(SELECT COUNT(*) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting'))::int AS active_sandbox_count,
(SELECT COUNT(*) FROM channels c WHERE c.team_id = t.id)::int AS channel_count
(SELECT COUNT(*) FROM channels c WHERE c.team_id = t.id)::int AS channel_count,
COALESCE((SELECT SUM(s.vcpus) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting')), 0)::int AS running_vcpus,
COALESCE((SELECT SUM(s.memory_mb) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting')), 0)::int AS running_memory_mb
FROM teams t
LEFT JOIN users_teams owner_ut ON owner_ut.team_id = t.id AND owner_ut.role = 'owner'
LEFT JOIN users owner_u ON owner_u.id = owner_ut.user_id

View File

@ -42,6 +42,9 @@ DELETE FROM templates WHERE name = $1 AND team_id = $2;
-- Bulk delete all templates owned by a team (for team soft-delete cleanup).
DELETE FROM templates WHERE team_id = $1;
-- name: UpdateTemplateSize :exec
UPDATE templates SET size_bytes = $2 WHERE id = $1;
-- name: ListTemplatesByTeamOnly :many
-- List templates owned by a specific team (NOT including platform templates).
SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;

View File

@ -22,6 +22,12 @@ RETURNING *;
-- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
-- name: RevokeUserAdmin :execrows
UPDATE users u SET is_admin = false, updated_at = NOW()
WHERE u.id = $1
AND u.is_admin = true
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1;
-- name: GetAdminUsers :many
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;

View File

@ -0,0 +1,2 @@
[target.x86_64-unknown-linux-musl]
linker = "musl-gcc"

2328
envd-rs/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

86
envd-rs/Cargo.toml Normal file
View File

@ -0,0 +1,86 @@
[package]
name = "envd"
version = "0.3.0"
edition = "2024"
rust-version = "1.95"
[dependencies]
# Async runtime
tokio = { version = "1", features = ["full"] }
# HTTP framework
axum = { version = "0.8", features = ["multipart"] }
tower = { version = "0.5", features = ["util"] }
tower-http = { version = "0.6", features = ["cors", "fs"] }
tower-service = "0.3"
# RPC (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
connectrpc = { version = "0.3", features = ["axum"] }
buffa-types = { path = "buffa-types-shim" }
# CLI
clap = { version = "4", features = ["derive"] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
# System metrics
sysinfo = "0.33"
# Unix syscalls
nix = { version = "0.30", features = ["fs", "process", "signal", "user", "term", "mount", "ioctl"] }
# Concurrent map
dashmap = "6"
# Crypto
sha2 = "0.10"
hmac = "0.12"
hex = "0.4"
base64 = "0.22"
# Secure memory
zeroize = { version = "1", features = ["derive"] }
# File watching
notify = "7"
# Compression
flate2 = "1"
# Directory walking
walkdir = "2"
# Time parsing (RFC3339 → epoch nanos for clock fix)
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
# Misc
libc = "0.2"
bytes = "1"
http = "1"
http-body-util = "0.1"
futures = "0.3"
tokio-util = { version = "0.7", features = ["io"] }
subtle = "2"
http-body = "1.0.1"
buffa = "0.3"
async-stream = "0.3.6"
mime_guess = "2"
[dev-dependencies]
tempfile = "3"
[build-dependencies]
connectrpc-build = "0.3"
[profile.release]
strip = true
lto = true
opt-level = "z"
codegen-units = 1
panic = "abort"

142
envd-rs/README.md Normal file
View File

@ -0,0 +1,142 @@
# envd (Rust)
Wrenn guest agent daemon — runs as PID 1 inside Cloud Hypervisor microVMs. Provides process management, filesystem operations, file transfer, port forwarding, and VM lifecycle control over Connect RPC and HTTP.
Rust rewrite of `envd/` (Go). Drop-in replacement — same wire protocol, same endpoints, same CLI flags.
## Prerequisites
- Rust 1.88+ (required by `connectrpc` 0.3.3)
- `protoc` (protobuf compiler, for proto codegen at build time)
- `musl-tools` (for static linking)
```bash
# Ubuntu/Debian
sudo apt install musl-tools protobuf-compiler
# Rust musl target
rustup target add x86_64-unknown-linux-musl
```
## Building
### Static binary (production — what goes into the rootfs)
```bash
cd envd-rs
ENVD_COMMIT=$(git rev-parse --short HEAD) \
cargo build --release --target x86_64-unknown-linux-musl
```
Output: `target/x86_64-unknown-linux-musl/release/envd`
Verify static linking:
```bash
file target/x86_64-unknown-linux-musl/release/envd
# should say: "statically linked"
ldd target/x86_64-unknown-linux-musl/release/envd
# should say: "not a dynamic executable"
```
### Debug binary (dev machine, dynamically linked)
```bash
cd envd-rs
cargo build
```
Run locally (outside a VM):
```bash
./target/debug/envd --port 49983
```
### Via Makefile (from repo root)
```bash
make build-envd # static musl release build
make build-envd-go # Go version (for comparison)
```
## CLI Flags
```
--port <PORT> Listen port [default: 49983]
--version Print version and exit
--commit Print git commit and exit
--cmd <CMD> Spawn a process at startup (e.g. --cmd "/bin/bash")
--cgroup-root <PATH> Cgroup v2 root [default: /sys/fs/cgroup]
```
## Endpoints
### HTTP
| Method | Path | Description |
|--------|---------------------|--------------------------------------|
| GET | `/health` | Health check, triggers post-restore |
| GET | `/metrics` | System metrics (CPU, memory, disk) |
| GET | `/envs` | Current environment variables |
| POST | `/init` | Host agent init (token, env, mounts) |
| POST | `/snapshot/prepare` | Quiesce before Cloud Hypervisor snapshot |
| GET | `/files` | Download file (gzip, range support) |
| POST | `/files` | Upload file(s) via multipart |
### Connect RPC (same port)
| Service | RPCs |
|------------|-------------------------------------------------------------------------|
| Process | List, Start, Connect, Update, StreamInput, SendInput, SendSignal, CloseStdin |
| Filesystem | Stat, MakeDir, Move, ListDir, Remove, WatchDir, CreateWatcher, GetWatcherEvents, RemoveWatcher |
## Architecture
```
42 files, ~4200 LOC Rust
Binary: ~4 MB (stripped, LTO, musl static)
src/
├── main.rs # Entry point, CLI, server setup
├── state.rs # Shared AppState
├── config.rs # Constants
├── conntracker.rs # TCP connection tracking for snapshot/restore
├── execcontext.rs # Default user/workdir/env
├── logging.rs # tracing-subscriber (JSON or pretty)
├── util.rs # AtomicMax
├── auth/ # Token, signing, middleware
├── crypto/ # SHA-256, SHA-512, HMAC
├── host/ # System metrics
├── http/ # Axum handlers (health, init, snapshot, files, encoding)
├── permissions/ # Path resolution, user lookup, chown
├── rpc/ # Connect RPC services
│ ├── pb.rs # Generated proto types
│ ├── process_*.rs # Process service + handler (PTY, pipe, broadcast)
│ ├── filesystem_*.rs # Filesystem service (stat, list, watch, mkdir, move, remove)
│ └── entry.rs # EntryInfo builder
├── port/ # Port subsystem
│ ├── conn.rs # /proc/net/tcp parser
│ ├── scanner.rs # Periodic TCP port scanner
│ ├── forwarder.rs # socat-based port forwarding
│ └── subsystem.rs # Lifecycle (start/stop/restart)
└── cgroups/ # Cgroup v2 manager (pty/user/socat groups)
```
## Updating the rootfs
After building the static binary, copy it into the rootfs:
```bash
bash scripts/update-minimal-rootfs.sh [rootfs_path]
```
With no argument it updates all four system base images; pass a path to target one.
Or manually (example path: the minimal-ubuntu image, platform team + template id 0):
```bash
sudo mount -o loop /var/lib/wrenn/images/teams/0000000000000000000000000/0000000000000000000000000/rootfs.ext4 /mnt
sudo cp target/x86_64-unknown-linux-musl/release/envd /mnt/usr/local/bin/envd
sudo umount /mnt
```

View File

@ -0,0 +1,12 @@
[package]
name = "buffa-types"
version = "0.3.0"
edition = "2024"
publish = false
[dependencies]
buffa = "0.3"
serde = { version = "1", features = ["derive"] }
[build-dependencies]
connectrpc-build = "0.3"

View File

@ -0,0 +1,9 @@
fn main() {
connectrpc_build::Config::new()
.files(&["/usr/include/google/protobuf/timestamp.proto"])
.includes(&["/usr/include"])
.include_file("_types.rs")
.emit_register_fn(false)
.compile()
.unwrap();
}

View File

@ -0,0 +1,6 @@
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
use ::buffa;
use ::serde;
include!(concat!(env!("OUT_DIR"), "/_types.rs"));

11
envd-rs/build.rs Normal file
View File

@ -0,0 +1,11 @@
fn main() {
connectrpc_build::Config::new()
.files(&[
"../proto/envd/process.proto",
"../proto/envd/filesystem.proto",
])
.includes(&["../proto/envd", "/usr/include"])
.include_file("_connectrpc.rs")
.compile()
.unwrap();
}

View File

@ -0,0 +1,3 @@
[toolchain]
channel = "stable"
targets = ["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl"]

View File

@ -0,0 +1,56 @@
use std::sync::Arc;
use axum::extract::Request;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use crate::auth::token::SecureToken;
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
/// Paths excluded from general token auth.
/// Format: "METHOD/path"
const AUTH_EXCLUDED: &[&str] = &[
"GET/health",
"GET/files",
"POST/files",
"POST/init",
"POST/snapshot/prepare",
];
/// Axum middleware that checks X-Access-Token header.
pub async fn auth_layer(
request: Request,
next: Next,
access_token: Arc<SecureToken>,
) -> Response {
if access_token.is_set() {
let method = request.method().as_str();
let path = request.uri().path();
let key = format!("{method}{path}");
let is_excluded = AUTH_EXCLUDED.iter().any(|p| *p == key);
let header_val = request
.headers()
.get(ACCESS_TOKEN_HEADER)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !access_token.equals(header_val) && !is_excluded {
tracing::error!("unauthorized access attempt");
return (
StatusCode::UNAUTHORIZED,
axum::Json(json!({
"code": 401,
"message": "unauthorized access, please provide a valid access token or method signing if supported"
})),
)
.into_response();
}
}
next.run(request).await
}

3
envd-rs/src/auth/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod token;
pub mod signing;
pub mod middleware;

210
envd-rs/src/auth/signing.rs Normal file
View File

@ -0,0 +1,210 @@
use crate::auth::token::SecureToken;
use crate::crypto;
use zeroize::Zeroize;
pub const READ_OPERATION: &str = "read";
pub const WRITE_OPERATION: &str = "write";
/// Generate a v1 signature: `v1_{sha256_base64(path:operation:username:token[:expiration])}`
pub fn generate_signature(
token: &SecureToken,
path: &str,
username: &str,
operation: &str,
expiration: Option<i64>,
) -> Result<String, &'static str> {
let mut token_bytes = token.bytes().ok_or("access token is not set")?;
let payload = match expiration {
Some(exp) => format!(
"{}:{}:{}:{}:{}",
path,
operation,
username,
String::from_utf8_lossy(&token_bytes),
exp
),
None => format!(
"{}:{}:{}:{}",
path,
operation,
username,
String::from_utf8_lossy(&token_bytes),
),
};
token_bytes.zeroize();
let hash = crypto::sha256::hash_without_prefix(payload.as_bytes());
Ok(format!("v1_{hash}"))
}
/// Validate a request's signing. Returns Ok(()) if valid.
pub fn validate_signing(
token: &SecureToken,
header_token: Option<&str>,
signature: Option<&str>,
signature_expiration: Option<i64>,
username: &str,
path: &str,
operation: &str,
) -> Result<(), String> {
if !token.is_set() {
return Ok(());
}
if let Some(ht) = header_token {
if !ht.is_empty() {
if token.equals(ht) {
return Ok(());
}
return Err("access token present in header but does not match".into());
}
}
let sig = signature.ok_or("missing signature query parameter")?;
let expected = generate_signature(token, path, username, operation, signature_expiration)
.map_err(|e| format!("error generating signing key: {e}"))?;
if expected != sig {
return Err("invalid signature".into());
}
if let Some(exp) = signature_expiration {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
if exp < now {
return Err("signature is already expired".into());
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_token(val: &[u8]) -> SecureToken {
let t = SecureToken::new();
t.set(val).unwrap();
t
}
fn far_future() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ 3600
}
#[test]
fn generate_starts_with_v1() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(sig.starts_with("v1_"));
}
#[test]
fn generate_deterministic() {
let token = test_token(b"secret");
let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn generate_with_expiration_differs() {
let token = test_token(b"secret");
let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap();
let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap();
assert_ne!(without, with);
}
#[test]
fn generate_unset_token_errors() {
let token = SecureToken::new();
assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err());
}
#[test]
fn validate_no_token_set_passes() {
let token = SecureToken::new();
assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_correct_header_token() {
let token = test_token(b"secret");
assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok());
}
#[test]
fn validate_wrong_header_token() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not match"));
}
#[test]
fn validate_valid_signature() {
let token = test_token(b"secret");
let exp = far_future();
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap();
assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn validate_invalid_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid signature"));
}
#[test]
fn validate_expired_signature() {
let token = test_token(b"secret");
let expired: i64 = 1_000_000;
let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap();
let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("expired"));
}
#[test]
fn validate_missing_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_empty_header_token_falls_through_to_signature() {
let token = test_token(b"secret");
let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing signature"));
}
#[test]
fn validate_valid_signature_no_expiration() {
let token = test_token(b"secret");
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok());
}
#[test]
fn different_operations_produce_different_signatures() {
let token = test_token(b"secret");
let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap();
let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap();
assert_ne!(r, w);
}
}

256
envd-rs/src/auth/token.rs Normal file
View File

@ -0,0 +1,256 @@
use std::sync::RwLock;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
/// Secure token storage with constant-time comparison and zeroize-on-drop.
///
/// Mirrors Go's SecureToken backed by memguard.LockedBuffer.
/// In Rust we rely on `zeroize` for Drop-based zeroing.
pub struct SecureToken {
inner: RwLock<Option<Vec<u8>>>,
}
impl SecureToken {
pub fn new() -> Self {
Self {
inner: RwLock::new(None),
}
}
pub fn set(&self, token: &[u8]) -> Result<(), &'static str> {
if token.is_empty() {
return Err("empty token not allowed");
}
let mut guard = self.inner.write().unwrap();
if let Some(ref mut old) = *guard {
old.zeroize();
}
*guard = Some(token.to_vec());
Ok(())
}
pub fn is_set(&self) -> bool {
let guard = self.inner.read().unwrap();
guard.is_some()
}
/// Constant-time comparison.
pub fn equals(&self, other: &str) -> bool {
let guard = self.inner.read().unwrap();
match guard.as_ref() {
Some(buf) => buf.as_slice().ct_eq(other.as_bytes()).into(),
None => false,
}
}
/// Constant-time comparison with another SecureToken.
pub fn equals_secure(&self, other: &SecureToken) -> bool {
let other_bytes = match other.bytes() {
Some(b) => b,
None => return false,
};
let guard = self.inner.read().unwrap();
let result = match guard.as_ref() {
Some(buf) => buf.as_slice().ct_eq(&other_bytes).into(),
None => false,
};
// other_bytes dropped here, Vec<u8> doesn't auto-zeroize but
// we accept this — same as Go's `defer memguard.WipeBytes(otherBytes)`
result
}
/// Returns a copy of the token bytes (for signature generation).
pub fn bytes(&self) -> Option<Vec<u8>> {
let guard = self.inner.read().unwrap();
guard.as_ref().map(|b| b.clone())
}
/// Transfer token from another SecureToken, clearing the source.
pub fn take_from(&self, src: &SecureToken) {
let taken = {
let mut src_guard = src.inner.write().unwrap();
src_guard.take()
};
let mut guard = self.inner.write().unwrap();
if let Some(ref mut old) = *guard {
old.zeroize();
}
*guard = taken;
}
pub fn destroy(&self) {
let mut guard = self.inner.write().unwrap();
if let Some(ref mut buf) = *guard {
buf.zeroize();
}
*guard = None;
}
}
impl Drop for SecureToken {
fn drop(&mut self) {
if let Ok(mut guard) = self.inner.write() {
if let Some(ref mut buf) = *guard {
buf.zeroize();
}
}
}
}
/// Deserialize from JSON string, matching Go's UnmarshalJSON behavior.
/// Expects a quoted JSON string. Rejects escape sequences.
impl SecureToken {
pub fn from_json_bytes(data: &mut [u8]) -> Result<Self, &'static str> {
if data.len() < 2 || data[0] != b'"' || data[data.len() - 1] != b'"' {
data.zeroize();
return Err("invalid secure token JSON string");
}
let content = &data[1..data.len() - 1];
if content.contains(&b'\\') {
data.zeroize();
return Err("invalid secure token: unexpected escape sequence");
}
if content.is_empty() {
data.zeroize();
return Err("empty token not allowed");
}
let token = Self::new();
token.set(content).map_err(|_| "failed to set token")?;
data.zeroize();
Ok(token)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_is_unset() {
let t = SecureToken::new();
assert!(!t.is_set());
assert!(!t.equals("anything"));
}
#[test]
fn set_and_equals() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
assert!(t.is_set());
assert!(t.equals("secret"));
assert!(!t.equals("wrong"));
}
#[test]
fn set_empty_errors() {
let t = SecureToken::new();
assert!(t.set(b"").is_err());
assert!(!t.is_set());
}
#[test]
fn set_overwrites_previous() {
let t = SecureToken::new();
t.set(b"first").unwrap();
t.set(b"second").unwrap();
assert!(!t.equals("first"));
assert!(t.equals("second"));
}
#[test]
fn destroy_clears() {
let t = SecureToken::new();
t.set(b"secret").unwrap();
t.destroy();
assert!(!t.is_set());
assert!(!t.equals("secret"));
}
#[test]
fn bytes_returns_copy() {
let t = SecureToken::new();
assert!(t.bytes().is_none());
t.set(b"hello").unwrap();
assert_eq!(t.bytes().unwrap(), b"hello");
}
#[test]
fn take_from_transfers_and_clears_source() {
let src = SecureToken::new();
src.set(b"token").unwrap();
let dst = SecureToken::new();
dst.take_from(&src);
assert!(!src.is_set());
assert!(dst.equals("token"));
}
#[test]
fn take_from_overwrites_existing() {
let src = SecureToken::new();
src.set(b"new").unwrap();
let dst = SecureToken::new();
dst.set(b"old").unwrap();
dst.take_from(&src);
assert!(dst.equals("new"));
assert!(!dst.equals("old"));
}
#[test]
fn equals_secure_matching() {
let a = SecureToken::new();
a.set(b"same").unwrap();
let b = SecureToken::new();
b.set(b"same").unwrap();
assert!(a.equals_secure(&b));
}
#[test]
fn equals_secure_different() {
let a = SecureToken::new();
a.set(b"one").unwrap();
let b = SecureToken::new();
b.set(b"two").unwrap();
assert!(!a.equals_secure(&b));
}
#[test]
fn equals_secure_unset() {
let a = SecureToken::new();
let b = SecureToken::new();
assert!(!a.equals_secure(&b));
}
#[test]
fn from_json_bytes_valid() {
let mut data = b"\"mysecret\"".to_vec();
let t = SecureToken::from_json_bytes(&mut data).unwrap();
assert!(t.equals("mysecret"));
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_missing_quotes() {
let mut data = b"noquotes".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_escape_sequences() {
let mut data = b"\"has\\nescapes\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
#[test]
fn from_json_bytes_rejects_empty_content() {
let mut data = b"\"\"".to_vec();
assert!(SecureToken::from_json_bytes(&mut data).is_err());
assert!(data.iter().all(|&b| b == 0));
}
}

View File

@ -0,0 +1,66 @@
use std::collections::HashMap;
use std::fs;
use std::os::unix::io::{OwnedFd, RawFd};
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProcessType {
Pty,
User,
Socat,
}
pub trait CgroupManager: Send + Sync {
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd>;
}
pub struct Cgroup2Manager {
fds: HashMap<ProcessType, OwnedFd>,
}
impl Cgroup2Manager {
pub fn new(root: &str, configs: &[(ProcessType, &str, &[(&str, &str)])]) -> Result<Self, String> {
let mut fds = HashMap::new();
for (proc_type, sub_path, properties) in configs {
let full_path = PathBuf::from(root).join(sub_path);
fs::create_dir_all(&full_path).map_err(|e| {
format!("failed to create cgroup {}: {e}", full_path.display())
})?;
for (name, value) in *properties {
let prop_path = full_path.join(name);
fs::write(&prop_path, value).map_err(|e| {
format!("failed to write cgroup property {}: {e}", prop_path.display())
})?;
}
let fd = nix::fcntl::open(
&full_path,
nix::fcntl::OFlag::O_RDONLY,
nix::sys::stat::Mode::empty(),
)
.map_err(|e| format!("failed to open cgroup {}: {e}", full_path.display()))?;
fds.insert(*proc_type, fd);
}
Ok(Self { fds })
}
}
impl CgroupManager for Cgroup2Manager {
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd> {
use std::os::unix::io::AsRawFd;
self.fds.get(&proc_type).map(|fd| fd.as_raw_fd())
}
}
pub struct NoopCgroupManager;
impl CgroupManager for NoopCgroupManager {
fn get_fd(&self, _proc_type: ProcessType) -> Option<RawFd> {
None
}
}

11
envd-rs/src/config.rs Normal file
View File

@ -0,0 +1,11 @@
use std::time::Duration;
pub const DEFAULT_PORT: u16 = 49983;
pub const IDLE_TIMEOUT: Duration = Duration::from_secs(640);
pub const CORS_MAX_AGE: Duration = Duration::from_secs(7200);
pub const PORT_SCANNER_INTERVAL: Duration = Duration::from_millis(1000);
pub const DEFAULT_USER: &str = "root";
pub const WRENN_RUN_DIR: &str = "/run/wrenn";
pub const KILOBYTE: u64 = 1024;
pub const MEGABYTE: u64 = 1024 * KILOBYTE;

View File

@ -0,0 +1,70 @@
use std::collections::HashSet;
use std::sync::Mutex;
/// Tracks active TCP connections.
pub struct ConnTracker {
inner: Mutex<ConnTrackerInner>,
}
struct ConnTrackerInner {
active: HashSet<u64>,
next_id: u64,
}
impl ConnTracker {
pub fn new() -> Self {
Self {
inner: Mutex::new(ConnTrackerInner {
active: HashSet::new(),
next_id: 0,
}),
}
}
pub fn register_connection(&self) -> u64 {
let mut inner = self.inner.lock().unwrap();
let id = inner.next_id;
inner.next_id += 1;
inner.active.insert(id);
id
}
pub fn remove_connection(&self, id: u64) {
let mut inner = self.inner.lock().unwrap();
inner.active.remove(&id);
}
#[cfg(test)]
fn active_count(&self) -> usize {
self.inner.lock().unwrap().active.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_assigns_sequential_ids() {
let ct = ConnTracker::new();
assert_eq!(ct.register_connection(), 0);
assert_eq!(ct.register_connection(), 1);
assert_eq!(ct.register_connection(), 2);
}
#[test]
fn remove_clears_active() {
let ct = ConnTracker::new();
let id = ct.register_connection();
assert_eq!(ct.active_count(), 1);
ct.remove_connection(id);
assert_eq!(ct.active_count(), 0);
}
#[test]
fn remove_nonexistent_is_noop() {
let ct = ConnTracker::new();
ct.remove_connection(999);
assert_eq!(ct.active_count(), 0);
}
}

View File

@ -0,0 +1,43 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
pub fn compute(key: &[u8], data: &[u8]) -> String {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
let result = mac.finalize();
hex::encode(result.into_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rfc4231_tc1() {
let key = &[0x0b; 20];
let data = b"Hi There";
assert_eq!(
compute(key, data),
"b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
);
}
#[test]
fn rfc4231_tc2() {
let key = b"Jefe";
let data = b"what do ya want for nothing?";
assert_eq!(
compute(key, data),
"5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
);
}
#[test]
fn output_is_64_hex_chars() {
let result = compute(b"key", b"data");
assert_eq!(result.len(), 64);
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
}
}

View File

@ -0,0 +1,3 @@
pub mod sha256;
pub mod sha512;
pub mod hmac_sha256;

View File

@ -0,0 +1,54 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD_NO_PAD;
use sha2::{Digest, Sha256};
pub fn hash(data: &[u8]) -> String {
let h = Sha256::digest(data);
let encoded = STANDARD_NO_PAD.encode(h);
format!("$sha256${encoded}")
}
pub fn hash_without_prefix(data: &[u8]) -> String {
let h = Sha256::digest(data);
STANDARD_NO_PAD.encode(h)
}
#[cfg(test)]
mod tests {
use super::*;
const VECTORS: &[(&[u8], &str)] = &[
(b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"),
(b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"),
(b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"),
];
#[test]
fn known_answer_with_prefix() {
for (input, expected_b64) in VECTORS {
let result = hash(input);
assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input));
}
}
#[test]
fn known_answer_without_prefix() {
for (input, expected_b64) in VECTORS {
let result = hash_without_prefix(input);
assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input));
}
}
#[test]
fn no_base64_padding() {
for (input, _) in VECTORS {
assert!(!hash(input).contains('='));
assert!(!hash_without_prefix(input).contains('='));
}
}
#[test]
fn deterministic() {
assert_eq!(hash(b"test"), hash(b"test"));
}
}

View File

@ -0,0 +1,43 @@
use sha2::{Digest, Sha512};
pub fn hash_access_token(token: &str) -> String {
let h = Sha512::digest(token.as_bytes());
hex::encode(h)
}
pub fn hash_access_token_bytes(token: &[u8]) -> String {
let h = Sha512::digest(token);
hex::encode(h)
}
#[cfg(test)]
mod tests {
use super::*;
const VECTORS: &[(&str, &str)] = &[
("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"),
("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"),
("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"),
];
#[test]
fn known_answer() {
for (input, expected) in VECTORS {
assert_eq!(hash_access_token(input), *expected, "input: {input:?}");
}
}
#[test]
fn str_and_bytes_agree() {
for (input, _) in VECTORS {
assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes()));
}
}
#[test]
fn output_is_lowercase_hex_128_chars() {
let h = hash_access_token("anything");
assert_eq!(h.len(), 128);
assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
}
}

118
envd-rs/src/execcontext.rs Normal file
View File

@ -0,0 +1,118 @@
use dashmap::DashMap;
use std::sync::{Arc, RwLock};
pub struct Defaults {
pub env_vars: Arc<DashMap<String, String>>,
user: RwLock<String>,
workdir: RwLock<Option<String>>,
}
impl Defaults {
pub fn new(user: &str) -> Self {
Self {
env_vars: Arc::new(DashMap::new()),
user: RwLock::new(user.to_string()),
workdir: RwLock::new(None),
}
}
pub fn user(&self) -> String {
self.user.read().unwrap().clone()
}
pub fn set_user(&self, user: String) {
*self.user.write().unwrap() = user;
}
pub fn workdir(&self) -> Option<String> {
self.workdir.read().unwrap().clone()
}
pub fn set_workdir(&self, workdir: Option<String>) {
*self.workdir.write().unwrap() = workdir;
}
}
pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String {
if !workdir.is_empty() {
return workdir.to_string();
}
if let Some(dw) = default_workdir {
return dw.to_string();
}
String::new()
}
pub fn resolve_default_username<'a>(
username: Option<&'a str>,
default_username: &'a str,
) -> Result<&'a str, &'static str> {
if let Some(u) = username {
return Ok(u);
}
if !default_username.is_empty() {
return Ok(default_username);
}
Err("username not provided")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn workdir_explicit_overrides_default() {
assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit");
}
#[test]
fn workdir_empty_uses_default() {
assert_eq!(resolve_default_workdir("", Some("/default")), "/default");
}
#[test]
fn workdir_empty_no_default_returns_empty() {
assert_eq!(resolve_default_workdir("", None), "");
}
#[test]
fn workdir_explicit_ignores_none_default() {
assert_eq!(resolve_default_workdir("/explicit", None), "/explicit");
}
#[test]
fn username_explicit_returns_explicit() {
assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root");
}
#[test]
fn username_none_uses_default() {
assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn");
}
#[test]
fn username_none_empty_default_errors() {
assert!(resolve_default_username(None, "").is_err());
}
#[test]
fn username_some_overrides_empty_default() {
assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root");
}
#[test]
fn defaults_user_set_and_get() {
let d = Defaults::new("initial");
assert_eq!(d.user(), "initial");
d.set_user("changed".into());
assert_eq!(d.user(), "changed");
}
#[test]
fn defaults_workdir_initially_none() {
let d = Defaults::new("user");
assert!(d.workdir().is_none());
d.set_workdir(Some("/home".into()));
assert_eq!(d.workdir().unwrap(), "/home");
}
}

View File

@ -0,0 +1,336 @@
use axum::http::Request;
const ENCODING_GZIP: &str = "gzip";
const ENCODING_IDENTITY: &str = "identity";
const ENCODING_WILDCARD: &str = "*";
const SUPPORTED_ENCODINGS: &[&str] = &[ENCODING_GZIP];
struct EncodingWithQuality {
encoding: String,
quality: f64,
}
fn parse_encoding_with_quality(value: &str) -> EncodingWithQuality {
let value = value.trim();
let mut quality = 1.0;
if let Some(idx) = value.find(';') {
let params = &value[idx + 1..];
let enc = value[..idx].trim();
for param in params.split(';') {
let param = param.trim();
if let Some(stripped) = param.strip_prefix("q=").or_else(|| param.strip_prefix("Q=")) {
if let Ok(q) = stripped.parse::<f64>() {
quality = q;
}
}
}
return EncodingWithQuality {
encoding: enc.to_ascii_lowercase(),
quality,
};
}
EncodingWithQuality {
encoding: value.to_ascii_lowercase(),
quality,
}
}
fn parse_accept_encoding_header(header: &str) -> (Vec<EncodingWithQuality>, bool) {
if header.is_empty() {
return (Vec::new(), false);
}
let encodings: Vec<EncodingWithQuality> =
header.split(',').map(|v| parse_encoding_with_quality(v)).collect();
let mut identity_rejected = false;
let mut identity_explicitly_accepted = false;
let mut wildcard_rejected = false;
for eq in &encodings {
match eq.encoding.as_str() {
ENCODING_IDENTITY => {
if eq.quality == 0.0 {
identity_rejected = true;
} else {
identity_explicitly_accepted = true;
}
}
ENCODING_WILDCARD => {
if eq.quality == 0.0 {
wildcard_rejected = true;
}
}
_ => {}
}
}
if wildcard_rejected && !identity_explicitly_accepted {
identity_rejected = true;
}
(encodings, identity_rejected)
}
pub fn is_identity_acceptable<B>(r: &Request<B>) -> bool {
let header = r
.headers()
.get("accept-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let (_, rejected) = parse_accept_encoding_header(header);
!rejected
}
pub fn parse_accept_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
let header = r
.headers()
.get("accept-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if header.is_empty() {
return Ok(ENCODING_IDENTITY);
}
let (mut encodings, identity_rejected) = parse_accept_encoding_header(header);
encodings.sort_by(|a, b| b.quality.partial_cmp(&a.quality).unwrap_or(std::cmp::Ordering::Equal));
for eq in &encodings {
if eq.quality == 0.0 {
continue;
}
if eq.encoding == ENCODING_IDENTITY {
return Ok(ENCODING_IDENTITY);
}
if eq.encoding == ENCODING_WILDCARD {
if identity_rejected && !SUPPORTED_ENCODINGS.is_empty() {
return Ok(SUPPORTED_ENCODINGS[0]);
}
return Ok(ENCODING_IDENTITY);
}
if eq.encoding == ENCODING_GZIP {
return Ok(ENCODING_GZIP);
}
}
if !identity_rejected {
return Ok(ENCODING_IDENTITY);
}
Err(format!("no acceptable encoding found, supported: {SUPPORTED_ENCODINGS:?}"))
}
pub fn parse_content_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
let header = r
.headers()
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if header.is_empty() {
return Ok(ENCODING_IDENTITY);
}
let encoding = header.trim().to_ascii_lowercase();
if encoding == ENCODING_IDENTITY {
return Ok(ENCODING_IDENTITY);
}
if SUPPORTED_ENCODINGS.contains(&encoding.as_str()) {
return Ok(ENCODING_GZIP);
}
Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}"))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
fn req_with_accept(v: &str) -> Request<()> {
Request::builder()
.header("accept-encoding", v)
.body(())
.unwrap()
}
fn req_with_content(v: &str) -> Request<()> {
Request::builder()
.header("content-encoding", v)
.body(())
.unwrap()
}
fn req_no_headers() -> Request<()> {
Request::builder().body(()).unwrap()
}
// parse_encoding_with_quality
#[test]
fn encoding_quality_default_1() {
let eq = parse_encoding_with_quality("gzip");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 1.0);
}
#[test]
fn encoding_quality_explicit() {
let eq = parse_encoding_with_quality("gzip;q=0.8");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.8);
}
#[test]
fn encoding_quality_case_insensitive() {
let eq = parse_encoding_with_quality("GZIP;Q=0.5");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.5);
}
#[test]
fn encoding_quality_zero() {
let eq = parse_encoding_with_quality("gzip;q=0");
assert_eq!(eq.quality, 0.0);
}
#[test]
fn encoding_quality_whitespace_trimmed() {
let eq = parse_encoding_with_quality(" gzip ; q=0.9 ");
assert_eq!(eq.encoding, "gzip");
assert_eq!(eq.quality, 0.9);
}
// parse_accept_encoding_header
#[test]
fn accept_header_empty() {
let (encs, rejected) = parse_accept_encoding_header("");
assert!(encs.is_empty());
assert!(!rejected);
}
#[test]
fn accept_header_identity_q0_rejects() {
let (_, rejected) = parse_accept_encoding_header("identity;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_rejects_identity() {
let (_, rejected) = parse_accept_encoding_header("*;q=0");
assert!(rejected);
}
#[test]
fn accept_header_wildcard_q0_but_identity_explicit_accepted() {
let (_, rejected) = parse_accept_encoding_header("*;q=0, identity");
assert!(!rejected);
}
// parse_accept_encoding (full)
#[test]
fn accept_encoding_no_header_returns_identity() {
assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip() {
assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip");
}
#[test]
fn accept_encoding_identity_explicit() {
assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity");
}
#[test]
fn accept_encoding_gzip_higher_quality() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_wildcard_returns_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity");
}
#[test]
fn accept_encoding_wildcard_identity_rejected_returns_gzip() {
assert_eq!(
parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(),
"gzip"
);
}
#[test]
fn accept_encoding_all_rejected_errors() {
assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err());
}
#[test]
fn accept_encoding_unsupported_only_falls_to_identity() {
assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity");
}
// is_identity_acceptable
#[test]
fn identity_acceptable_no_header() {
assert!(is_identity_acceptable(&req_no_headers()));
}
#[test]
fn identity_acceptable_gzip_only() {
assert!(is_identity_acceptable(&req_with_accept("gzip")));
}
#[test]
fn identity_not_acceptable_identity_q0() {
assert!(!is_identity_acceptable(&req_with_accept("identity;q=0")));
}
#[test]
fn identity_not_acceptable_wildcard_q0() {
assert!(!is_identity_acceptable(&req_with_accept("*;q=0")));
}
#[test]
fn identity_acceptable_wildcard_q0_but_identity_explicit() {
assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity")));
}
// parse_content_encoding
#[test]
fn content_encoding_empty_returns_identity() {
assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity");
}
#[test]
fn content_encoding_gzip() {
assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip");
}
#[test]
fn content_encoding_identity_explicit() {
assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity");
}
#[test]
fn content_encoding_unsupported_errors() {
assert!(parse_content_encoding(&req_with_content("br")).is_err());
}
#[test]
fn content_encoding_case_insensitive() {
assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip");
}
}

25
envd-rs/src/http/envs.rs Normal file
View File

@ -0,0 +1,25 @@
use std::collections::HashMap;
use std::sync::Arc;
use axum::Json;
use axum::extract::State;
use axum::http::header;
use axum::response::IntoResponse;
use crate::state::AppState;
pub async fn get_envs(State(state): State<Arc<AppState>>) -> impl IntoResponse {
tracing::debug!("getting env vars");
let envs: HashMap<String, String> = state
.defaults
.env_vars
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
(
[(header::CACHE_CONTROL, "no-store")],
Json(envs),
)
}

20
envd-rs/src/http/error.rs Normal file
View File

@ -0,0 +1,20 @@
use axum::Json;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::Serialize;
#[derive(Serialize)]
struct ErrorBody {
code: u16,
message: String,
}
pub fn json_error(status: StatusCode, message: &str) -> impl IntoResponse {
(
status,
Json(ErrorBody {
code: status.as_u16(),
message: message.to_string(),
}),
)
}

447
envd-rs/src/http/files.rs Normal file
View File

@ -0,0 +1,447 @@
use std::io::Write as _;
use std::path::Path;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::{FromRequest, Query, Request, State};
use axum::http::{StatusCode, header};
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use crate::auth::signing;
use crate::execcontext;
use crate::http::encoding;
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
use crate::permissions::user::lookup_user;
use crate::state::AppState;
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
#[derive(Deserialize)]
pub struct FileParams {
pub path: Option<String>,
pub username: Option<String>,
pub signature: Option<String>,
pub signature_expiration: Option<i64>,
}
#[derive(Serialize)]
struct EntryInfo {
path: String,
name: String,
r#type: &'static str,
}
fn json_error(status: StatusCode, msg: &str) -> Response {
let body = serde_json::json!({ "code": status.as_u16(), "message": msg });
(status, axum::Json(body)).into_response()
}
fn extract_header_token(req: &Request) -> Option<&str> {
req.headers()
.get(ACCESS_TOKEN_HEADER)
.and_then(|v| v.to_str().ok())
}
fn validate_file_signing(
state: &AppState,
header_token: Option<&str>,
params: &FileParams,
path: &str,
operation: &str,
username: &str,
) -> Result<(), String> {
signing::validate_signing(
&state.access_token,
header_token,
params.signature.as_deref(),
params.signature_expiration,
username,
path,
operation,
)
}
/// GET /files — download a file
pub async fn get_files(
State(state): State<Arc<AppState>>,
Query(params): Query<FileParams>,
req: Request,
) -> Response {
let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username(
params.username.as_deref(),
&default_user,
) {
Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
};
if let Err(e) = validate_file_signing(
&state,
header_token,
&params,
path_str,
signing::READ_OPERATION,
&username,
) {
return json_error(StatusCode::UNAUTHORIZED, &e);
}
let user = match lookup_user(&username) {
Ok(u) => u,
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
};
let home_dir = user.dir.to_string_lossy().to_string();
let default_workdir = state.defaults.workdir();
let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref())
{
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
};
let meta = match std::fs::metadata(&resolved) {
Ok(m) => m,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return json_error(
StatusCode::NOT_FOUND,
&format!("path '{}' does not exist", resolved),
);
}
Err(e) => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("error checking path: {e}"),
);
}
};
if meta.is_dir() {
return json_error(
StatusCode::BAD_REQUEST,
&format!("path '{}' is a directory", resolved),
);
}
if !meta.file_type().is_file() {
return json_error(
StatusCode::BAD_REQUEST,
&format!("path '{}' is not a regular file", resolved),
);
}
let accept_enc = match encoding::parse_accept_encoding(&req) {
Ok(e) => e,
Err(e) => return json_error(StatusCode::NOT_ACCEPTABLE, &e),
};
let has_range_or_conditional = req.headers().get("range").is_some()
|| req.headers().get("if-modified-since").is_some()
|| req.headers().get("if-none-match").is_some()
|| req.headers().get("if-range").is_some();
let use_encoding = if has_range_or_conditional {
if !encoding::is_identity_acceptable(&req) {
return json_error(
StatusCode::NOT_ACCEPTABLE,
"identity encoding not acceptable for Range or conditional request",
);
}
"identity"
} else {
accept_enc
};
let file_data = match std::fs::read(&resolved) {
Ok(d) => d,
Err(e) => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("error reading file: {e}"),
);
}
};
let filename = Path::new(&resolved)
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let content_disposition = format!("inline; filename=\"{}\"", filename);
let content_type = mime_guess::from_path(&resolved)
.first_raw()
.unwrap_or("application/octet-stream");
if use_encoding == "gzip" {
let mut encoder =
flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
if let Err(e) = encoder.write_all(&file_data) {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("gzip encoding error: {e}"),
);
}
let compressed = match encoder.finish() {
Ok(d) => d,
Err(e) => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("gzip finish error: {e}"),
);
}
};
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_ENCODING, "gzip")
.header(header::CONTENT_DISPOSITION, content_disposition)
.header(header::VARY, "Accept-Encoding")
.body(Body::from(compressed))
.unwrap();
}
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_DISPOSITION, content_disposition)
.header(header::VARY, "Accept-Encoding")
.header(header::CONTENT_LENGTH, file_data.len())
.body(Body::from(file_data))
.unwrap()
}
/// POST /files — upload file(s) via multipart
pub async fn post_files(
State(state): State<Arc<AppState>>,
Query(params): Query<FileParams>,
req: Request,
) -> Response {
let path_str = params.path.as_deref().unwrap_or("");
let header_token = extract_header_token(&req);
let default_user = state.defaults.user();
let username = match execcontext::resolve_default_username(
params.username.as_deref(),
&default_user,
) {
Ok(u) => u.to_string(),
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
};
if let Err(e) = validate_file_signing(
&state,
header_token,
&params,
path_str,
signing::WRITE_OPERATION,
&username,
) {
return json_error(StatusCode::UNAUTHORIZED, &e);
}
let user = match lookup_user(&username) {
Ok(u) => u,
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
};
let home_dir = user.dir.to_string_lossy().to_string();
let uid = user.uid;
let gid = user.gid;
let content_enc = match encoding::parse_content_encoding(&req) {
Ok(e) => e,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
};
let mut multipart = match axum::extract::Multipart::from_request(req, &()).await {
Ok(m) => m,
Err(e) => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("error parsing multipart: {e}"),
);
}
};
let mut uploaded: Vec<EntryInfo> = Vec::new();
let default_workdir = state.defaults.workdir();
while let Ok(Some(field)) = multipart.next_field().await {
let field_name = field.name().unwrap_or("").to_string();
if field_name != "file" {
continue;
}
let file_path = if !path_str.is_empty() {
match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) {
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
}
} else {
let fname = field
.file_name()
.unwrap_or("upload")
.to_string();
match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) {
Ok(p) => p,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
}
};
if uploaded.iter().any(|e| e.path == file_path) {
return json_error(
StatusCode::BAD_REQUEST,
&format!("cannot upload multiple files to same path '{}'", file_path),
);
}
let raw_bytes = match field.bytes().await {
Ok(b) => b,
Err(e) => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("error reading field: {e}"),
);
}
};
let data = if content_enc == "gzip" {
use std::io::Read;
let mut decoder = flate2::read::GzDecoder::new(&raw_bytes[..]);
let mut buf = Vec::new();
match decoder.read_to_end(&mut buf) {
Ok(_) => buf,
Err(e) => {
return json_error(
StatusCode::BAD_REQUEST,
&format!("gzip decompression failed: {e}"),
);
}
}
} else {
raw_bytes.to_vec()
};
if let Err(e) = process_file(&file_path, &data, uid, gid) {
let (status, msg) = e;
return json_error(status, &msg);
}
let name = Path::new(&file_path)
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
uploaded.push(EntryInfo {
path: file_path,
name,
r#type: "file",
});
}
axum::Json(uploaded).into_response()
}
fn process_file(
path: &str,
data: &[u8],
uid: nix::unistd::Uid,
gid: nix::unistd::Gid,
) -> Result<(), (StatusCode, String)> {
let dir = Path::new(path)
.parent()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_default();
if !dir.is_empty() {
ensure_dirs(&dir, uid, gid).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("error ensuring directories: {e}"),
)
})?;
}
let can_pre_chown = match std::fs::metadata(path) {
Ok(meta) => {
if meta.is_dir() {
return Err((
StatusCode::BAD_REQUEST,
format!("path is a directory: {path}"),
));
}
true
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => false,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("error getting file info: {e}"),
))
}
};
let mut chowned = false;
if can_pre_chown {
match std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())) {
Ok(()) => chowned = true,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("error changing ownership: {e}"),
))
}
}
}
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o666)
.open(path)
.map_err(|e| {
if e.raw_os_error() == Some(libc::ENOSPC) {
return (
StatusCode::INSUFFICIENT_STORAGE,
"not enough disk space available".to_string(),
);
}
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("error opening file: {e}"),
)
})?;
if !chowned {
std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("error changing ownership: {e}"),
)
})?;
}
file.write_all(data).map_err(|e| {
if e.raw_os_error() == Some(libc::ENOSPC) {
return (
StatusCode::INSUFFICIENT_STORAGE,
"not enough disk space available".to_string(),
);
}
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("error writing file: {e}"),
)
})?;
Ok(())
}
use std::os::unix::fs::OpenOptionsExt;

View File

@ -0,0 +1,18 @@
use std::sync::Arc;
use axum::Json;
use axum::extract::State;
use axum::http::header;
use axum::response::IntoResponse;
use serde_json::json;
use crate::state::AppState;
pub async fn get_health(State(state): State<Arc<AppState>>) -> impl IntoResponse {
tracing::trace!("health check");
(
[(header::CACHE_CONTROL, "no-store")],
Json(json!({ "version": state.version })),
)
}

312
envd-rs/src/http/init.rs Normal file
View File

@ -0,0 +1,312 @@
use std::collections::HashMap;
use std::sync::Arc;
use axum::Json;
use axum::extract::State;
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
use serde::Deserialize;
use crate::state::AppState;
#[derive(Deserialize, Default)]
pub struct InitRequest {
#[serde(rename = "access_token")]
pub access_token: Option<String>,
#[serde(rename = "defaultUser")]
pub default_user: Option<String>,
#[serde(rename = "defaultWorkdir")]
pub default_workdir: Option<String>,
#[serde(rename = "envVars")]
pub env_vars: Option<HashMap<String, String>>,
#[serde(rename = "hyperloop_ip")]
pub hyperloop_ip: Option<String>,
pub timestamp: Option<String>,
#[serde(rename = "volume_mounts")]
pub volume_mounts: Option<Vec<VolumeMount>>,
pub sandbox_id: Option<String>,
pub template_id: Option<String>,
/// New lifecycle identifier for this resume. When it changes between
/// /init calls, envd treats the call as a post-resume hook: port
/// forwarder is restarted and NFS mounts are refreshed.
pub lifecycle_id: Option<String>,
}
#[derive(Deserialize)]
pub struct VolumeMount {
pub nfs_target: String,
pub path: String,
}
/// POST /init — called by host agent after boot.
pub async fn post_init(
State(state): State<Arc<AppState>>,
body: Option<Json<InitRequest>>,
) -> impl IntoResponse {
let init_req = body.map(|b| b.0).unwrap_or_default();
// Validate access token if provided
if let Some(ref token_str) = init_req.access_token {
if let Err(e) = validate_init_access_token(&state, token_str).await {
tracing::error!(error = %e, "init: access token validation failed");
return (StatusCode::UNAUTHORIZED, e).into_response();
}
}
// Post-resume lifecycle hook: restart port forwarder so socat children
// are reaped + respawned against the new wall clock and any rotated
// listeners. Must run BEFORE the stale-timestamp early-return so a
// resume with an out-of-order timestamp still refreshes the subsystem.
let lifecycle_changed = if let Some(ref new_id) = init_req.lifecycle_id {
state.bump_lifecycle(new_id)
} else {
false
};
if lifecycle_changed {
// Each new lifecycle (i.e. a snapshot restore) requires a fresh memory
// preload pass — pages materialised before the previous pause are now
// back in the source memory-ranges file as the host re-restored them
// lazily. Reset the flags so the next POST /memory/preload kicks off
// a new loader instead of returning the stale "already-done".
use std::sync::atomic::Ordering;
state.mem_preload_cancel.store(false, Ordering::SeqCst);
state.mem_preload_done.store(false, Ordering::SeqCst);
state.mem_preload_started.store(false, Ordering::SeqCst);
state.mem_preload_regions.store(0, Ordering::SeqCst);
state.mem_preload_pages.store(0, Ordering::SeqCst);
state.mem_preload_bytes.store(0, Ordering::SeqCst);
state.mem_preload_elapsed_us.store(0, Ordering::SeqCst);
state.mem_preload_source.store(0, Ordering::SeqCst);
*state.mem_preload_error.lock().unwrap() = None;
if let Some(ref port_sub) = state.port_subsystem {
tracing::info!("lifecycle changed, restarting port subsystem");
port_sub.restart();
}
// Force chrony to step the clock immediately. chronyd is launched by
// wrenn-init.sh and disciplines against PHC (/dev/ptp0), so the host
// wall time is already available — `makestep` just bypasses chrony's
// normal slewing and snaps the clock in one go. Best effort.
tokio::spawn(async {
match tokio::process::Command::new("chronyc")
.args(["makestep"])
.output()
.await
{
Ok(out) if out.status.success() => {
tracing::info!("chronyc makestep ok");
}
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
tracing::warn!(stderr = %stderr, "chronyc makestep failed");
}
Err(e) => {
tracing::warn!(error = %e, "chronyc makestep spawn failed");
}
}
});
}
// Idempotent timestamp check. Run after lifecycle handling so a
// stale-timestamp /init still gets to refresh ports + step clock.
// No userspace clock_settime here — chrony owns time discipline.
if let Some(ref ts_str) = init_req.timestamp {
if let Ok(ts) = parse_timestamp_to_nanos(ts_str) {
if !state.last_set_time.set_to_greater(ts) {
return (
StatusCode::NO_CONTENT,
[(header::CACHE_CONTROL, "no-store")],
)
.into_response();
}
}
}
// Apply env vars
if let Some(ref vars) = init_req.env_vars {
tracing::debug!(count = vars.len(), "setting env vars");
for (k, v) in vars {
state.defaults.env_vars.insert(k.clone(), v.clone());
}
}
// Set access token
if let Some(ref token_str) = init_req.access_token {
if !token_str.is_empty() {
tracing::debug!("setting access token");
let _ = state.access_token.set(token_str.as_bytes());
} else if state.access_token.is_set() {
tracing::debug!("clearing access token");
state.access_token.destroy();
}
}
// Set default user
if let Some(ref user) = init_req.default_user {
if !user.is_empty() {
tracing::debug!(user = %user, "setting default user");
state.defaults.set_user(user.clone());
}
}
// Set default workdir
if let Some(ref workdir) = init_req.default_workdir {
if !workdir.is_empty() {
tracing::debug!(workdir = %workdir, "setting default workdir");
state.defaults.set_workdir(Some(workdir.clone()));
}
}
// Hyperloop /etc/hosts setup. Awaited so callers that immediately
// resolve events.wrenn.local see the entry. Cheap (two file ops).
if let Some(ref ip) = init_req.hyperloop_ip {
setup_hyperloop(ip, &state.defaults.env_vars).await;
}
// NFS mounts. Awaited in parallel so callers that immediately access the
// mount path don't race the mount(2). Previously these were detached via
// tokio::spawn, which let /init return success before mounts existed.
if let Some(ref mounts) = init_req.volume_mounts {
let futs = mounts.iter().map(|m| {
let target = m.nfs_target.clone();
let path = m.path.clone();
async move {
setup_nfs(&target, &path).await;
}
});
futures::future::join_all(futs).await;
}
// Set sandbox/template metadata from request body.
if let Some(ref id) = init_req.sandbox_id {
tracing::debug!(sandbox_id = %id, "setting sandbox ID from init request");
// SAFETY: envd is single-threaded at init time; no concurrent env reads.
unsafe { std::env::set_var("WRENN_SANDBOX_ID", id) };
write_run_file(".WRENN_SANDBOX_ID", id);
state.defaults.env_vars.insert("WRENN_SANDBOX_ID".into(), id.clone());
}
if let Some(ref id) = init_req.template_id {
tracing::debug!(template_id = %id, "setting template ID from init request");
// SAFETY: envd is single-threaded at init time; no concurrent env reads.
unsafe { std::env::set_var("WRENN_TEMPLATE_ID", id) };
write_run_file(".WRENN_TEMPLATE_ID", id);
state.defaults.env_vars.insert("WRENN_TEMPLATE_ID".into(), id.clone());
}
(
StatusCode::NO_CONTENT,
[(header::CACHE_CONTROL, "no-store")],
)
.into_response()
}
async fn validate_init_access_token(state: &AppState, request_token: &str) -> Result<(), String> {
// Fast path: matches existing token
if state.access_token.is_set() && !request_token.is_empty() && state.access_token.equals(request_token) {
return Ok(());
}
// First-time setup: no existing token
if !state.access_token.is_set() {
return Ok(());
}
if request_token.is_empty() {
return Err("access token reset not authorized".into());
}
Err("access token validation failed".into())
}
async fn setup_hyperloop(address: &str, env_vars: &dashmap::DashMap<String, String>) {
// Write to /etc/hosts: events.wrenn.local → address
let entry = format!("{address} events.wrenn.local\n");
match std::fs::read_to_string("/etc/hosts") {
Ok(contents) => {
let filtered: String = contents
.lines()
.filter(|line| !line.contains("events.wrenn.local"))
.collect::<Vec<_>>()
.join("\n");
let new_contents = format!("{filtered}\n{entry}");
if let Err(e) = std::fs::write("/etc/hosts", new_contents) {
tracing::error!(error = %e, "failed to modify hosts file");
return;
}
}
Err(e) => {
tracing::error!(error = %e, "failed to read hosts file");
return;
}
}
env_vars.insert(
"WRENN_EVENTS_ADDRESS".into(),
format!("http://{address}"),
);
}
async fn setup_nfs(nfs_target: &str, path: &str) {
let mkdir = tokio::process::Command::new("mkdir")
.args(["-p", path])
.output()
.await;
if let Err(e) = mkdir {
tracing::error!(error = %e, path, "nfs: mkdir failed");
return;
}
let mount = tokio::process::Command::new("mount")
.args([
"-v",
"-t",
"nfs",
"-o",
"mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl",
nfs_target,
path,
])
.output()
.await;
match mount {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
if output.status.success() {
tracing::info!(nfs_target, path, stdout = %stdout, "nfs: mount success");
} else {
tracing::error!(nfs_target, path, stderr = %stderr, "nfs: mount failed");
}
}
Err(e) => {
tracing::error!(error = %e, nfs_target, path, "nfs: mount command failed");
}
}
}
fn write_run_file(name: &str, value: &str) {
let dir = std::path::Path::new("/run/wrenn");
if let Err(e) = std::fs::create_dir_all(dir) {
tracing::warn!(error = %e, "failed to create /run/wrenn");
return;
}
if let Err(e) = std::fs::write(dir.join(name), value) {
tracing::warn!(error = %e, name, "failed to write run file");
}
}
/// Parses a host-provided timestamp into nanoseconds since the Unix epoch.
/// Accepts either RFC3339 (`2026-05-17T16:13:03.123456Z`) or a float-seconds
/// string (legacy callers).
fn parse_timestamp_to_nanos(ts: &str) -> Result<i64, ()> {
if let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(ts) {
return Ok(parsed.timestamp_nanos_opt().ok_or(())?);
}
if let Ok(secs) = ts.parse::<f64>() {
return Ok((secs * 1_000_000_000.0) as i64);
}
Err(())
}

350
envd-rs/src/http/memory.rs Normal file
View File

@ -0,0 +1,350 @@
// POST /memory/preload — guest-side helper that materialises every physical
// RAM page so a subsequent ch.snapshot writes a self-contained memory-ranges
// file. Required after a restore with memory_restore_mode=ondemand: pages
// that were never demand-faulted live only in the source memory-ranges file
// and would become holes (read back as zero) in the new snapshot.
//
// Trigger is one-byte-per-page reads through /dev/mem (fallback /proc/kcore
// PT_LOAD segments). The guest kernel walks its direct map → accesses the
// physical page → host kernel handles the EPT entry → CH's userfaultfd
// handler fills the page from the source memory-ranges file.
//
// Wire protocol:
// POST /memory/preload — starts the loader (idempotent) and returns
// the current status JSON immediately
// GET /memory/preload — returns the current status JSON
// POST /memory/preload/cancel — signals the loader to stop early
//
// Returning immediately avoids any HTTP-level header timeout in the caller
// while materialisation (hundreds of MiB at one byte per page) runs in a
// background blocking thread.
use std::fs;
use std::io::{Read, Seek, SeekFrom};
use std::os::unix::fs::FileExt;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Instant;
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::Serialize;
use crate::state::AppState;
const PAGE_SIZE: u64 = 4096;
#[derive(Serialize, Clone)]
pub struct PreloadStatus {
/// One of: "idle", "running", "done", "failed", "cancelled".
pub state: &'static str,
pub regions: u64,
pub pages: u64,
pub bytes: u64,
pub elapsed_sec: f64,
pub source: &'static str,
pub error: Option<String>,
}
pub async fn post_memory_preload(State(state): State<Arc<AppState>>) -> impl IntoResponse {
// First caller wins the CAS and spawns the loader; subsequent callers
// just report the existing status.
let we_start = state
.mem_preload_started
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok();
if we_start {
let state_clone = Arc::clone(&state);
// Detached blocking thread — no axum task lifetime ties it to the
// request, so the connection can close immediately without aborting
// materialisation. Lifecycle bump on the next /init clears the flags
// for a fresh run after a restore.
std::thread::spawn(move || {
let started = Instant::now();
match preload_blocking(&state_clone) {
Ok((source, regions, pages, bytes)) => {
let elapsed = started.elapsed().as_secs_f64();
state_clone
.mem_preload_regions
.store(regions, Ordering::SeqCst);
state_clone.mem_preload_pages.store(pages, Ordering::SeqCst);
state_clone.mem_preload_bytes.store(bytes, Ordering::SeqCst);
state_clone
.mem_preload_elapsed_us
.store((elapsed * 1_000_000.0) as u64, Ordering::SeqCst);
set_source(&state_clone, source);
*state_clone.mem_preload_error.lock().unwrap() = None;
state_clone.mem_preload_done.store(true, Ordering::SeqCst);
tracing::info!(
regions,
pages,
bytes,
elapsed_sec = elapsed,
source,
"memory preload complete"
);
}
Err(e) => {
let elapsed = started.elapsed().as_secs_f64();
state_clone
.mem_preload_elapsed_us
.store((elapsed * 1_000_000.0) as u64, Ordering::SeqCst);
*state_clone.mem_preload_error.lock().unwrap() = Some(e.clone());
state_clone.mem_preload_done.store(true, Ordering::SeqCst);
tracing::warn!(error = %e, "memory preload failed");
}
}
});
}
let status = read_status(&state);
(StatusCode::OK, Json(status))
}
pub async fn get_memory_preload(State(state): State<Arc<AppState>>) -> impl IntoResponse {
(StatusCode::OK, Json(read_status(&state)))
}
pub async fn post_memory_preload_cancel(State(state): State<Arc<AppState>>) -> impl IntoResponse {
state.mem_preload_cancel.store(true, Ordering::SeqCst);
StatusCode::NO_CONTENT
}
fn read_status(state: &AppState) -> PreloadStatus {
let started = state.mem_preload_started.load(Ordering::SeqCst);
let done = state.mem_preload_done.load(Ordering::SeqCst);
let cancelled = state.mem_preload_cancel.load(Ordering::SeqCst);
let error = state.mem_preload_error.lock().unwrap().clone();
let lane = if !started {
"idle"
} else if !done {
"running"
} else if let Some(_) = &error {
"failed"
} else if cancelled {
"cancelled"
} else {
"done"
};
PreloadStatus {
state: lane,
regions: state.mem_preload_regions.load(Ordering::SeqCst),
pages: state.mem_preload_pages.load(Ordering::SeqCst),
bytes: state.mem_preload_bytes.load(Ordering::SeqCst),
elapsed_sec: state.mem_preload_elapsed_us.load(Ordering::SeqCst) as f64 / 1_000_000.0,
source: get_source(state),
error,
}
}
fn set_source(state: &AppState, src: &'static str) {
let code: u8 = match src {
"/dev/mem" => 1,
"/proc/kcore" => 2,
_ => 0,
};
state.mem_preload_source.store(code, Ordering::SeqCst);
}
fn get_source(state: &AppState) -> &'static str {
match state.mem_preload_source.load(Ordering::SeqCst) {
1 => "/dev/mem",
2 => "/proc/kcore",
_ => "",
}
}
fn preload_blocking(state: &AppState) -> Result<(&'static str, u64, u64, u64), String> {
let ranges = parse_system_ram_ranges().map_err(|e| format!("iomem: {e}"))?;
if ranges.is_empty() {
return Err("no System RAM ranges found in /proc/iomem".into());
}
let mut pages: u64 = 0;
let mut bytes: u64 = 0;
match preload_via_devmem(&ranges, state, &mut pages, &mut bytes) {
Ok(()) => Ok(("/dev/mem", ranges.len() as u64, pages, bytes)),
Err(devmem_err) => {
tracing::warn!(
error = %devmem_err,
"/dev/mem preload failed, falling back to /proc/kcore"
);
pages = 0;
bytes = 0;
preload_via_kcore(state, &mut pages, &mut bytes)
.map_err(|e| format!("/dev/mem: {devmem_err}; /proc/kcore: {e}"))?;
Ok(("/proc/kcore", ranges.len() as u64, pages, bytes))
}
}
}
fn parse_system_ram_ranges() -> std::io::Result<Vec<(u64, u64)>> {
let data = fs::read_to_string("/proc/iomem")?;
let mut out = Vec::new();
for line in data.lines() {
if line.starts_with(|c: char| c.is_whitespace()) {
continue;
}
let Some((range, label)) = line.split_once(" : ") else {
continue;
};
if label.trim() != "System RAM" {
continue;
}
let Some((start, end)) = range.split_once('-') else {
continue;
};
let start = u64::from_str_radix(start.trim(), 16).ok();
let end = u64::from_str_radix(end.trim(), 16).ok();
if let (Some(s), Some(e)) = (start, end) {
out.push((s, e.saturating_add(1)));
}
}
Ok(out)
}
fn preload_via_devmem(
ranges: &[(u64, u64)],
state: &AppState,
pages: &mut u64,
bytes: &mut u64,
) -> std::io::Result<()> {
let f = fs::File::open("/dev/mem")?;
let mut buf = [0u8; 1];
for (start, end) in ranges {
let mut off = *start;
while off < *end {
if state.mem_preload_cancel.load(Ordering::SeqCst) {
return Ok(());
}
f.read_at(&mut buf, off)?;
*pages += 1;
*bytes += PAGE_SIZE;
// Publish progress so GET /memory/preload reports useful numbers
// while the loader is still running.
if *pages % 1024 == 0 {
state.mem_preload_pages.store(*pages, Ordering::SeqCst);
state.mem_preload_bytes.store(*bytes, Ordering::SeqCst);
}
off = off.saturating_add(PAGE_SIZE);
}
}
Ok(())
}
// Read /proc/kcore's direct-map segment to materialise physical RAM. The
// direct map's PT_LOAD covers the kernel's *maximum* possible direct-map
// region (64TB on x86_64), not just the present physical RAM — iterating the
// whole segment would loop for billions of pages. Bound the walk to the sum
// of System RAM ranges from /proc/iomem; sequential reads through the
// segment touch consecutive physical pages 1:1, which is what we need.
fn preload_via_kcore(state: &AppState, pages: &mut u64, bytes: &mut u64) -> std::io::Result<()> {
let ram_ranges = parse_system_ram_ranges()?;
if ram_ranges.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no System RAM ranges to bound kcore walk",
));
}
let total_ram_bytes: u64 = ram_ranges.iter().map(|(s, e)| e - s).sum();
let mut f = fs::File::open("/proc/kcore")?;
let segments = parse_kcore_pt_load(&mut f)?;
if segments.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no PT_LOAD segments in /proc/kcore",
));
}
// Pick the direct map: highest-vaddr kernel-space segment large enough
// to plausibly cover RAM. KASLR randomises the base, so don't hardcode it.
// Kernel virtual addresses start at 0xffff800000000000 on x86_64; vmalloc
// / modules sit above the direct map and are usually smaller.
const KERNEL_SPACE_MIN: u64 = 0xffff_8000_0000_0000;
let direct_map = segments
.iter()
.filter(|s| s.vaddr >= KERNEL_SPACE_MIN && s.file_size >= total_ram_bytes)
.min_by_key(|s| s.vaddr)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"no PT_LOAD segment large enough for {} bytes of RAM in /proc/kcore",
total_ram_bytes
),
)
})?;
let mut buf = [0u8; 1];
let read_bytes = total_ram_bytes.min(direct_map.file_size);
let start = direct_map.file_offset;
let end = start.saturating_add(read_bytes);
let mut off = start;
while off < end {
if state.mem_preload_cancel.load(Ordering::SeqCst) {
return Ok(());
}
// Reads into MMIO holes within the direct map can fail; ignore so the
// loop keeps making progress over the present RAM ranges either side.
if f.read_at(&mut buf, off).is_ok() {
*pages += 1;
*bytes += PAGE_SIZE;
if *pages % 256 == 0 {
state.mem_preload_pages.store(*pages, Ordering::SeqCst);
state.mem_preload_bytes.store(*bytes, Ordering::SeqCst);
}
}
off = off.saturating_add(PAGE_SIZE);
}
Ok(())
}
struct KcoreSegment {
file_offset: u64,
file_size: u64,
vaddr: u64,
}
fn parse_kcore_pt_load(f: &mut fs::File) -> std::io::Result<Vec<KcoreSegment>> {
let mut hdr = [0u8; 64];
f.seek(SeekFrom::Start(0))?;
f.read_exact(&mut hdr)?;
if &hdr[0..4] != b"\x7fELF" || hdr[4] != 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"not an ELF64 file",
));
}
let e_phoff = u64::from_le_bytes(hdr[32..40].try_into().unwrap());
let e_phentsize = u16::from_le_bytes(hdr[54..56].try_into().unwrap()) as u64;
let e_phnum = u16::from_le_bytes(hdr[56..58].try_into().unwrap()) as u64;
let mut out = Vec::new();
let mut entry = vec![0u8; e_phentsize as usize];
for i in 0..e_phnum {
f.seek(SeekFrom::Start(e_phoff + i * e_phentsize))?;
f.read_exact(&mut entry)?;
let p_type = u32::from_le_bytes(entry[0..4].try_into().unwrap());
if p_type != 1 {
continue;
}
let p_offset = u64::from_le_bytes(entry[8..16].try_into().unwrap());
let p_vaddr = u64::from_le_bytes(entry[16..24].try_into().unwrap());
let p_filesz = u64::from_le_bytes(entry[32..40].try_into().unwrap());
out.push(KcoreSegment {
file_offset: p_offset,
file_size: p_filesz,
vaddr: p_vaddr,
});
}
Ok(out)
}

View File

@ -0,0 +1,89 @@
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::Json;
use axum::extract::State;
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
use serde::Serialize;
use crate::state::AppState;
#[derive(Serialize)]
pub struct Metrics {
ts: i64,
cpu_count: u32,
cpu_used_pct: f32,
mem_total_mib: u64,
mem_used_mib: u64,
mem_total: u64,
mem_used: u64,
disk_used: u64,
disk_total: u64,
}
pub async fn get_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
tracing::trace!("get metrics");
match collect_metrics(&state) {
Ok(m) => (
StatusCode::OK,
[(header::CACHE_CONTROL, "no-store")],
Json(m),
)
.into_response(),
Err(e) => {
tracing::error!(error = %e, "failed to get metrics");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
fn collect_metrics(state: &AppState) -> Result<Metrics, String> {
let cpu_count = state.cpu_count();
let cpu_used_pct_rounded = state.cpu_used_pct();
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let mem_total = sys.total_memory();
let mem_available = sys.available_memory();
let mem_used = mem_total.saturating_sub(mem_available);
let mem_total_mib = mem_total / 1024 / 1024;
let mem_used_mib = mem_used / 1024 / 1024;
let (disk_total, disk_used) = disk_stats("/").map_err(|e| e.to_string())?;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
Ok(Metrics {
ts,
cpu_count,
cpu_used_pct: cpu_used_pct_rounded,
mem_total_mib,
mem_used_mib,
mem_total,
mem_used,
disk_used,
disk_total,
})
}
fn disk_stats(path: &str) -> Result<(u64, u64), nix::Error> {
use std::ffi::CString;
let c_path = CString::new(path).unwrap();
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
if ret != 0 {
return Err(nix::Error::last());
}
let block = stat.f_bsize as u64;
let total = stat.f_blocks * block;
let available = stat.f_bavail * block;
Ok((total, total - available))
}

65
envd-rs/src/http/mod.rs Normal file
View File

@ -0,0 +1,65 @@
pub mod encoding;
pub mod envs;
pub mod error;
pub mod files;
pub mod health;
pub mod init;
pub mod memory;
pub mod metrics;
pub mod snapshot;
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::routing::{get, post};
use http::header::{CACHE_CONTROL, HeaderName};
use http::Method;
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
use crate::config::CORS_MAX_AGE;
use crate::state::AppState;
pub fn router(state: Arc<AppState>) -> Router {
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::any())
.allow_methods(AllowMethods::list([
Method::HEAD,
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
]))
.allow_headers(AllowHeaders::any())
.expose_headers([
HeaderName::from_static("location"),
CACHE_CONTROL,
HeaderName::from_static("x-content-type-options"),
HeaderName::from_static("connect-content-encoding"),
HeaderName::from_static("connect-protocol-version"),
HeaderName::from_static("grpc-encoding"),
HeaderName::from_static("grpc-message"),
HeaderName::from_static("grpc-status"),
HeaderName::from_static("grpc-status-details-bin"),
])
.max_age(Duration::from_secs(CORS_MAX_AGE.as_secs()));
Router::new()
.route("/health", get(health::get_health))
.route("/metrics", get(metrics::get_metrics))
.route("/envs", get(envs::get_envs))
.route("/init", post(init::post_init))
.route("/snapshot/prepare", post(snapshot::post_snapshot_prepare))
.route(
"/memory/preload",
get(memory::get_memory_preload).post(memory::post_memory_preload),
)
.route(
"/memory/preload/cancel",
post(memory::post_memory_preload_cancel),
)
.route("/files", get(files::get_files).post(files::post_files))
.layer(cors)
.with_state(state)
}

View File

@ -0,0 +1,60 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
use nix::unistd::sync;
use crate::state::AppState;
/// POST /snapshot/prepare — called by the host agent immediately before it
/// invokes vm.pause + vm.snapshot. The handler quiesces guest state so the
/// resulting snapshot is clean: outstanding writes are flushed to disk, the
/// VFS page cache is dropped (so the dm-snapshot CoW is the source of truth),
/// and the port forwarder is stopped to prevent socat children from being
/// frozen mid-handshake.
pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse {
// Stop port forwarder + scanner so no socat process is captured in the
// snapshot with a half-open TCP connection. /init on resume restarts it.
if let Some(ref port_sub) = state.port_subsystem {
port_sub.stop();
}
// sync(2) flushes the in-memory FS state. Done before drop_caches so the
// pages we drop are clean.
sync();
// Drop the VFS page cache + dentries/inodes. Reduces snapshot size by
// ensuring CH only persists memory pages that the guest actually needs.
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
tracing::warn!(error = %e, "drop_caches (first pass) failed (continuing)");
}
// Best-effort fstrim on the rootfs so unused blocks are returned to the
// dm-snapshot, keeping CoW size minimal.
let _ = tokio::process::Command::new("fstrim")
.arg("/")
.output()
.await;
// Second drop_caches pass after fstrim: fstrim re-reads superblock /
// group descriptor pages that we just evicted, putting them back in the
// page cache. A second pass drops those and any other late readers (e.g.
// sync flushers).
sync();
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
tracing::warn!(error = %e, "drop_caches (second pass) failed (continuing)");
}
// Free-page reporting drains asynchronously: the balloon driver hands
// freed pages to the host in batches and CH punches holes in the backing
// memfile. Without a brief settle window most of the pages freed by the
// drop_caches passes above would still be present in the snapshot.
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
tracing::info!("snapshot/prepare: quiesced");
(
StatusCode::NO_CONTENT,
[(header::CACHE_CONTROL, "no-store")],
)
}

17
envd-rs/src/logging.rs Normal file
View File

@ -0,0 +1,17 @@
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
pub fn init(json: bool) {
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
if json {
tracing_subscriber::registry()
.with(filter)
.with(fmt::layer().json().flatten_event(true))
.init();
} else {
tracing_subscriber::registry()
.with(filter)
.with(fmt::layer())
.init();
}
}

244
envd-rs/src/main.rs Normal file
View File

@ -0,0 +1,244 @@
#![allow(dead_code)]
mod auth;
mod cgroups;
mod config;
mod conntracker;
mod crypto;
mod execcontext;
mod http;
mod logging;
mod permissions;
mod port;
mod rpc;
mod state;
mod util;
use std::fs;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use clap::Parser;
use tokio::net::TcpListener;
use config::{DEFAULT_PORT, DEFAULT_USER, WRENN_RUN_DIR};
use execcontext::Defaults;
use port::subsystem::PortSubsystem;
use state::AppState;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const COMMIT: &str = {
match option_env!("ENVD_COMMIT") {
Some(c) => c,
None => "unknown",
}
};
#[derive(Parser)]
#[command(name = "envd", about = "Wrenn guest agent daemon")]
struct Cli {
#[arg(long, default_value_t = DEFAULT_PORT)]
port: u16,
#[arg(long)]
version: bool,
#[arg(long)]
commit: bool,
#[arg(long = "cmd", default_value = "")]
start_cmd: String,
#[arg(long = "cgroup-root", default_value = "/sys/fs/cgroup")]
cgroup_root: String,
}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
if cli.version {
println!("{VERSION}");
return;
}
if cli.commit {
println!("{COMMIT}");
return;
}
logging::init(true);
if let Err(e) = fs::create_dir_all(WRENN_RUN_DIR) {
tracing::error!(error = %e, "failed to create wrenn run directory");
}
let defaults = Defaults::new(DEFAULT_USER);
defaults
.env_vars
.insert("WRENN_SANDBOX".into(), "true".into());
let wrenn_sandbox_path = Path::new(WRENN_RUN_DIR).join(".WRENN_SANDBOX");
if let Err(e) = fs::write(&wrenn_sandbox_path, b"true") {
tracing::error!(error = %e, "failed to write sandbox file");
}
// Cgroup manager
let cgroup_manager: Arc<dyn cgroups::CgroupManager> =
match cgroups::Cgroup2Manager::new(
&cli.cgroup_root,
&[
(
cgroups::ProcessType::Pty,
"wrenn/pty",
&[] as &[(&str, &str)],
),
(
cgroups::ProcessType::User,
"wrenn/user",
&[] as &[(&str, &str)],
),
(
cgroups::ProcessType::Socat,
"wrenn/socat",
&[] as &[(&str, &str)],
),
],
) {
Ok(m) => {
tracing::info!("cgroup2 manager initialized");
Arc::new(m)
}
Err(e) => {
tracing::warn!(error = %e, "cgroup2 init failed, using noop");
Arc::new(cgroups::NoopCgroupManager)
}
};
// Port subsystem
let port_subsystem = Arc::new(PortSubsystem::new(Arc::clone(&cgroup_manager)));
port_subsystem.start();
tracing::info!("port subsystem started");
let state = AppState::new(
defaults,
VERSION.to_string(),
COMMIT.to_string(),
Some(Arc::clone(&port_subsystem)),
);
// Memory reclaimer — drop page cache when available memory is low.
// The balloon device can only reclaim pages the guest kernel freed.
{
let state_for_reclaimer = Arc::clone(&state);
std::thread::spawn(move || memory_reclaimer(state_for_reclaimer));
}
// RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
let connect_router = rpc::rpc_router(Arc::clone(&state));
let app = http::router(Arc::clone(&state))
.fallback_service(connect_router.into_axum_service());
// --cmd: spawn initial process if specified
if !cli.start_cmd.is_empty() {
let cmd = cli.start_cmd.clone();
let state_clone = Arc::clone(&state);
tokio::spawn(async move {
spawn_initial_command(&cmd, &state_clone);
});
}
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
tracing::info!(port = cli.port, version = VERSION, commit = COMMIT, "envd starting");
let listener = TcpListener::bind(addr).await.expect("failed to bind");
let graceful = axum::serve(listener, app).with_graceful_shutdown(async move {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM")
.recv()
.await;
tracing::info!("SIGTERM received, shutting down");
});
if let Err(e) = graceful.await {
tracing::error!(error = %e, "server error");
}
port_subsystem.stop();
}
fn spawn_initial_command(cmd: &str, state: &AppState) {
use crate::permissions::user::lookup_user;
use crate::rpc::process_handler;
use std::collections::HashMap;
let default_user = state.defaults.user();
let user = match lookup_user(&default_user) {
Ok(u) => u,
Err(e) => {
tracing::error!(error = %e, "cmd: failed to lookup user");
return;
}
};
let home = user.dir.to_string_lossy().to_string();
let default_workdir = state.defaults.workdir();
let cwd = default_workdir
.as_deref()
.unwrap_or(&home);
match process_handler::spawn_process(
cmd,
&[],
&HashMap::new(),
cwd,
None,
false,
Some("init-cmd".to_string()),
&user,
&state.defaults.env_vars,
) {
Ok(spawned) => {
tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned");
}
Err(e) => {
tracing::error!(error = %e, cmd, "failed to spawn initial command");
}
}
}
fn memory_reclaimer(_state: Arc<AppState>) {
use std::time::Duration;
const CHECK_INTERVAL: Duration = Duration::from_secs(10);
const DROP_THRESHOLD_PCT: u64 = 80;
loop {
std::thread::sleep(CHECK_INTERVAL);
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let total = sys.total_memory();
let available = sys.available_memory();
if total == 0 {
continue;
}
let used_pct = ((total - available) * 100) / total;
if used_pct >= DROP_THRESHOLD_PCT {
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
tracing::debug!(error = %e, "drop_caches failed");
} else {
let mut sys2 = sysinfo::System::new();
sys2.refresh_memory();
let freed_mb =
sys2.available_memory().saturating_sub(available) / (1024 * 1024);
tracing::info!(used_pct, freed_mb, "page cache dropped");
}
}
}
}

View File

@ -0,0 +1,2 @@
pub mod user;
pub mod path;

View File

@ -0,0 +1,184 @@
use std::fs;
use std::os::unix::fs::chown;
use std::path::{Path, PathBuf};
use nix::unistd::{Gid, Uid};
fn expand_tilde(path: &str, home_dir: &str) -> Result<String, String> {
if path.is_empty() || !path.starts_with('~') {
return Ok(path.to_string());
}
if path.len() > 1 && path.as_bytes()[1] != b'/' && path.as_bytes()[1] != b'\\' {
return Err("cannot expand user-specific home dir".into());
}
Ok(format!("{}{}", home_dir, &path[1..]))
}
pub fn expand_and_resolve(
path: &str,
home_dir: &str,
default_path: Option<&str>,
) -> Result<String, String> {
let path = if path.is_empty() {
default_path.unwrap_or("").to_string()
} else {
path.to_string()
};
let path = expand_tilde(&path, home_dir)?;
if Path::new(&path).is_absolute() {
return Ok(path);
}
let joined = PathBuf::from(home_dir).join(&path);
joined
.canonicalize()
.or_else(|_| Ok(joined))
.map(|p| p.to_string_lossy().to_string())
}
pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> {
let path = Path::new(path);
let mut current = PathBuf::new();
for component in path.components() {
current.push(component);
let current_str = current.to_string_lossy();
if current_str == "/" {
continue;
}
match fs::metadata(&current) {
Ok(meta) => {
if !meta.is_dir() {
return Err(format!("path is a file: {current_str}"));
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
fs::create_dir(&current)
.map_err(|e| format!("failed to create directory {current_str}: {e}"))?;
chown(&current, Some(uid.as_raw()), Some(gid.as_raw()))
.map_err(|e| format!("failed to chown directory {current_str}: {e}"))?;
}
Err(e) => {
return Err(format!("failed to stat directory {current_str}: {e}"));
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// expand_tilde
#[test]
fn tilde_empty_passthrough() {
assert_eq!(expand_tilde("", "/home/u").unwrap(), "");
}
#[test]
fn tilde_no_tilde_passthrough() {
assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute");
}
#[test]
fn tilde_bare() {
assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user");
}
#[test]
fn tilde_slash_path() {
assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs");
}
#[test]
fn tilde_nested() {
assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c");
}
#[test]
fn tilde_other_user_errors() {
assert!(expand_tilde("~bob/foo", "/home/user").is_err());
}
#[test]
fn tilde_relative_no_tilde() {
assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path");
}
// expand_and_resolve
#[test]
fn resolve_absolute_passthrough() {
assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path");
}
#[test]
fn resolve_empty_uses_default() {
assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default");
}
#[test]
fn resolve_empty_no_default_falls_back_to_home() {
// Empty path with no default → joins "" with home_dir → returns home_dir
let result = expand_and_resolve("", "/home", None).unwrap();
assert_eq!(result, "/home");
}
#[test]
fn resolve_tilde_expands() {
assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir");
}
#[test]
fn resolve_relative_joins_home() {
let result = expand_and_resolve("subdir", "/tmp", None).unwrap();
// Relative path joined with home and canonicalized (or raw join on missing)
assert!(result.starts_with("/tmp"));
assert!(result.contains("subdir"));
}
#[test]
fn resolve_tilde_other_user_errors() {
assert!(expand_and_resolve("~bob", "/home/u", None).is_err());
}
// ensure_dirs
#[test]
fn ensure_dirs_creates_nested() {
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("a/b/c");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap();
assert!(path.is_dir());
}
#[test]
fn ensure_dirs_existing_is_ok() {
let tmp = tempfile::TempDir::new().unwrap();
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap();
}
#[test]
fn ensure_dirs_file_in_path_errors() {
let tmp = tempfile::TempDir::new().unwrap();
let file_path = tmp.path().join("afile");
std::fs::write(&file_path, "").unwrap();
let nested = file_path.join("subdir");
let uid = nix::unistd::getuid();
let gid = nix::unistd::getgid();
let result = ensure_dirs(nested.to_str().unwrap(), uid, gid);
assert!(result.is_err());
assert!(result.unwrap_err().contains("path is a file"));
}
}

View File

@ -0,0 +1,32 @@
use nix::unistd::{Gid, Group, Uid, User};
pub fn lookup_user(username: &str) -> Result<User, String> {
User::from_name(username)
.map_err(|e| format!("error looking up user '{username}': {e}"))?
.ok_or_else(|| format!("user '{username}' not found"))
}
pub fn get_uid_gid(user: &User) -> (Uid, Gid) {
(user.uid, user.gid)
}
pub fn get_user_groups(user: &User) -> Vec<Gid> {
let c_name = std::ffi::CString::new(user.name.as_str()).unwrap();
nix::unistd::getgrouplist(&c_name, user.gid).unwrap_or_default()
}
pub fn lookup_username_by_uid(uid: Uid) -> String {
User::from_uid(uid)
.ok()
.flatten()
.map(|u| u.name)
.unwrap_or_else(|| uid.to_string())
}
pub fn lookup_groupname_by_gid(gid: Gid) -> String {
Group::from_gid(gid)
.ok()
.flatten()
.map(|g| g.name)
.unwrap_or_else(|| gid.to_string())
}

260
envd-rs/src/port/conn.rs Normal file
View File

@ -0,0 +1,260 @@
use std::io::{self, BufRead};
#[derive(Debug, Clone)]
pub struct ConnStat {
pub local_ip: String,
pub local_port: u32,
pub status: String,
pub family: u32,
pub inode: u64,
}
fn tcp_state_name(hex: &str) -> &'static str {
match hex {
"01" => "ESTABLISHED",
"02" => "SYN_SENT",
"03" => "SYN_RECV",
"04" => "FIN_WAIT1",
"05" => "FIN_WAIT2",
"06" => "TIME_WAIT",
"07" => "CLOSE",
"08" => "CLOSE_WAIT",
"09" => "LAST_ACK",
"0A" => "LISTEN",
"0B" => "CLOSING",
_ => "UNKNOWN",
}
}
pub fn read_tcp_connections() -> Vec<ConnStat> {
let mut conns = Vec::new();
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp", libc::AF_INET as u32) {
conns.extend(c);
}
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp6", libc::AF_INET6 as u32) {
conns.extend(c);
}
conns
}
fn parse_proc_net_tcp(path: &str, family: u32) -> io::Result<Vec<ConnStat>> {
let file = std::fs::File::open(path)?;
let reader = io::BufReader::new(file);
let mut conns = Vec::new();
let mut first = true;
for line in reader.lines() {
let line = line?;
if first {
first = false;
continue;
}
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let fields: Vec<&str> = line.split_whitespace().collect();
if fields.len() < 10 {
continue;
}
let (ip, port) = match parse_hex_addr(fields[1], family) {
Some(v) => v,
None => continue,
};
let state = tcp_state_name(fields[3]);
let inode: u64 = match fields[9].parse() {
Ok(v) => v,
Err(_) => continue,
};
conns.push(ConnStat {
local_ip: ip,
local_port: port,
status: state.to_string(),
family,
inode,
});
}
Ok(conns)
}
fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> {
let (ip_hex, port_hex) = s.split_once(':')?;
let port = u32::from_str_radix(port_hex, 16).ok()?;
let ip_bytes = hex::decode(ip_hex).ok()?;
let ip_str = if family == libc::AF_INET as u32 {
if ip_bytes.len() != 4 {
return None;
}
format!("{}.{}.{}.{}", ip_bytes[3], ip_bytes[2], ip_bytes[1], ip_bytes[0])
} else {
if ip_bytes.len() != 16 {
return None;
}
let mut octets = [0u8; 16];
for i in 0..4 {
octets[i * 4] = ip_bytes[i * 4 + 3];
octets[i * 4 + 1] = ip_bytes[i * 4 + 2];
octets[i * 4 + 2] = ip_bytes[i * 4 + 1];
octets[i * 4 + 3] = ip_bytes[i * 4];
}
let addr = std::net::Ipv6Addr::from(octets);
addr.to_string()
};
Some((ip_str, port))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
// tcp_state_name
#[test]
fn state_all_known_codes() {
assert_eq!(tcp_state_name("01"), "ESTABLISHED");
assert_eq!(tcp_state_name("02"), "SYN_SENT");
assert_eq!(tcp_state_name("03"), "SYN_RECV");
assert_eq!(tcp_state_name("04"), "FIN_WAIT1");
assert_eq!(tcp_state_name("05"), "FIN_WAIT2");
assert_eq!(tcp_state_name("06"), "TIME_WAIT");
assert_eq!(tcp_state_name("07"), "CLOSE");
assert_eq!(tcp_state_name("08"), "CLOSE_WAIT");
assert_eq!(tcp_state_name("09"), "LAST_ACK");
assert_eq!(tcp_state_name("0A"), "LISTEN");
assert_eq!(tcp_state_name("0B"), "CLOSING");
}
#[test]
fn state_unknown_code() {
assert_eq!(tcp_state_name("FF"), "UNKNOWN");
assert_eq!(tcp_state_name("00"), "UNKNOWN");
}
// parse_hex_addr
#[test]
fn ipv4_localhost() {
let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "127.0.0.1");
assert_eq!(port, 80);
}
#[test]
fn ipv4_any() {
let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "0.0.0.0");
assert_eq!(port, 53);
}
#[test]
fn ipv4_real_address() {
// 192.168.1.1 in little-endian = 0101A8C0
let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap();
assert_eq!(ip, "192.168.1.1");
assert_eq!(port, 443);
}
#[test]
fn ipv4_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn invalid_hex_returns_none() {
assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none());
}
#[test]
fn no_colon_returns_none() {
assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none());
}
#[test]
fn ipv6_loopback() {
// ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000
let (ip, port) = parse_hex_addr(
"00000000000000000000000001000000:0035",
libc::AF_INET6 as u32,
)
.unwrap();
assert_eq!(ip, "::1");
assert_eq!(port, 53);
}
#[test]
fn ipv6_wrong_byte_count_returns_none() {
assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none());
}
// parse_proc_net_tcp
fn write_tcp_file(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
#[test]
fn parse_empty_file() {
let f = write_tcp_file(
" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n",
);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert!(conns.is_empty());
}
#[test]
fn parse_single_entry() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
assert_eq!(conns[0].local_ip, "127.0.0.1");
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[0].status, "LISTEN");
assert_eq!(conns[0].inode, 12345);
assert_eq!(conns[0].family, libc::AF_INET as u32);
}
#[test]
fn parse_skips_malformed_rows() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000
bad line
1: short\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 1);
}
#[test]
fn parse_multiple_entries() {
let content = "\
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000
1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n";
let f = write_tcp_file(content);
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
assert_eq!(conns.len(), 2);
assert_eq!(conns[0].local_port, 80);
assert_eq!(conns[1].local_port, 443);
}
#[test]
fn parse_nonexistent_file_errors() {
assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err());
}
}

View File

@ -0,0 +1,181 @@
use std::collections::HashMap;
use std::os::unix::process::CommandExt;
use std::process::Command;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::cgroups::{CgroupManager, ProcessType};
use super::conn::ConnStat;
const DEFAULT_GATEWAY_IP: &str = "169.254.0.21";
#[derive(PartialEq)]
enum PortState {
Forward,
Delete,
}
struct PortToForward {
pid: Option<u32>,
inode: u64,
family: u32,
state: PortState,
port: u32,
}
fn family_to_ip_version(family: u32) -> u32 {
if family == libc::AF_INET as u32 {
4
} else if family == libc::AF_INET6 as u32 {
6
} else {
0
}
}
pub struct Forwarder {
cgroup_manager: Arc<dyn CgroupManager>,
ports: HashMap<String, PortToForward>,
source_ip: String,
}
impl Forwarder {
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
Self {
cgroup_manager,
ports: HashMap::new(),
source_ip: DEFAULT_GATEWAY_IP.to_string(),
}
}
pub async fn start_forwarding(
&mut self,
mut rx: mpsc::Receiver<Vec<ConnStat>>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
self.stop_all();
return;
}
msg = rx.recv() => {
match msg {
Some(conns) => self.process_scan(conns),
None => {
self.stop_all();
return;
}
}
}
}
}
}
fn process_scan(&mut self, conns: Vec<ConnStat>) {
for ptf in self.ports.values_mut() {
ptf.state = PortState::Delete;
}
for conn in &conns {
let key = format!("{}-{}", conn.inode, conn.local_port);
if let Some(ptf) = self.ports.get_mut(&key) {
ptf.state = PortState::Forward;
} else {
tracing::debug!(
ip = %conn.local_ip,
port = conn.local_port,
family = family_to_ip_version(conn.family),
"detected new port on localhost"
);
let mut ptf = PortToForward {
pid: None,
inode: conn.inode,
family: family_to_ip_version(conn.family),
state: PortState::Forward,
port: conn.local_port,
};
self.start_port_forwarding(&mut ptf);
self.ports.insert(key, ptf);
}
}
let to_stop: Vec<String> = self
.ports
.iter()
.filter(|(_, v)| v.state == PortState::Delete)
.map(|(k, _)| k.clone())
.collect();
for key in to_stop {
if let Some(ptf) = self.ports.get(&key) {
stop_port_forwarding(ptf);
}
self.ports.remove(&key);
}
}
fn start_port_forwarding(&self, ptf: &mut PortToForward) {
let listen_arg = format!(
"TCP4-LISTEN:{},bind={},reuseaddr,fork",
ptf.port, self.source_ip
);
let connect_arg = format!("TCP{}:localhost:{}", ptf.family, ptf.port);
let mut cmd = Command::new("socat");
cmd.args(["-d", "-d", "-d", &listen_arg, &connect_arg]);
unsafe {
let cgroup_fd = self.cgroup_manager.get_fd(ProcessType::Socat);
cmd.pre_exec(move || {
libc::setpgid(0, 0);
if let Some(fd) = cgroup_fd {
let pid_str = format!("{}", libc::getpid());
let tasks_path = format!("/proc/self/fd/{}/cgroup.procs", fd);
let _ = std::fs::write(&tasks_path, pid_str.as_bytes());
}
Ok(())
});
}
tracing::debug!(
port = ptf.port,
inode = ptf.inode,
family = ptf.family,
source_ip = %self.source_ip,
"starting port forwarding"
);
match cmd.spawn() {
Ok(child) => {
ptf.pid = Some(child.id());
std::thread::spawn(move || {
let mut child = child;
let _ = child.wait();
});
}
Err(e) => {
tracing::error!(error = %e, port = ptf.port, "failed to start socat");
}
}
}
fn stop_all(&mut self) {
for ptf in self.ports.values() {
stop_port_forwarding(ptf);
}
self.ports.clear();
}
}
fn stop_port_forwarding(ptf: &PortToForward) {
if let Some(pid) = ptf.pid {
tracing::debug!(port = ptf.port, pid, "stopping port forwarding");
unsafe {
libc::kill(-(pid as i32), libc::SIGKILL);
}
}
}

4
envd-rs/src/port/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod conn;
pub mod forwarder;
pub mod scanner;
pub mod subsystem;

View File

@ -0,0 +1,81 @@
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::conn::{ConnStat, read_tcp_connections};
pub struct ScannerFilter {
pub ips: Vec<String>,
pub state: String,
}
impl ScannerFilter {
pub fn matches(&self, conn: &ConnStat) -> bool {
if self.state.is_empty() && self.ips.is_empty() {
return false;
}
self.ips.contains(&conn.local_ip) && self.state == conn.status
}
}
pub struct ScannerSubscriber {
pub tx: mpsc::Sender<Vec<ConnStat>>,
pub filter: Option<ScannerFilter>,
}
pub struct Scanner {
period: Duration,
subs: RwLock<Vec<(String, Arc<ScannerSubscriber>)>>,
}
impl Scanner {
pub fn new(period: Duration) -> Self {
Self {
period,
subs: RwLock::new(Vec::new()),
}
}
pub fn add_subscriber(
&self,
id: &str,
filter: Option<ScannerFilter>,
) -> mpsc::Receiver<Vec<ConnStat>> {
let (tx, rx) = mpsc::channel(4);
let sub = Arc::new(ScannerSubscriber { tx, filter });
let mut subs = self.subs.write().unwrap();
subs.push((id.to_string(), sub));
rx
}
pub fn remove_subscriber(&self, id: &str) {
let mut subs = self.subs.write().unwrap();
subs.retain(|(sid, _)| sid != id);
}
pub async fn scan_and_broadcast(&self, cancel: CancellationToken) {
loop {
let conns = tokio::task::spawn_blocking(read_tcp_connections)
.await
.unwrap_or_default();
{
let subs = self.subs.read().unwrap();
for (_, sub) in subs.iter() {
let payload = match &sub.filter {
Some(f) => conns.iter().filter(|c| f.matches(c)).cloned().collect(),
None => conns.clone(),
};
let _ = sub.tx.try_send(payload);
}
}
tokio::select! {
_ = cancel.cancelled() => return,
_ = tokio::time::sleep(self.period) => {}
}
}
}
}

View File

@ -0,0 +1,78 @@
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::cgroups::CgroupManager;
use crate::config::PORT_SCANNER_INTERVAL;
use super::forwarder::Forwarder;
use super::scanner::{Scanner, ScannerFilter};
pub struct PortSubsystem {
cgroup_manager: Arc<dyn CgroupManager>,
cancel: std::sync::Mutex<Option<CancellationToken>>,
}
impl PortSubsystem {
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
Self {
cgroup_manager,
cancel: std::sync::Mutex::new(None),
}
}
pub fn start(&self) {
let mut guard = self.cancel.lock().unwrap();
if guard.is_some() {
return;
}
let cancel = CancellationToken::new();
*guard = Some(cancel.clone());
drop(guard);
let cgroup_manager = Arc::clone(&self.cgroup_manager);
let cancel_scanner = cancel.clone();
let cancel_forwarder = cancel.clone();
tokio::spawn(async move {
let scanner = Arc::new(Scanner::new(PORT_SCANNER_INTERVAL));
let rx = scanner.add_subscriber(
"port-forwarder",
Some(ScannerFilter {
ips: vec![
"127.0.0.1".to_string(),
"localhost".to_string(),
"::1".to_string(),
],
state: "LISTEN".to_string(),
}),
);
let scanner_clone = Arc::clone(&scanner);
let scanner_handle = tokio::spawn(async move {
scanner_clone.scan_and_broadcast(cancel_scanner).await;
});
let forwarder_handle = tokio::spawn(async move {
let mut forwarder = Forwarder::new(cgroup_manager);
forwarder.start_forwarding(rx, cancel_forwarder).await;
});
let _ = tokio::join!(scanner_handle, forwarder_handle);
});
}
pub fn stop(&self) {
let mut guard = self.cancel.lock().unwrap();
if let Some(cancel) = guard.take() {
cancel.cancel();
}
}
pub fn restart(&self) {
self.stop();
self.start();
}
}

231
envd-rs/src/rpc/entry.rs Normal file
View File

@ -0,0 +1,231 @@
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use connectrpc::{ConnectError, ErrorCode};
use crate::permissions::user::{lookup_groupname_by_gid, lookup_username_by_uid};
use crate::rpc::pb::filesystem::{EntryInfo, FileType};
use nix::unistd::{Gid, Uid};
const NFS_SUPER_MAGIC: i64 = 0x6969;
const CIFS_MAGIC: i64 = 0xFF534D42;
const SMB_SUPER_MAGIC: i64 = 0x517B;
const SMB2_MAGIC_NUMBER: i64 = 0xFE534D42;
const FUSE_SUPER_MAGIC: i64 = 0x65735546;
pub fn is_network_mount(path: &str) -> Result<bool, String> {
let c_path = std::ffi::CString::new(path).map_err(|e| e.to_string())?;
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
if ret != 0 {
return Err(format!(
"statfs {path}: {}",
std::io::Error::last_os_error()
));
}
let fs_type = stat.f_type as i64;
Ok(matches!(
fs_type,
NFS_SUPER_MAGIC | CIFS_MAGIC | SMB_SUPER_MAGIC | SMB2_MAGIC_NUMBER | FUSE_SUPER_MAGIC
))
}
pub fn build_entry_info(path: &str) -> Result<EntryInfo, ConnectError> {
let p = Path::new(path);
let lstat = std::fs::symlink_metadata(p).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ConnectError::new(ErrorCode::NotFound, format!("file not found: {e}"))
} else {
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
}
})?;
let is_symlink = lstat.file_type().is_symlink();
let (file_type, mode, symlink_target) = if is_symlink {
let target = std::fs::canonicalize(p)
.map(|t| t.to_string_lossy().to_string())
.unwrap_or_else(|_| path.to_string());
let target_type = match std::fs::metadata(p) {
Ok(meta) => meta_to_file_type(&meta),
Err(_) => FileType::FILE_TYPE_UNSPECIFIED,
};
let target_mode = std::fs::metadata(p)
.map(|m| m.mode() & 0o7777)
.unwrap_or(0);
(target_type, target_mode, Some(target))
} else {
let ft = meta_to_file_type(&lstat);
let mode = lstat.mode() & 0o7777;
(ft, mode, None)
};
let uid = lstat.uid();
let gid = lstat.gid();
let owner = lookup_username_by_uid(Uid::from_raw(uid));
let group = lookup_groupname_by_gid(Gid::from_raw(gid));
let modified_time = {
let mtime_sec = lstat.mtime();
let mtime_nsec = lstat.mtime_nsec() as i32;
if mtime_sec == 0 && mtime_nsec == 0 {
None
} else {
Some(buffa_types::google::protobuf::Timestamp {
seconds: mtime_sec,
nanos: mtime_nsec,
..Default::default()
})
}
};
let name = p
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let permissions = format_permissions(lstat.mode());
Ok(EntryInfo {
name,
r#type: buffa::EnumValue::Known(file_type),
path: path.to_string(),
size: lstat.len() as i64,
mode,
permissions,
owner,
group,
modified_time: modified_time.into(),
symlink_target: symlink_target,
..Default::default()
})
}
fn meta_to_file_type(meta: &std::fs::Metadata) -> FileType {
if meta.is_file() {
FileType::FILE_TYPE_FILE
} else if meta.is_dir() {
FileType::FILE_TYPE_DIRECTORY
} else if meta.file_type().is_symlink() {
FileType::FILE_TYPE_SYMLINK
} else {
FileType::FILE_TYPE_UNSPECIFIED
}
}
fn format_permissions(mode: u32) -> String {
let file_type = match mode & libc::S_IFMT {
libc::S_IFDIR => 'd',
libc::S_IFLNK => 'L',
libc::S_IFREG => '-',
libc::S_IFBLK => 'b',
libc::S_IFCHR => 'c',
libc::S_IFIFO => 'p',
libc::S_IFSOCK => 'S',
_ => '?',
};
let perms = mode & 0o777;
let mut s = String::with_capacity(10);
s.push(file_type);
for shift in [6, 3, 0] {
let bits = (perms >> shift) & 7;
s.push(if bits & 4 != 0 { 'r' } else { '-' });
s.push(if bits & 2 != 0 { 'w' } else { '-' });
s.push(if bits & 1 != 0 { 'x' } else { '-' });
}
s
}
#[cfg(test)]
mod tests {
use super::*;
// format_permissions
#[test]
fn regular_file_755() {
assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x");
}
#[test]
fn directory_755() {
assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x");
}
#[test]
fn symlink_777() {
assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx");
}
#[test]
fn regular_file_000() {
assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------");
}
#[test]
fn regular_file_644() {
assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--");
}
#[test]
fn block_device() {
assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----");
}
#[test]
fn char_device() {
assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-");
}
#[test]
fn fifo() {
assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--");
}
#[test]
fn socket() {
assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x");
}
#[test]
fn unknown_type() {
assert_eq!(format_permissions(0o755), "?rwxr-xr-x");
}
#[test]
fn setuid_in_mode_only_affects_lower_bits() {
// setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755
assert_eq!(
format_permissions(libc::S_IFREG | 0o4755),
format_permissions(libc::S_IFREG | 0o755),
);
}
#[test]
fn output_always_10_chars() {
for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] {
assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10);
}
}
// meta_to_file_type — needs real filesystem
#[test]
fn meta_regular_file() {
let f = tempfile::NamedTempFile::new().unwrap();
let meta = std::fs::metadata(f.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE);
}
#[test]
fn meta_directory() {
let d = tempfile::TempDir::new().unwrap();
let meta = std::fs::metadata(d.path()).unwrap();
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY);
}
}

View File

@ -0,0 +1,402 @@
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use connectrpc::{ConnectError, Context, ErrorCode};
use dashmap::DashMap;
use futures::Stream;
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
use crate::permissions::user::lookup_user;
use crate::rpc::entry::build_entry_info;
use crate::rpc::pb::filesystem::*;
use crate::state::AppState;
pub struct FilesystemServiceImpl {
state: Arc<AppState>,
watchers: DashMap<String, WatcherHandle>,
}
struct WatcherHandle {
events: Arc<Mutex<Vec<FilesystemEvent>>>,
_watcher: notify::RecommendedWatcher,
}
impl FilesystemServiceImpl {
pub fn new(state: Arc<AppState>) -> Self {
Self {
state,
watchers: DashMap::new(),
}
}
fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> {
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user());
let user = lookup_user(&username).map_err(|e| {
ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}"))
})?;
let home_dir = user.dir.to_string_lossy().to_string();
let default_workdir = self.state.defaults.workdir();
expand_and_resolve(path, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))
}
}
fn extract_username(ctx: &Context) -> Option<String> {
ctx.extensions.get::<AuthUser>().map(|u| u.0.clone())
}
#[derive(Clone)]
pub struct AuthUser(pub String);
impl Filesystem for FilesystemServiceImpl {
async fn stat(
&self,
ctx: Context,
request: buffa::view::OwnedView<StatRequestView<'static>>,
) -> Result<(StatResponse, Context), ConnectError> {
let path = self.resolve_path(request.path, &ctx)?;
let entry = build_entry_info(&path)?;
Ok((
StatResponse {
entry: entry.into(),
..Default::default()
},
ctx,
))
}
async fn make_dir(
&self,
ctx: Context,
request: buffa::view::OwnedView<MakeDirRequestView<'static>>,
) -> Result<(MakeDirResponse, Context), ConnectError> {
let path = self.resolve_path(request.path, &ctx)?;
match std::fs::metadata(&path) {
Ok(meta) => {
if meta.is_dir() {
return Err(ConnectError::new(
ErrorCode::AlreadyExists,
format!("directory already exists: {path}"),
));
}
return Err(ConnectError::new(
ErrorCode::InvalidArgument,
format!("path exists but is not a directory: {path}"),
));
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
return Err(ConnectError::new(
ErrorCode::Internal,
format!("error getting file info: {e}"),
));
}
}
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
ensure_dirs(&path, user.uid, user.gid)
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
let entry = build_entry_info(&path)?;
Ok((
MakeDirResponse {
entry: entry.into(),
..Default::default()
},
ctx,
))
}
async fn r#move(
&self,
ctx: Context,
request: buffa::view::OwnedView<MoveRequestView<'static>>,
) -> Result<(MoveResponse, Context), ConnectError> {
let source = self.resolve_path(request.source, &ctx)?;
let destination = self.resolve_path(request.destination, &ctx)?;
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
if let Some(parent) = Path::new(&destination).parent() {
ensure_dirs(&parent.to_string_lossy(), user.uid, user.gid)
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
}
std::fs::rename(&source, &destination).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ConnectError::new(ErrorCode::NotFound, format!("source not found: {e}"))
} else {
ConnectError::new(ErrorCode::Internal, format!("error renaming: {e}"))
}
})?;
let entry = build_entry_info(&destination)?;
Ok((
MoveResponse {
entry: entry.into(),
..Default::default()
},
ctx,
))
}
async fn list_dir(
&self,
ctx: Context,
request: buffa::view::OwnedView<ListDirRequestView<'static>>,
) -> Result<(ListDirResponse, Context), ConnectError> {
let mut depth = request.depth as usize;
if depth == 0 {
depth = 1;
}
let path = self.resolve_path(request.path, &ctx)?;
let resolved = std::fs::canonicalize(&path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ConnectError::new(ErrorCode::NotFound, format!("path not found: {e}"))
} else {
ConnectError::new(ErrorCode::Internal, format!("error resolving path: {e}"))
}
})?;
let resolved_str = resolved.to_string_lossy().to_string();
let meta = std::fs::metadata(&resolved).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
})?;
if !meta.is_dir() {
return Err(ConnectError::new(
ErrorCode::InvalidArgument,
format!("path is not a directory: {path}"),
));
}
let entries = walk_dir(&path, &resolved_str, depth)?;
Ok((
ListDirResponse {
entries,
..Default::default()
},
ctx,
))
}
async fn remove(
&self,
ctx: Context,
request: buffa::view::OwnedView<RemoveRequestView<'static>>,
) -> Result<(RemoveResponse, Context), ConnectError> {
let path = self.resolve_path(request.path, &ctx)?;
if let Err(e1) = std::fs::remove_dir_all(&path) {
if let Err(e2) = std::fs::remove_file(&path) {
return Err(ConnectError::new(
ErrorCode::Internal,
format!("error removing: {e1}; also tried as file: {e2}"),
));
}
}
Ok((RemoveResponse { ..Default::default() }, ctx))
}
async fn watch_dir(
&self,
_ctx: Context,
_request: buffa::view::OwnedView<WatchDirRequestView<'static>>,
) -> Result<
(
Pin<Box<dyn Stream<Item = Result<WatchDirResponse, ConnectError>> + Send>>,
Context,
),
ConnectError,
> {
Err(ConnectError::new(
ErrorCode::Unimplemented,
"watch_dir streaming not yet implemented",
))
}
async fn create_watcher(
&self,
ctx: Context,
request: buffa::view::OwnedView<CreateWatcherRequestView<'static>>,
) -> Result<(CreateWatcherResponse, Context), ConnectError> {
use notify::{RecursiveMode, Watcher};
let path = self.resolve_path(request.path, &ctx)?;
let recursive = request.recursive;
if let Ok(true) = crate::rpc::entry::is_network_mount(&path) {
return Err(ConnectError::new(
ErrorCode::FailedPrecondition,
"watching network mounts is not supported",
));
}
let watcher_id = simple_id();
let events: Arc<Mutex<Vec<FilesystemEvent>>> = Arc::new(Mutex::new(Vec::new()));
let events_cb = Arc::clone(&events);
let mut watcher = notify::recommended_watcher(
move |res: Result<notify::Event, notify::Error>| {
if let Ok(event) = res {
let event_type = match event.kind {
notify::EventKind::Create(_) => EventType::EVENT_TYPE_CREATE,
notify::EventKind::Modify(notify::event::ModifyKind::Data(_)) => {
EventType::EVENT_TYPE_WRITE
}
notify::EventKind::Modify(notify::event::ModifyKind::Metadata(_)) => {
EventType::EVENT_TYPE_CHMOD
}
notify::EventKind::Remove(_) => EventType::EVENT_TYPE_REMOVE,
notify::EventKind::Modify(notify::event::ModifyKind::Name(_)) => {
EventType::EVENT_TYPE_RENAME
}
_ => return,
};
for p in &event.paths {
if let Ok(mut guard) = events_cb.lock() {
guard.push(FilesystemEvent {
name: p.to_string_lossy().to_string(),
r#type: buffa::EnumValue::Known(event_type),
..Default::default()
});
}
}
}
},
)
.map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("failed to create watcher: {e}"))
})?;
let mode = if recursive {
RecursiveMode::Recursive
} else {
RecursiveMode::NonRecursive
};
watcher.watch(Path::new(&path), mode).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("failed to watch path: {e}"))
})?;
self.watchers.insert(
watcher_id.clone(),
WatcherHandle {
events,
_watcher: watcher,
},
);
Ok((
CreateWatcherResponse {
watcher_id,
..Default::default()
},
ctx,
))
}
async fn get_watcher_events(
&self,
ctx: Context,
request: buffa::view::OwnedView<GetWatcherEventsRequestView<'static>>,
) -> Result<(GetWatcherEventsResponse, Context), ConnectError> {
let watcher_id: &str = request.watcher_id;
let handle = self.watchers.get(watcher_id).ok_or_else(|| {
ConnectError::new(
ErrorCode::NotFound,
format!("watcher not found: {watcher_id}"),
)
})?;
let events = {
let mut guard = handle.events.lock().unwrap();
std::mem::take(&mut *guard)
};
Ok((
GetWatcherEventsResponse {
events,
..Default::default()
},
ctx,
))
}
async fn remove_watcher(
&self,
ctx: Context,
request: buffa::view::OwnedView<RemoveWatcherRequestView<'static>>,
) -> Result<(RemoveWatcherResponse, Context), ConnectError> {
let watcher_id: &str = request.watcher_id;
self.watchers.remove(watcher_id);
Ok((RemoveWatcherResponse { ..Default::default() }, ctx))
}
}
fn walk_dir(
requested_path: &str,
resolved_path: &str,
depth: usize,
) -> Result<Vec<EntryInfo>, ConnectError> {
let mut entries = Vec::new();
let base = Path::new(resolved_path);
for result in walkdir::WalkDir::new(resolved_path)
.min_depth(1)
.max_depth(depth)
.follow_links(false)
{
let dir_entry = match result {
Ok(e) => e,
Err(e) => {
if e.io_error()
.is_some_and(|io| io.kind() == std::io::ErrorKind::NotFound)
{
continue;
}
return Err(ConnectError::new(
ErrorCode::Internal,
format!("error reading directory: {e}"),
));
}
};
let entry_path = dir_entry.path();
let mut entry = match build_entry_info(&entry_path.to_string_lossy()) {
Ok(e) => e,
Err(e) if e.code == ErrorCode::NotFound => continue,
Err(e) => return Err(e),
};
if let Ok(rel) = entry_path.strip_prefix(base) {
let remapped = PathBuf::from(requested_path).join(rel);
entry.path = remapped.to_string_lossy().to_string();
}
entries.push(entry);
}
Ok(entries)
}
fn simple_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
format!("w-{nanos:x}")
}

26
envd-rs/src/rpc/mod.rs Normal file
View File

@ -0,0 +1,26 @@
pub mod pb;
pub mod entry;
pub mod process_handler;
pub mod process_service;
pub mod filesystem_service;
use std::sync::Arc;
use crate::rpc::process_service::ProcessServiceImpl;
use crate::rpc::filesystem_service::FilesystemServiceImpl;
use crate::state::AppState;
use pb::process::ProcessExt;
use pb::filesystem::FilesystemExt;
/// Build the connect-rust Router with both RPC services registered.
pub fn rpc_router(state: Arc<AppState>) -> connectrpc::Router {
let process_svc = Arc::new(ProcessServiceImpl::new(Arc::clone(&state)));
let filesystem_svc = Arc::new(FilesystemServiceImpl::new(Arc::clone(&state)));
let router = connectrpc::Router::new();
let router = process_svc.register(router);
let router = filesystem_svc.register(router);
router
}

10
envd-rs/src/rpc/pb.rs Normal file
View File

@ -0,0 +1,10 @@
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
use ::buffa;
use ::buffa_types;
use ::connectrpc;
use ::futures;
use ::http_body;
use ::serde;
include!(concat!(env!("OUT_DIR"), "/_connectrpc.rs"));

View File

@ -0,0 +1,453 @@
use std::io::Read;
use std::os::unix::process::CommandExt;
use std::process::Stdio;
use std::sync::{Arc, Mutex};
use connectrpc::{ConnectError, ErrorCode};
use nix::pty::{openpty, Winsize};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use tokio::sync::broadcast;
use crate::rpc::pb::process::*;
const STD_CHUNK_SIZE: usize = 32768;
const PTY_CHUNK_SIZE: usize = 16384;
const BROADCAST_CAPACITY: usize = 4096;
#[derive(Clone)]
pub enum DataEvent {
Stdout(Vec<u8>),
Stderr(Vec<u8>),
Pty(Vec<u8>),
}
#[derive(Clone)]
pub struct EndEvent {
pub exit_code: i32,
pub exited: bool,
pub status: String,
pub error: Option<String>,
}
pub struct ProcessHandle {
pub config: ProcessConfig,
pub tag: Option<String>,
pub pid: u32,
data_tx: broadcast::Sender<DataEvent>,
end_tx: broadcast::Sender<EndEvent>,
ended: Mutex<Option<EndEvent>>,
stdin: Mutex<Option<std::process::ChildStdin>>,
pty_master: Mutex<Option<std::fs::File>>,
}
impl ProcessHandle {
pub fn subscribe_data(&self) -> broadcast::Receiver<DataEvent> {
self.data_tx.subscribe()
}
pub fn subscribe_end(&self) -> broadcast::Receiver<EndEvent> {
self.end_tx.subscribe()
}
pub fn cached_end(&self) -> Option<EndEvent> {
self.ended.lock().unwrap().clone()
}
pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> {
// Signal the whole process group (negative pid), not just the immediate
// /bin/sh wrapper. Otherwise children the process spawned are orphaned
// and keep running. Both spawn paths make the process a group leader
// (setsid for pty, setpgid for pipe), so pgid == pid.
signal::kill(Pid::from_raw(-(self.pid as i32)), sig).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}"))
})
}
pub fn write_stdin(&self, data: &[u8]) -> Result<(), ConnectError> {
use std::io::Write;
let mut guard = self.stdin.lock().unwrap();
match guard.as_mut() {
Some(stdin) => stdin.write_all(data).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error writing to stdin: {e}"))
}),
None => Err(ConnectError::new(
ErrorCode::FailedPrecondition,
"stdin not enabled or closed",
)),
}
}
pub fn write_pty(&self, data: &[u8]) -> Result<(), ConnectError> {
use std::io::Write;
let mut guard = self.pty_master.lock().unwrap();
match guard.as_mut() {
Some(master) => master.write_all(data).map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error writing to pty: {e}"))
}),
None => Err(ConnectError::new(
ErrorCode::FailedPrecondition,
"pty not assigned to process",
)),
}
}
pub fn close_stdin(&self) -> Result<(), ConnectError> {
if self.pty_master.lock().unwrap().is_some() {
return Err(ConnectError::new(
ErrorCode::FailedPrecondition,
"cannot close stdin for PTY process — send Ctrl+D (0x04) instead",
));
}
let mut guard = self.stdin.lock().unwrap();
*guard = None;
Ok(())
}
pub fn resize_pty(&self, cols: u16, rows: u16) -> Result<(), ConnectError> {
let guard = self.pty_master.lock().unwrap();
match guard.as_ref() {
Some(master) => {
use std::os::unix::io::AsRawFd;
let ws = libc::winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
let ret = unsafe { libc::ioctl(master.as_raw_fd(), libc::TIOCSWINSZ, &ws) };
if ret != 0 {
return Err(ConnectError::new(
ErrorCode::Internal,
format!(
"ioctl TIOCSWINSZ failed: {}",
std::io::Error::last_os_error()
),
));
}
Ok(())
}
None => Err(ConnectError::new(
ErrorCode::FailedPrecondition,
"tty not assigned to process",
)),
}
}
}
pub struct SpawnedProcess {
pub handle: Arc<ProcessHandle>,
pub data_rx: broadcast::Receiver<DataEvent>,
pub end_rx: broadcast::Receiver<EndEvent>,
}
pub fn spawn_process(
cmd_str: &str,
args: &[String],
envs: &std::collections::HashMap<String, String>,
cwd: &str,
pty_opts: Option<(u16, u16)>,
enable_stdin: bool,
tag: Option<String>,
user: &nix::unistd::User,
default_env_vars: &dashmap::DashMap<String, String>,
) -> Result<SpawnedProcess, ConnectError> {
let mut env: Vec<(String, String)> = Vec::new();
env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default()));
let home = user.dir.to_string_lossy().to_string();
env.push(("HOME".into(), home));
env.push(("USER".into(), user.name.clone()));
env.push(("LOGNAME".into(), user.name.clone()));
default_env_vars.iter().for_each(|entry| {
env.push((entry.key().clone(), entry.value().clone()));
});
for (k, v) in envs {
env.push((k.clone(), v.clone()));
}
// Reset the child's nice value only when envd itself was started at an
// elevated nice value (delta > 0 means raising the nice number / lowering
// priority, which is permitted for non-root processes). A non-root process
// cannot improve its priority, so skip the `nice` wrapper otherwise — it
// would fail with EPERM ("cannot set niceness: permission denied") for
// commands run as a non-root user. Writing 100 to the process's own
// oom_score_adj is always permitted (raising the score).
let nice_delta = 0 - current_nice();
let oom_script = if nice_delta > 0 {
format!(
r#"echo 100 > /proc/$$/oom_score_adj && exec /usr/bin/nice -n {} "${{@}}""#,
nice_delta
)
} else {
r#"echo 100 > /proc/$$/oom_score_adj && exec "$@""#.to_string()
};
let mut wrapper_args = vec![
"-c".to_string(),
oom_script,
"--".to_string(),
cmd_str.to_string(),
];
wrapper_args.extend_from_slice(args);
let uid = user.uid.as_raw();
let gid = user.gid.as_raw();
let (data_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
let (end_tx, _) = broadcast::channel(16);
let config = ProcessConfig {
cmd: cmd_str.to_string(),
args: args.to_vec(),
envs: envs.clone(),
cwd: Some(cwd.to_string()),
..Default::default()
};
if let Some((cols, rows)) = pty_opts {
let pty_result = openpty(
Some(&Winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
}),
None,
)
.map_err(|e| ConnectError::new(ErrorCode::Internal, format!("openpty failed: {e}")))?;
let master_fd = pty_result.master;
let slave_fd = pty_result.slave;
let mut command = std::process::Command::new("/bin/sh");
command
.args(&wrapper_args)
.env_clear()
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
.current_dir(cwd);
unsafe {
use std::os::unix::io::AsRawFd;
let slave_raw = slave_fd.as_raw_fd();
let master_raw = master_fd.as_raw_fd();
command.pre_exec(move || {
libc::close(master_raw);
nix::unistd::setsid()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
libc::ioctl(slave_raw, libc::TIOCSCTTY, 0);
libc::dup2(slave_raw, 0);
libc::dup2(slave_raw, 1);
libc::dup2(slave_raw, 2);
if slave_raw > 2 {
libc::close(slave_raw);
}
libc::setgid(gid);
libc::setuid(uid);
Ok(())
});
}
command.stdin(Stdio::null());
command.stdout(Stdio::null());
command.stderr(Stdio::null());
let child = command.spawn().map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error starting pty process: {e}"))
})?;
drop(slave_fd);
let pid = child.id();
let master_file: std::fs::File = master_fd.into();
let master_clone = master_file.try_clone().unwrap();
let handle = Arc::new(ProcessHandle {
config,
tag,
pid,
data_tx: data_tx.clone(),
end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(None),
pty_master: Mutex::new(Some(master_file)),
});
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
let data_tx_clone = data_tx.clone();
let pty_reader = std::thread::spawn(move || {
let mut master = master_clone;
let mut buf = vec![0u8; PTY_CHUNK_SIZE];
loop {
match master.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = data_tx_clone.send(DataEvent::Pty(buf[..n].to_vec()));
}
Err(_) => break,
}
}
});
let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || {
let mut child = child;
let status = child.wait();
// Drain the pty to EOF before publishing the end event so trailing
// output is never lost to a process-exit/pty-read race.
let _ = pty_reader.join();
let end_event = match status {
Ok(s) => EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
},
Err(e) => EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
},
};
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
});
tracing::info!(pid, cmd = cmd_str, "process started (pty)");
Ok(SpawnedProcess { handle, data_rx, end_rx })
} else {
let mut command = std::process::Command::new("/bin/sh");
command
.args(&wrapper_args)
.env_clear()
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
.current_dir(cwd)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
if enable_stdin {
command.stdin(Stdio::piped());
} else {
command.stdin(Stdio::null());
}
unsafe {
command.pre_exec(move || {
// Become a process-group leader so SendSignal can kill the
// whole group, not just this wrapper. The pty path gets this
// for free via setsid().
nix::unistd::setpgid(Pid::from_raw(0), Pid::from_raw(0))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
libc::setgid(gid);
libc::setuid(uid);
Ok(())
});
}
let mut child = command.spawn().map_err(|e| {
ConnectError::new(ErrorCode::Internal, format!("error starting process: {e}"))
})?;
let pid = child.id();
let stdin = child.stdin.take();
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let handle = Arc::new(ProcessHandle {
config,
tag,
pid,
data_tx: data_tx.clone(),
end_tx: end_tx.clone(),
ended: Mutex::new(None),
stdin: Mutex::new(stdin),
pty_master: Mutex::new(None),
});
let data_rx = handle.subscribe_data();
let end_rx = handle.subscribe_end();
let mut output_readers: Vec<std::thread::JoinHandle<()>> = Vec::new();
if let Some(mut out) = stdout {
let tx = data_tx.clone();
output_readers.push(std::thread::spawn(move || {
let mut buf = vec![0u8; STD_CHUNK_SIZE];
loop {
match out.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = tx.send(DataEvent::Stdout(buf[..n].to_vec()));
}
Err(_) => break,
}
}
}));
}
if let Some(mut err_pipe) = stderr {
let tx = data_tx.clone();
output_readers.push(std::thread::spawn(move || {
let mut buf = vec![0u8; STD_CHUNK_SIZE];
loop {
match err_pipe.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = tx.send(DataEvent::Stderr(buf[..n].to_vec()));
}
Err(_) => break,
}
}
}));
}
let end_tx_clone = end_tx.clone();
let handle_for_waiter = Arc::clone(&handle);
std::thread::spawn(move || {
let status = child.wait();
// Drain stdout/stderr to EOF before publishing the end event so
// trailing output is never lost to a process-exit/pipe-read race.
for reader in output_readers {
let _ = reader.join();
}
let end_event = match status {
Ok(s) => EndEvent {
exit_code: s.code().unwrap_or(-1),
exited: s.code().is_some(),
status: format!("{s}"),
error: None,
},
Err(e) => EndEvent {
exit_code: -1,
exited: false,
status: "error".into(),
error: Some(e.to_string()),
},
};
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
let _ = end_tx_clone.send(end_event);
});
tracing::info!(pid, cmd = cmd_str, "process started (pipe)");
Ok(SpawnedProcess { handle, data_rx, end_rx })
}
}
fn current_nice() -> i32 {
unsafe {
*libc::__errno_location() = 0;
let prio = libc::getpriority(libc::PRIO_PROCESS, 0);
if *libc::__errno_location() != 0 {
return 0;
}
// getpriority(PRIO_PROCESS, 0) returns the nice value directly,
// in the range [-20, 19]; the normal default is 0.
prio
}
}

View File

@ -0,0 +1,527 @@
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use connectrpc::{ConnectError, Context, ErrorCode};
use dashmap::DashMap;
use futures::Stream;
use crate::permissions::path::expand_and_resolve;
use crate::permissions::user::lookup_user;
use crate::rpc::pb::process::*;
use crate::rpc::process_handler::{self, DataEvent, ProcessHandle};
use crate::state::AppState;
pub struct ProcessServiceImpl {
state: Arc<AppState>,
processes: Arc<DashMap<u32, Arc<ProcessHandle>>>,
}
impl ProcessServiceImpl {
pub fn new(state: Arc<AppState>) -> Self {
Self {
state,
processes: Arc::new(DashMap::new()),
}
}
fn get_process_by_selector(
&self,
selector: &ProcessSelectorView,
) -> Result<Arc<ProcessHandle>, ConnectError> {
match &selector.selector {
Some(process_selector::SelectorView::Pid(pid)) => {
let pid_val = *pid;
self.processes
.get(&pid_val)
.map(|entry| Arc::clone(entry.value()))
.ok_or_else(|| {
ConnectError::new(
ErrorCode::NotFound,
format!("process with pid {pid_val} not found"),
)
})
}
Some(process_selector::SelectorView::Tag(tag)) => {
let tag_str: &str = tag;
for entry in self.processes.iter() {
if let Some(ref t) = entry.value().tag {
if t == tag_str {
return Ok(Arc::clone(entry.value()));
}
}
}
Err(ConnectError::new(
ErrorCode::NotFound,
format!("process with tag {tag_str} not found"),
))
}
None => Err(ConnectError::new(
ErrorCode::InvalidArgument,
"process selector required",
)),
}
}
fn spawn_from_request(
&self,
request: &StartRequestView<'_>,
) -> Result<process_handler::SpawnedProcess, ConnectError> {
let proc_config = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process config required")
})?;
let username = self.state.defaults.user();
let user =
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
let cmd: &str = proc_config.cmd;
let args: Vec<String> = proc_config.args.iter().map(|s| s.to_string()).collect();
let envs: HashMap<String, String> = proc_config
.envs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
let home_dir = user.dir.to_string_lossy().to_string();
let cwd_str: &str = proc_config.cwd.unwrap_or("");
let default_workdir = self.state.defaults.workdir();
let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref())
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?;
let effective_cwd = if cwd.is_empty() { "/" } else { &cwd };
if let Err(_) = std::fs::metadata(effective_cwd) {
return Err(ConnectError::new(
ErrorCode::InvalidArgument,
format!("cwd '{effective_cwd}' does not exist"),
));
}
let pty_opts = request.pty.as_option().and_then(|pty| {
pty.size
.as_option()
.map(|sz| (sz.cols as u16, sz.rows as u16))
});
let enable_stdin = request.stdin.unwrap_or(true);
let tag = request.tag.map(|s| s.to_string());
tracing::info!(
cmd = cmd,
has_pty = pty_opts.is_some(),
pty_size = ?pty_opts,
tag = ?tag,
stdin = enable_stdin,
cwd = effective_cwd,
user = %username,
"process.Start request"
);
let spawned = process_handler::spawn_process(
cmd,
&args,
&envs,
effective_cwd,
pty_opts,
enable_stdin,
tag,
&user,
&self.state.defaults.env_vars,
)?;
self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle));
let processes = Arc::clone(&self.processes);
let pid = spawned.handle.pid;
// Subscribe before checking cached_end so the prune cannot be lost to a
// race: a short-lived process can exit and broadcast its end event
// before this task runs. A broadcast receiver only sees messages sent
// after subscribe(), so a late subscribe would miss the event forever
// (recv() never returns Closed either — the handle keeps end_tx alive
// until it leaves the map, which only this task does). The waiter sets
// ended before sending end_tx, so cached_end() is a reliable fallback.
let mut cleanup_end_rx = spawned.handle.subscribe_end();
let already_ended = spawned.handle.cached_end().is_some();
tokio::spawn(async move {
if !already_ended {
let _ = cleanup_end_rx.recv().await;
}
processes.remove(&pid);
});
Ok(spawned)
}
}
impl Process for ProcessServiceImpl {
async fn list(
&self,
ctx: Context,
_request: buffa::view::OwnedView<ListRequestView<'static>>,
) -> Result<(ListResponse, Context), ConnectError> {
let processes: Vec<ProcessInfo> = self
.processes
.iter()
.map(|entry| {
let h = entry.value();
ProcessInfo {
config: buffa::MessageField::some(h.config.clone()),
pid: h.pid,
tag: h.tag.clone(),
..Default::default()
}
})
.collect();
Ok((
ListResponse {
processes,
..Default::default()
},
ctx,
))
}
async fn start(
&self,
ctx: Context,
request: buffa::view::OwnedView<StartRequestView<'static>>,
) -> Result<
(
Pin<Box<dyn Stream<Item = Result<StartResponse, ConnectError>> + Send>>,
Context,
),
ConnectError,
> {
let spawned = self.spawn_from_request(&request)?;
let pid = spawned.handle.pid;
let mut data_rx = spawned.data_rx;
let mut end_rx = spawned.end_rx;
let stream = async_stream::stream! {
yield Ok(make_start_response(pid));
loop {
tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => yield Ok(make_data_start_response(ev)),
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
// Data channel closed: the process ended and its
// handle was dropped. The end event is published
// before the handle drop, so it is still buffered
// — emit it rather than losing the exit code.
if let Ok(end) = end_rx.try_recv() {
yield Ok(make_end_start_response(end));
}
break;
}
}
}
end = end_rx.recv() => {
// Process ended. The waiter joins the output readers
// before sending this event, so every byte is already
// in the data channel — drain it fully before the end.
loop {
match data_rx.try_recv() {
Ok(ev) => yield Ok(make_data_start_response(ev)),
Err(tokio::sync::broadcast::error::TryRecvError::Lagged(_)) => continue,
Err(_) => break,
}
}
if let Ok(end) = end {
yield Ok(make_end_start_response(end));
}
break;
}
}
}
};
Ok((Box::pin(stream), ctx))
}
async fn connect(
&self,
ctx: Context,
request: buffa::view::OwnedView<ConnectRequestView<'static>>,
) -> Result<
(
Pin<Box<dyn Stream<Item = Result<ConnectResponse, ConnectError>> + Send>>,
Context,
),
ConnectError,
> {
let selector = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
})?;
let handle = self.get_process_by_selector(selector)?;
let pid = handle.pid;
let mut data_rx = handle.subscribe_data();
let mut end_rx = handle.subscribe_end();
let cached_end = handle.cached_end();
let stream = async_stream::stream! {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(ProcessEvent {
event: Some(process_event::Event::Start(Box::new(
process_event::StartEvent { pid, ..Default::default() },
))),
..Default::default()
}),
..Default::default()
});
if let Some(end) = cached_end {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
} else {
loop {
tokio::select! {
biased;
data = data_rx.recv() => {
match data {
Ok(ev) => {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
});
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
// Data channel closed: the process ended and
// its handle was dropped. The end event is
// published before the handle drop, so it is
// still buffered — emit it rather than losing
// the exit code.
if let Ok(end) = end_rx.try_recv() {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
}
break;
}
}
}
end = end_rx.recv() => {
// Process ended. The waiter joins the output readers
// before sending this event, so every byte is already
// in the data channel — drain it fully before the end.
loop {
match data_rx.try_recv() {
Ok(ev) => yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
}),
Err(tokio::sync::broadcast::error::TryRecvError::Lagged(_)) => continue,
Err(_) => break,
}
}
if let Ok(end) = end {
yield Ok(ConnectResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
});
}
break;
}
}
}
}
};
Ok((Box::pin(stream), ctx))
}
async fn update(
&self,
ctx: Context,
request: buffa::view::OwnedView<UpdateRequestView<'static>>,
) -> Result<(UpdateResponse, Context), ConnectError> {
let selector = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
})?;
let handle = self.get_process_by_selector(selector)?;
if let Some(pty) = request.pty.as_option() {
if let Some(size) = pty.size.as_option() {
handle.resize_pty(size.cols as u16, size.rows as u16)?;
}
}
Ok((UpdateResponse { ..Default::default() }, ctx))
}
async fn stream_input(
&self,
ctx: Context,
mut requests: Pin<
Box<
dyn Stream<
Item = Result<
buffa::view::OwnedView<StreamInputRequestView<'static>>,
ConnectError,
>,
> + Send,
>,
>,
) -> Result<(StreamInputResponse, Context), ConnectError> {
use futures::StreamExt;
let mut handle: Option<Arc<ProcessHandle>> = None;
while let Some(result) = requests.next().await {
let req = result?;
match &req.event {
Some(stream_input_request::EventView::Start(start)) => {
if let Some(selector) = start.process.as_option() {
handle = Some(self.get_process_by_selector(selector)?);
}
}
Some(stream_input_request::EventView::Data(data)) => {
let h = handle.as_ref().ok_or_else(|| {
ConnectError::new(ErrorCode::FailedPrecondition, "no start event received")
})?;
if let Some(input) = data.input.as_option() {
write_input(h, input)?;
}
}
Some(stream_input_request::EventView::Keepalive(_)) => {}
None => {}
}
}
Ok((StreamInputResponse { ..Default::default() }, ctx))
}
async fn send_input(
&self,
ctx: Context,
request: buffa::view::OwnedView<SendInputRequestView<'static>>,
) -> Result<(SendInputResponse, Context), ConnectError> {
let selector = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
})?;
let handle = self.get_process_by_selector(selector)?;
if let Some(input) = request.input.as_option() {
write_input(&handle, input)?;
}
Ok((SendInputResponse { ..Default::default() }, ctx))
}
async fn send_signal(
&self,
ctx: Context,
request: buffa::view::OwnedView<SendSignalRequestView<'static>>,
) -> Result<(SendSignalResponse, Context), ConnectError> {
let selector = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
})?;
let handle = self.get_process_by_selector(selector)?;
let sig = match request.signal.as_known() {
Some(Signal::SIGNAL_SIGKILL) => nix::sys::signal::Signal::SIGKILL,
Some(Signal::SIGNAL_SIGTERM) => nix::sys::signal::Signal::SIGTERM,
_ => {
return Err(ConnectError::new(
ErrorCode::InvalidArgument,
"invalid or unspecified signal",
))
}
};
handle.send_signal(sig)?;
Ok((SendSignalResponse { ..Default::default() }, ctx))
}
async fn close_stdin(
&self,
ctx: Context,
request: buffa::view::OwnedView<CloseStdinRequestView<'static>>,
) -> Result<(CloseStdinResponse, Context), ConnectError> {
let selector = request.process.as_option().ok_or_else(|| {
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
})?;
let handle = self.get_process_by_selector(selector)?;
handle.close_stdin()?;
Ok((CloseStdinResponse { ..Default::default() }, ctx))
}
}
fn write_input(handle: &ProcessHandle, input: &ProcessInputView) -> Result<(), ConnectError> {
match &input.input {
Some(process_input::InputView::Pty(d)) => handle.write_pty(d),
Some(process_input::InputView::Stdin(d)) => handle.write_stdin(d),
None => Ok(()),
}
}
fn make_start_response(pid: u32) -> StartResponse {
StartResponse {
event: buffa::MessageField::some(ProcessEvent {
event: Some(process_event::Event::Start(Box::new(
process_event::StartEvent {
pid,
..Default::default()
},
))),
..Default::default()
}),
..Default::default()
}
}
fn make_data_event(ev: DataEvent) -> ProcessEvent {
let output = match ev {
DataEvent::Stdout(d) => Some(process_event::data_event::Output::Stdout(d.into())),
DataEvent::Stderr(d) => Some(process_event::data_event::Output::Stderr(d.into())),
DataEvent::Pty(d) => Some(process_event::data_event::Output::Pty(d.into())),
};
ProcessEvent {
event: Some(process_event::Event::Data(Box::new(
process_event::DataEvent {
output,
..Default::default()
},
))),
..Default::default()
}
}
fn make_data_start_response(ev: DataEvent) -> StartResponse {
StartResponse {
event: buffa::MessageField::some(make_data_event(ev)),
..Default::default()
}
}
fn make_end_event(end: process_handler::EndEvent) -> ProcessEvent {
ProcessEvent {
event: Some(process_event::Event::End(Box::new(
process_event::EndEvent {
exit_code: end.exit_code,
exited: end.exited,
status: end.status,
error: end.error,
..Default::default()
},
))),
..Default::default()
}
}
fn make_end_start_response(end: process_handler::EndEvent) -> StartResponse {
StartResponse {
event: buffa::MessageField::some(make_end_event(end)),
..Default::default()
}
}

127
envd-rs/src/state.rs Normal file
View File

@ -0,0 +1,127 @@
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicU8, Ordering};
use std::sync::{Arc, Mutex};
use crate::auth::token::SecureToken;
use crate::conntracker::ConnTracker;
use crate::execcontext::Defaults;
use crate::port::subsystem::PortSubsystem;
use crate::util::AtomicMax;
pub struct AppState {
pub defaults: Defaults,
pub version: String,
pub commit: String,
pub last_set_time: AtomicMax,
pub access_token: SecureToken,
pub conn_tracker: ConnTracker,
pub port_subsystem: Option<Arc<PortSubsystem>>,
pub cpu_used_pct: AtomicU32,
pub cpu_count: AtomicU32,
/// Memory preload coordination. The host agent POSTs /memory/preload after
/// a snapshot restore to materialise every physical page (so the next
/// ch.snapshot writes a self-contained memory-ranges). `mem_preload_started`
/// ensures only one loader runs; `mem_preload_done` lets concurrent callers
/// rendezvous; `mem_preload_cancel` lets a teardown abort the loader.
pub mem_preload_started: AtomicBool,
pub mem_preload_done: AtomicBool,
pub mem_preload_cancel: AtomicBool,
pub mem_preload_regions: AtomicU64,
pub mem_preload_pages: AtomicU64,
pub mem_preload_bytes: AtomicU64,
pub mem_preload_elapsed_us: AtomicU64,
/// 0 = unset, 1 = /dev/mem, 2 = /proc/kcore.
pub mem_preload_source: AtomicU8,
pub mem_preload_error: Mutex<Option<String>>,
/// Last lifecycle ID seen on /init. Used to detect post-resume calls so
/// envd can refresh port forwarders and remount NFS volumes.
lifecycle_id: Mutex<Option<String>>,
}
impl AppState {
pub fn new(
defaults: Defaults,
version: String,
commit: String,
port_subsystem: Option<Arc<PortSubsystem>>,
) -> Arc<Self> {
let state = Arc::new(Self {
defaults,
version,
commit,
last_set_time: AtomicMax::new(),
access_token: SecureToken::new(),
conn_tracker: ConnTracker::new(),
port_subsystem,
cpu_used_pct: AtomicU32::new(0),
cpu_count: AtomicU32::new(0),
mem_preload_started: AtomicBool::new(false),
mem_preload_done: AtomicBool::new(false),
mem_preload_cancel: AtomicBool::new(false),
mem_preload_regions: AtomicU64::new(0),
mem_preload_pages: AtomicU64::new(0),
mem_preload_bytes: AtomicU64::new(0),
mem_preload_elapsed_us: AtomicU64::new(0),
mem_preload_source: AtomicU8::new(0),
mem_preload_error: Mutex::new(None),
lifecycle_id: Mutex::new(None),
});
let state_clone = Arc::clone(&state);
std::thread::spawn(move || {
cpu_sampler(state_clone);
});
state
}
pub fn cpu_used_pct(&self) -> f32 {
f32::from_bits(self.cpu_used_pct.load(Ordering::Relaxed))
}
pub fn cpu_count(&self) -> u32 {
self.cpu_count.load(Ordering::Relaxed)
}
/// Records a new lifecycle ID, returning true if it changed (i.e. this
/// is the first /init since a resume). First-ever call returns false:
/// boot-time /init doesn't need port-subsystem restart since the
/// subsystem hasn't been started yet by anything else.
pub fn bump_lifecycle(&self, new_id: &str) -> bool {
let mut guard = self.lifecycle_id.lock().unwrap();
let changed = match guard.as_deref() {
Some(existing) => existing != new_id,
None => false,
};
*guard = Some(new_id.to_owned());
changed
}
}
fn cpu_sampler(state: Arc<AppState>) {
use sysinfo::System;
let mut sys = System::new();
sys.refresh_cpu_all();
loop {
std::thread::sleep(std::time::Duration::from_secs(1));
sys.refresh_cpu_all();
let pct = sys.global_cpu_usage();
let rounded = if pct > 0.0 {
(pct * 100.0).round() / 100.0
} else {
0.0
};
state
.cpu_used_pct
.store(rounded.to_bits(), Ordering::Relaxed);
state
.cpu_count
.store(sys.cpus().len() as u32, Ordering::Relaxed);
}
}

102
envd-rs/src/util.rs Normal file
View File

@ -0,0 +1,102 @@
use std::sync::atomic::{AtomicI64, Ordering};
pub struct AtomicMax {
val: AtomicI64,
}
impl AtomicMax {
pub fn new() -> Self {
Self {
val: AtomicI64::new(i64::MIN),
}
}
pub fn get(&self) -> i64 {
self.val.load(Ordering::Acquire)
}
/// Sets the stored value to `new` if `new` is strictly greater than
/// the current value. Returns `true` if the value was updated.
pub fn set_to_greater(&self, new: i64) -> bool {
loop {
let current = self.val.load(Ordering::Acquire);
if new <= current {
return false;
}
match self.val.compare_exchange_weak(
current,
new,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(_) => continue,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn initial_value_is_i64_min() {
let m = AtomicMax::new();
assert_eq!(m.get(), i64::MIN);
}
#[test]
fn updates_when_larger() {
let m = AtomicMax::new();
assert!(m.set_to_greater(0));
assert_eq!(m.get(), 0);
assert!(m.set_to_greater(100));
assert_eq!(m.get(), 100);
}
#[test]
fn returns_false_when_equal() {
let m = AtomicMax::new();
m.set_to_greater(42);
assert!(!m.set_to_greater(42));
assert_eq!(m.get(), 42);
}
#[test]
fn returns_false_when_smaller() {
let m = AtomicMax::new();
m.set_to_greater(100);
assert!(!m.set_to_greater(50));
assert_eq!(m.get(), 100);
}
#[test]
fn concurrent_convergence() {
let m = Arc::new(AtomicMax::new());
let threads: Vec<_> = (0..8)
.map(|t| {
let m = Arc::clone(&m);
std::thread::spawn(move || {
for i in (t * 100)..((t + 1) * 100) {
m.set_to_greater(i);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(m.get(), 799);
}
#[test]
fn i64_max_boundary() {
let m = AtomicMax::new();
assert!(m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(i64::MAX));
assert!(!m.set_to_greater(0));
assert_eq!(m.get(), i64::MAX);
}
}

View File

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023 FoundryLabs, Inc.
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,62 +0,0 @@
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
BUILDS := ../builds
# ═══════════════════════════════════════════════════
# Build
# ═══════════════════════════════════════════════════
.PHONY: build build-debug
build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
@file $(BUILDS)/envd | grep -q "statically linked" || \
(echo "ERROR: envd is not statically linked!" && exit 1)
build-debug:
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
# ═══════════════════════════════════════════════════
# Run (debug mode, not inside a VM)
# ═══════════════════════════════════════════════════
.PHONY: run-debug
run-debug: build-debug
$(BUILDS)/debug/envd -isnotfc -port 49983
# ═══════════════════════════════════════════════════
# Code Generation
# ═══════════════════════════════════════════════════
.PHONY: generate proto openapi
generate: proto openapi
proto:
cd spec && buf generate --template buf.gen.yaml
openapi:
go generate ./internal/api/...
# ═══════════════════════════════════════════════════
# Quality
# ═══════════════════════════════════════════════════
.PHONY: fmt vet test tidy
fmt:
gofmt -w .
vet:
go vet ./...
test:
go test -race -v ./...
tidy:
go mod tidy
# ═══════════════════════════════════════════════════
# Clean
# ═══════════════════════════════════════════════════
.PHONY: clean
clean:
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd

View File

@ -1 +0,0 @@
0.1.1

View File

@ -1,42 +0,0 @@
module git.omukk.dev/wrenn/sandbox/envd
go 1.25.8
require (
connectrpc.com/authn v0.1.0
connectrpc.com/connect v1.19.1
connectrpc.com/cors v0.1.0
github.com/awnumar/memguard v0.23.0
github.com/creack/pty v1.1.24
github.com/dchest/uniuri v1.2.0
github.com/e2b-dev/fsnotify v0.0.1
github.com/go-chi/chi/v5 v5.2.5
github.com/google/uuid v1.6.0
github.com/oapi-codegen/runtime v1.2.0
github.com/orcaman/concurrent-map/v2 v2.0.1
github.com/rs/cors v1.11.1
github.com/rs/zerolog v1.34.0
github.com/shirou/gopsutil/v4 v4.26.2
github.com/stretchr/testify v1.11.1
github.com/txn2/txeh v1.8.0
golang.org/x/sys v0.43.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ebitengine/purego v0.10.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/crypto v0.50.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,92 +0,0 @@
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=

View File

@ -1,604 +0,0 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT.
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/oapi-codegen/runtime"
openapi_types "github.com/oapi-codegen/runtime/types"
)
const (
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
)
// Defines values for EntryInfoType.
const (
File EntryInfoType = "file"
)
// Valid indicates whether the value is a known member of the EntryInfoType enum.
func (e EntryInfoType) Valid() bool {
switch e {
case File:
return true
default:
return false
}
}
// EntryInfo defines model for EntryInfo.
type EntryInfo struct {
// Name Name of the file
Name string `json:"name"`
// Path Path to the file
Path string `json:"path"`
// Type Type of the file
Type EntryInfoType `json:"type"`
}
// EntryInfoType Type of the file
type EntryInfoType string
// EnvVars Environment variables to set
type EnvVars map[string]string
// Error defines model for Error.
type Error struct {
// Code Error code
Code int `json:"code"`
// Message Error message
Message string `json:"message"`
}
// Metrics Resource usage metrics
type Metrics struct {
// CpuCount Number of CPU cores
CpuCount *int `json:"cpu_count,omitempty"`
// CpuUsedPct CPU usage percentage
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
// DiskTotal Total disk space in bytes
DiskTotal *int `json:"disk_total,omitempty"`
// DiskUsed Used disk space in bytes
DiskUsed *int `json:"disk_used,omitempty"`
// MemTotal Total virtual memory in bytes
MemTotal *int `json:"mem_total,omitempty"`
// MemUsed Used virtual memory in bytes
MemUsed *int `json:"mem_used,omitempty"`
// Ts Unix timestamp in UTC for current sandbox time
Ts *int64 `json:"ts,omitempty"`
}
// VolumeMount Volume
type VolumeMount struct {
NfsTarget string `json:"nfs_target"`
Path string `json:"path"`
}
// FilePath defines model for FilePath.
type FilePath = string
// Signature defines model for Signature.
type Signature = string
// SignatureExpiration defines model for SignatureExpiration.
type SignatureExpiration = int
// User defines model for User.
type User = string
// FileNotFound defines model for FileNotFound.
type FileNotFound = Error
// InternalServerError defines model for InternalServerError.
type InternalServerError = Error
// InvalidPath defines model for InvalidPath.
type InvalidPath = Error
// InvalidUser defines model for InvalidUser.
type InvalidUser = Error
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
type NotEnoughDiskSpace = Error
// UploadSuccess defines model for UploadSuccess.
type UploadSuccess = []EntryInfo
// GetFilesParams defines parameters for GetFiles.
type GetFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostFilesMultipartBody defines parameters for PostFiles.
type PostFilesMultipartBody struct {
File *openapi_types.File `json:"file,omitempty"`
}
// PostFilesParams defines parameters for PostFiles.
type PostFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostInitJSONBody defines parameters for PostInit.
type PostInitJSONBody struct {
// AccessToken Access token for secure access to envd service
AccessToken *SecureToken `json:"accessToken,omitempty"`
// DefaultUser The default user to use for operations
DefaultUser *string `json:"defaultUser,omitempty"`
// DefaultWorkdir The default working directory to use for operations
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
// EnvVars Environment variables to set
EnvVars *EnvVars `json:"envVars,omitempty"`
// HyperloopIP IP address of the hyperloop server to connect to
HyperloopIP *string `json:"hyperloopIP,omitempty"`
// Timestamp The current timestamp in RFC3339 format
Timestamp *time.Time `json:"timestamp,omitempty"`
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
}
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
type PostFilesMultipartRequestBody PostFilesMultipartBody
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
type PostInitJSONRequestBody PostInitJSONBody
// ServerInterface represents all server handlers.
type ServerInterface interface {
// Get the environment variables
// (GET /envs)
GetEnvs(w http.ResponseWriter, r *http.Request)
// Download a file
// (GET /files)
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
// Check the health of the service
// (GET /health)
GetHealth(w http.ResponseWriter, r *http.Request)
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
PostInit(w http.ResponseWriter, r *http.Request)
// Get the stats of the service
// (GET /metrics)
GetMetrics(w http.ResponseWriter, r *http.Request)
// Quiesce continuous goroutines before Firecracker snapshot
// (POST /snapshot/prepare)
PostSnapshotPrepare(w http.ResponseWriter, r *http.Request)
}
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
type Unimplemented struct{}
// Get the environment variables
// (GET /envs)
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Download a file
// (GET /files)
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Check the health of the service
// (GET /health)
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Get the stats of the service
// (GET /metrics)
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Quiesce continuous goroutines before Firecracker snapshot
// (POST /snapshot/prepare)
func (_ Unimplemented) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
type MiddlewareFunc func(http.Handler) http.Handler
// GetEnvs operation middleware
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetEnvs(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetFiles operation middleware
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params GetFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), &params.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), &params.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), &params.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostFiles operation middleware
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params PostFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), &params.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), &params.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), &params.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetHealth operation middleware
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetHealth(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostInit operation middleware
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostInit(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetMetrics operation middleware
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetMetrics(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostSnapshotPrepare operation middleware
func (siw *ServerInterfaceWrapper) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostSnapshotPrepare(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
type UnescapedCookieParamError struct {
ParamName string
Err error
}
func (e *UnescapedCookieParamError) Error() string {
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
}
func (e *UnescapedCookieParamError) Unwrap() error {
return e.Err
}
type UnmarshalingParamError struct {
ParamName string
Err error
}
func (e *UnmarshalingParamError) Error() string {
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
}
func (e *UnmarshalingParamError) Unwrap() error {
return e.Err
}
type RequiredParamError struct {
ParamName string
}
func (e *RequiredParamError) Error() string {
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
}
type RequiredHeaderError struct {
ParamName string
Err error
}
func (e *RequiredHeaderError) Error() string {
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
}
func (e *RequiredHeaderError) Unwrap() error {
return e.Err
}
type InvalidParamFormatError struct {
ParamName string
Err error
}
func (e *InvalidParamFormatError) Error() string {
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
}
func (e *InvalidParamFormatError) Unwrap() error {
return e.Err
}
type TooManyValuesForParamError struct {
ParamName string
Count int
}
func (e *TooManyValuesForParamError) Error() string {
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
}
// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{})
}
type ChiServerOptions struct {
BaseURL string
BaseRouter chi.Router
Middlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseRouter: r,
})
}
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseURL: baseURL,
BaseRouter: r,
})
}
// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
r := options.BaseRouter
if r == nil {
r = chi.NewRouter()
}
if options.ErrorHandlerFunc == nil {
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusBadRequest)
}
}
wrapper := ServerInterfaceWrapper{
Handler: si,
HandlerMiddlewares: options.Middlewares,
ErrorHandlerFunc: options.ErrorHandlerFunc,
}
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/init", wrapper.PostInit)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/snapshot/prepare", wrapper.PostSnapshotPrepare)
})
return r
}

View File

@ -1,133 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/awnumar/memguard"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
const (
SigningReadOperation = "read"
SigningWriteOperation = "write"
accessTokenHeader = "X-Access-Token"
)
// paths that are always allowed without general authentication
// POST/init is secured via MMDS hash validation instead
var authExcludedPaths = []string{
"GET/health",
"GET/files",
"POST/files",
"POST/init",
"POST/snapshot/prepare",
}
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if a.accessToken.IsSet() {
authHeader := req.Header.Get(accessTokenHeader)
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
if !a.accessToken.Equals(authHeader) && !allowedPath {
a.logger.Error().Msg("Trying to access secured envd without correct access token")
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
jsonError(w, http.StatusUnauthorized, err)
return
}
}
handler.ServeHTTP(w, req)
})
}
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
tokenBytes, err := a.accessToken.Bytes()
if err != nil {
return "", fmt.Errorf("access token is not set: %w", err)
}
defer memguard.WipeBytes(tokenBytes)
var signature string
hasher := keys.NewSHA256Hashing()
if signatureExpiration == nil {
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
} else {
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
}
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
}
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
var expectedSignature string
// no need to validate signing key if access token is not set
if !a.accessToken.IsSet() {
return nil
}
// check if access token is sent in the header
tokenFromHeader := r.Header.Get(accessTokenHeader)
if tokenFromHeader != "" {
if !a.accessToken.Equals(tokenFromHeader) {
return fmt.Errorf("access token present in header but does not match")
}
return nil
}
if signature == nil {
return fmt.Errorf("missing signature query parameter")
}
// Empty string is used when no username is provided and the default user should be used
signatureUsername := ""
if username != nil {
signatureUsername = *username
}
if signatureExpiration == nil {
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
} else {
exp := int64(*signatureExpiration)
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
}
if err != nil {
a.logger.Error().Err(err).Msg("error generating signing key")
return errors.New("invalid signature")
}
// signature validation
if expectedSignature != *signature {
return fmt.Errorf("invalid signature")
}
// signature expiration
if signatureExpiration != nil {
exp := int64(*signatureExpiration)
if exp < time.Now().Unix() {
return fmt.Errorf("signature is already expired")
}
}
return nil
}

View File

@ -1,64 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/demo.txt"
username := "root"
operation := "write"
timestamp := time.Now().Unix()
signature, err := api.generateSignature(path, username, operation, &timestamp)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/resource.txt"
username := "user"
operation := "read"
signature, err := api.generateSignature(path, username, operation, nil)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}

View File

@ -1,10 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
package: api
output: api.gen.go
generate:
models: true
chi-server: true
client: false

View File

@ -1,187 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"compress/gzip"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"os/user"
"path/filepath"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
)
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
defer r.Body.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File read")
}()
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
if err != nil {
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
stat, err := os.Stat(resolvedPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
errorCode = http.StatusNotFound
jsonError(w, errorCode, errMsg)
return
}
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
if stat.IsDir() {
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
// Reject anything that isn't a regular file (devices, pipes, sockets, etc.).
// Reading device files like /dev/zero or /dev/urandom produces infinite data
// and will exhaust memory on all layers of the stack.
if !stat.Mode().IsRegular() {
errMsg = fmt.Errorf("path '%s' is not a regular file", resolvedPath)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
// Validate Accept-Encoding header
encoding, err := parseAcceptEncoding(r)
if err != nil {
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
// Tell caches to store separate variants for different Accept-Encoding values
w.Header().Set("Vary", "Accept-Encoding")
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
// is acceptable per the Accept-Encoding header.
hasRangeOrConditional := r.Header.Get("Range") != "" ||
r.Header.Get("If-Modified-Since") != "" ||
r.Header.Get("If-None-Match") != "" ||
r.Header.Get("If-Range") != ""
if hasRangeOrConditional {
if !isIdentityAcceptable(r) {
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
encoding = EncodingIdentity
}
file, err := os.Open(resolvedPath)
if err != nil {
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
defer file.Close()
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
// Serve with gzip encoding if requested.
if encoding == EncodingGzip {
w.Header().Set("Content-Encoding", EncodingGzip)
// Set Content-Type based on file extension, preserving the original type
contentType := mime.TypeByExtension(filepath.Ext(path))
if contentType == "" {
contentType = "application/octet-stream"
}
w.Header().Set("Content-Type", contentType)
gw := gzip.NewWriter(w)
defer gw.Close()
_, err = io.Copy(gw, file)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
}
return
}
http.ServeContent(w, r, path, stat.ModTime(), file)
}

View File

@ -1,405 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"bytes"
"compress/gzip"
"context"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/user"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestGetFilesContentDisposition(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
tests := []struct {
name string
filename string
expectedHeader string
}{
{
name: "simple filename",
filename: "test.txt",
expectedHeader: `inline; filename=test.txt`,
},
{
name: "filename with extension",
filename: "presentation.pptx",
expectedHeader: `inline; filename=presentation.pptx`,
},
{
name: "filename with multiple dots",
filename: "archive.tar.gz",
expectedHeader: `inline; filename=archive.tar.gz`,
},
{
name: "filename with spaces",
filename: "my document.pdf",
expectedHeader: `inline; filename="my document.pdf"`,
},
{
name: "filename with quotes",
filename: `file"name.txt`,
expectedHeader: `inline; filename="file\"name.txt"`,
},
{
name: "filename with backslash",
filename: `file\name.txt`,
expectedHeader: `inline; filename="file\\name.txt"`,
},
{
name: "unicode filename",
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
},
{
name: "dotfile preserved",
filename: ".env",
expectedHeader: `inline; filename=.env`,
},
{
name: "dotfile with extension preserved",
filename: ".gitignore",
expectedHeader: `inline; filename=.gitignore`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a temp directory and file
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, tt.filename)
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
})
}
}
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with nested structure
tempDir := t.TempDir()
nestedDir := filepath.Join(tempDir, "subdir", "another")
err = os.MkdirAll(nestedDir, 0o755)
require.NoError(t, err)
filename := "document.pdf"
tempFile := filepath.Join(nestedDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header uses only the base filename, not the full path
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
}
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with a test file
tempDir := t.TempDir()
filename := "document.pdf"
tempFile := filepath.Join(tempDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
}
func TestGetFiles_GzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file for gzip compression")
// Create a temp file with known content
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test.txt")
err = os.WriteFile(tempFile, originalContent, 0o644)
require.NoError(t, err)
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
// Decompress the gzip response body
gzReader, err := gzip.NewReader(resp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}
func TestPostFiles_GzipUpload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file uploaded with gzip")
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
// Create test API
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "uploaded.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
req.Header.Set("Content-Encoding", "gzip")
w := httptest.NewRecorder()
params := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify the file was written with the original (decompressed) content
data, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
}
func TestGzipUploadThenGzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
// --- Upload with gzip ---
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "roundtrip.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
uploadReq.Header.Set("Content-Encoding", "gzip")
uploadW := httptest.NewRecorder()
uploadParams := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(uploadW, uploadReq, uploadParams)
uploadResp := uploadW.Result()
defer uploadResp.Body.Close()
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
// --- Download with gzip ---
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
downloadReq.Header.Set("Accept-Encoding", "gzip")
downloadW := httptest.NewRecorder()
downloadParams := GetFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.GetFiles(downloadW, downloadReq, downloadParams)
downloadResp := downloadW.Result()
defer downloadResp.Body.Close()
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
// Decompress and verify content matches original
gzReader, err := gzip.NewReader(downloadResp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}

View File

@ -1,229 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"compress/gzip"
"fmt"
"io"
"net/http"
"slices"
"sort"
"strconv"
"strings"
)
const (
// EncodingGzip is the gzip content encoding.
EncodingGzip = "gzip"
// EncodingIdentity means no encoding (passthrough).
EncodingIdentity = "identity"
// EncodingWildcard means any encoding is acceptable.
EncodingWildcard = "*"
)
// SupportedEncodings lists the content encodings supported for file transfer.
// The order matters - encodings are checked in order of preference.
var SupportedEncodings = []string{
EncodingGzip,
}
// encodingWithQuality holds an encoding name and its quality value.
type encodingWithQuality struct {
encoding string
quality float64
}
// isSupportedEncoding checks if the given encoding is in the supported list.
// Per RFC 7231, content-coding values are case-insensitive.
func isSupportedEncoding(encoding string) bool {
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
}
// parseEncodingWithQuality parses an encoding value and extracts the quality.
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
// Per RFC 7231, content-coding values are case-insensitive.
func parseEncodingWithQuality(value string) encodingWithQuality {
value = strings.TrimSpace(value)
quality := 1.0
if idx := strings.Index(value, ";"); idx != -1 {
params := value[idx+1:]
value = strings.TrimSpace(value[:idx])
// Parse q=X.X parameter
for param := range strings.SplitSeq(params, ";") {
param = strings.TrimSpace(param)
if strings.HasPrefix(strings.ToLower(param), "q=") {
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
quality = q
}
}
}
}
// Normalize encoding to lowercase per RFC 7231
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
}
// parseEncoding extracts the encoding name from a header value, stripping quality.
func parseEncoding(value string) string {
return parseEncodingWithQuality(value).encoding
}
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
// Returns an error if an unsupported encoding is specified.
// If no Content-Encoding header is present, returns empty string.
func parseContentEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Content-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encoding := parseEncoding(header)
if encoding == EncodingIdentity {
return EncodingIdentity, nil
}
if !isSupportedEncoding(encoding) {
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
}
return encoding, nil
}
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
// the parsed encodings along with the identity rejection state.
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
if header == "" {
return nil, false // identity not rejected when header is empty
}
// Parse all encodings with their quality values
var encodings []encodingWithQuality
for value := range strings.SplitSeq(header, ",") {
eq := parseEncodingWithQuality(value)
encodings = append(encodings, eq)
}
// Check if identity is rejected per RFC 7231 Section 5.3.4:
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
// without a more specific entry for identity with q>0.
identityRejected := false
identityExplicitlyAccepted := false
wildcardRejected := false
for _, eq := range encodings {
switch eq.encoding {
case EncodingIdentity:
if eq.quality == 0 {
identityRejected = true
} else {
identityExplicitlyAccepted = true
}
case EncodingWildcard:
if eq.quality == 0 {
wildcardRejected = true
}
}
}
if wildcardRejected && !identityExplicitlyAccepted {
identityRejected = true
}
return encodings, identityRejected
}
// isIdentityAcceptable checks if identity encoding is acceptable based on the
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
// implicitly acceptable unless explicitly rejected with q=0.
func isIdentityAcceptable(r *http.Request) bool {
header := r.Header.Get("Accept-Encoding")
_, identityRejected := parseAcceptEncodingHeader(header)
return !identityRejected
}
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
// identity is always implicitly acceptable unless explicitly rejected with q=0.
// If no Accept-Encoding header is present, returns empty string (identity).
func parseAcceptEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Accept-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encodings, identityRejected := parseAcceptEncodingHeader(header)
// Sort by quality value (highest first)
sort.Slice(encodings, func(i, j int) bool {
return encodings[i].quality > encodings[j].quality
})
// Find the best supported encoding
for _, eq := range encodings {
// Skip encodings with q=0 (explicitly rejected)
if eq.quality == 0 {
continue
}
if eq.encoding == EncodingIdentity {
return EncodingIdentity, nil
}
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
if eq.encoding == EncodingWildcard {
if identityRejected && len(SupportedEncodings) > 0 {
return SupportedEncodings[0], nil
}
return EncodingIdentity, nil
}
if isSupportedEncoding(eq.encoding) {
return eq.encoding, nil
}
}
// Per RFC 7231, identity is implicitly acceptable unless rejected
if !identityRejected {
return EncodingIdentity, nil
}
// Identity rejected and no supported encodings found
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
}
// getDecompressedBody returns a reader that decompresses the request body based on
// Content-Encoding header. Returns the original body if no encoding is specified.
// Returns an error if an unsupported encoding is specified.
// The caller is responsible for closing both the returned ReadCloser and the
// original request body (r.Body) separately.
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
encoding, err := parseContentEncoding(r)
if err != nil {
return nil, err
}
if encoding == EncodingIdentity {
return r.Body, nil
}
switch encoding {
case EncodingGzip:
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
return gzReader, nil
default:
// This shouldn't happen if isSupportedEncoding is correct
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
}
}

View File

@ -1,496 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSupportedEncoding(t *testing.T) {
t.Parallel()
t.Run("gzip is supported", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("gzip"))
})
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("GZIP"))
})
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("Gzip"))
})
t.Run("br is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("br"))
})
t.Run("deflate is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("deflate"))
})
}
func TestParseEncodingWithQuality(t *testing.T) {
t.Parallel()
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("parses quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
t.Run("parses quality value with whitespace", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip ; q=0.8")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.8, eq.quality, 0.001)
})
t.Run("handles q=0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.0, eq.quality, 0.001)
})
t.Run("handles invalid quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=invalid")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
})
t.Run("trims whitespace from encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality(" gzip ")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("GZIP")
assert.Equal(t, "gzip", eq.encoding)
})
t.Run("normalizes mixed case encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("Gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
}
func TestParseEncoding(t *testing.T) {
t.Parallel()
t.Run("returns encoding as-is", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip"))
})
t.Run("trims whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding(" gzip "))
})
t.Run("strips quality value", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
})
t.Run("strips quality value with whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
})
}
func TestParseContentEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "GZIP")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "Gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "identity")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "br")
_, err := parseContentEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
assert.Contains(t, err.Error(), "supported: [gzip]")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip;q=1.0")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
}
func TestParseAcceptEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "GZIP")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("skips encoding with q=0", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br, identity;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*, identity;q=0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
})
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingGzip, encoding)
})
}
func TestGetDecompressedBody(t *testing.T) {
t.Parallel()
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
t.Run("returns error for invalid gzip data", func(t *testing.T) {
t.Parallel()
invalidGzip := []byte("this is not gzip data")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
req.Header.Set("Content-Encoding", "gzip")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to create gzip reader")
})
t.Run("returns original body for identity encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "identity")
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "br")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip;q=1.0")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
}

View File

@ -1,31 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"net/http"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
)
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
operationID := logs.AssignOperationID()
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
envs := make(EnvVars)
a.defaults.EnvVars.Range(func(key, value string) bool {
envs[key] = value
return true
})
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(envs); err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
}
}

View File

@ -1,23 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"errors"
"net/http"
)
func jsonError(w http.ResponseWriter, code int, err error) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
encodeErr := json.NewEncoder(w).Encode(Error{
Code: code,
Message: err.Error(),
})
if encodeErr != nil {
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
}
}

View File

@ -1,5 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml

View File

@ -1,296 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"os/exec"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
)
var (
ErrAccessTokenMismatch = errors.New("access token validation failed")
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
requestTokenSet := requestToken.IsSet()
// Fast path: token matches existing
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
return nil
}
// Check MMDS only if token didn't match existing
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
switch {
case matchesMMDS:
return nil
case !a.accessToken.IsSet() && !mmdsExists:
return nil // first-time setup
case !requestTokenSet:
return ErrAccessTokenResetNotAuthorized
default:
return ErrAccessTokenMismatch
}
}
// checkMMDSHash checks if the request token matches the MMDS hash.
// Returns (matches, mmdsExists).
//
// The MMDS hash is set by the orchestrator during Resume:
// - hash(token): requires this specific token
// - hash(""): explicitly allows nil token (token reset authorized)
// - "": MMDS not properly configured, no authorization granted
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
if a.isNotFC {
return false, false
}
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
if err != nil {
return false, false
}
if mmdsHash == "" {
return false, false
}
if !requestToken.IsSet() {
return mmdsHash == keys.HashAccessToken(""), true
}
tokenBytes, err := requestToken.Bytes()
if err != nil {
return false, true
}
defer memguard.WipeBytes(tokenBytes)
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
}
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := r.Context()
operationID := logs.AssignOperationID()
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
if r.Body != nil {
// Read raw body so we can wipe it after parsing
body, err := io.ReadAll(r.Body)
// Ensure body is wiped after we're done
defer memguard.WipeBytes(body)
if err != nil {
logger.Error().Msgf("Failed to read request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
var initRequest PostInitJSONBody
if len(body) > 0 {
err = json.Unmarshal(body, &initRequest)
if err != nil {
logger.Error().Msgf("Failed to decode request: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
}
// Ensure request token is destroyed if not transferred via TakeFrom.
// This handles: validation failures, timestamp-based skips, and any early returns.
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
defer initRequest.AccessToken.Destroy()
a.initLock.Lock()
defer a.initLock.Unlock()
// Update data only if the request is newer or if there's no timestamp at all
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
err = a.SetData(ctx, logger, initRequest)
if err != nil {
switch {
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Error().Msgf("Failed to set data: %v", err)
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(err.Error()))
return
}
}
}
go func() { //nolint:contextcheck // TODO: fix this later
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
}()
// Start the port scanner and forwarder if they were stopped by a
// pre-snapshot prepare call. Start is a no-op if already running,
// so this is safe on first boot and only takes effect after restore.
if a.portSubsystem != nil {
a.portSubsystem.Start(a.rootCtx)
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
// Validate access token before proceeding with any action
// The request must provide a token that is either:
// 1. Matches the existing access token (if set), OR
// 2. Matches the MMDS hash (for token change during resume)
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
return err
}
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
for key, value := range *data.EnvVars {
logger.Debug().Msgf("Setting env var for %s", key)
a.defaults.EnvVars.Store(key, value)
}
}
if data.AccessToken.IsSet() {
logger.Debug().Msg("Setting access token")
a.accessToken.TakeFrom(data.AccessToken)
} else if a.accessToken.IsSet() {
logger.Debug().Msg("Clearing access token")
a.accessToken.Destroy()
}
if data.HyperloopIP != nil {
go a.SetupHyperloop(*data.HyperloopIP)
}
if data.DefaultUser != nil && *data.DefaultUser != "" {
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
a.defaults.User = *data.DefaultUser
}
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
a.defaults.Workdir = data.DefaultWorkdir
}
if data.VolumeMounts != nil {
for _, volume := range *data.VolumeMounts {
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
}
}
return nil
}
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
commands := [][]string{
{"mkdir", "-p", path},
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
}
for _, command := range commands {
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
logger := a.getLogger(err)
logger.
Strs("command", command).
Str("output", string(data)).
Msg("Mount NFS")
if err != nil {
return
}
}
}
func (a *API) SetupHyperloop(address string) {
a.hyperloopLock.Lock()
defer a.hyperloopLock.Unlock()
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
a.logger.Error().Err(err).Msg("failed to modify hosts file")
} else {
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
}
}
const eventsHost = "events.wrenn.local"
func rewriteHostsFile(address, path string) error {
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
ReadFilePath: path,
WriteFilePath: path,
})
if err != nil {
return fmt.Errorf("failed to create hosts: %w", err)
}
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
// This will remove any existing entries for events.wrenn.local first
ipFamily, err := getIPFamily(address)
if err != nil {
return fmt.Errorf("failed to get ip family: %w", err)
}
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
return nil // nothing to be done
}
hosts.AddHost(address, eventsHost)
return hosts.Save()
}
var (
ErrInvalidAddress = errors.New("invalid IP address")
ErrUnknownAddressFormat = errors.New("unknown IP address format")
)
func getIPFamily(address string) (txeh.IPFamily, error) {
addressIP, err := netip.ParseAddr(address)
if err != nil {
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
}
switch {
case addressIP.Is4():
return txeh.IPFamilyV4, nil
case addressIP.Is6():
return txeh.IPFamilyV6, nil
default:
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}

View File

@ -1,524 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestSimpleCases(t *testing.T) {
t.Parallel()
testCases := map[string]func(string) string{
"both newlines": func(s string) string { return s },
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
"no newline prefix or suffix": strings.TrimSpace,
}
for name, preprocessor := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
value := `
# comment
127.0.0.1 one.host
127.0.0.2 two.host
`
value = preprocessor(value)
inputPath := filepath.Join(tempDir, "hosts")
err := os.WriteFile(inputPath, []byte(value), 0o644)
require.NoError(t, err)
err = rewriteHostsFile("127.0.0.3", inputPath)
require.NoError(t, err)
data, err := os.ReadFile(inputPath)
require.NoError(t, err)
assert.Equal(t, `# comment
127.0.0.1 one.host
127.0.0.2 two.host
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
})
}
}
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))
return token
}
type mockMMDSClient struct {
hash string
err error
}
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
return m.hash, m.err
}
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
}
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
if accessToken != nil {
api.accessToken.TakeFrom(accessToken)
}
api.mmdsClient = mmdsClient
return api
}
func TestValidateInitAccessToken(t *testing.T) {
t.Parallel()
ctx := t.Context()
tests := []struct {
name string
accessToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
}{
{
name: "fast path: token matches existing",
accessToken: secureTokenPtr("secret-token"),
requestToken: secureTokenPtr("secret-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "MMDS match: token hash matches MMDS hash",
accessToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: no existing token, MMDS error",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "first-time setup: no existing token, empty MMDS hash",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: both tokens nil, no MMDS",
accessToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "mismatch: existing token differs from request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
},
{
name: "mismatch: existing token differs from request, MMDS hash mismatch",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
},
{
name: "conflict: existing token, nil request, MMDS exists",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
},
{
name: "conflict: existing token, nil request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.accessToken, mmdsClient)
err := api.validateInitAccessToken(ctx, tt.requestToken)
if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestCheckMMDSHash(t *testing.T) {
t.Parallel()
ctx := t.Context()
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
t.Parallel()
token := "my-secret-token"
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
assert.True(t, matches)
assert.True(t, exists)
})
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.True(t, matches)
assert.True(t, exists)
})
}
func TestSetData(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := zerolog.Nop()
t.Run("access token updates", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
wantFinalToken *SecureToken
}{
{
name: "first-time setup: sets initial token",
existingToken: nil,
requestToken: secureTokenPtr("initial-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("initial-token"),
},
{
name: "first-time setup: nil request token leaves token unset",
existingToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: nil,
},
{
name: "re-init with same token: token unchanged",
existingToken: secureTokenPtr("same-token"),
requestToken: secureTokenPtr("same-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("same-token"),
},
{
name: "resume with MMDS: updates token when hash matches",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: secureTokenPtr("new-token"),
},
{
name: "resume with MMDS: fails when hash doesn't match",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("old-token"),
},
{
name: "fails when existing token and request token mismatch without MMDS",
existingToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request token",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request with MMDS present",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken(""),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.existingToken, mmdsClient)
data := PostInitJSONBody{
AccessToken: tt.requestToken,
}
err := api.SetData(ctx, logger, data)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
if tt.wantFinalToken == nil {
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
} else {
require.True(t, api.accessToken.IsSet(), "expected token to be set")
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
}
})
}
})
t.Run("sets environment variables", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
data := PostInitJSONBody{
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
val, ok := api.defaults.EnvVars.Load("FOO")
assert.True(t, ok)
assert.Equal(t, "bar", val)
val, ok = api.defaults.EnvVars.Load("BAZ")
assert.True(t, ok)
assert.Equal(t, "qux", val)
})
t.Run("sets default user", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr("testuser"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "testuser", api.defaults.User)
})
t.Run("does not set default user when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
api.defaults.User = "original"
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "original", api.defaults.User)
})
t.Run("sets default workdir", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/home/user", *api.defaults.Workdir)
})
t.Run("does not set default workdir when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
originalWorkdir := "/original"
api.defaults.Workdir = &originalWorkdir
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/original", *api.defaults.Workdir)
})
t.Run("sets multiple fields at once", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"KEY": "value"}
data := PostInitJSONBody{
AccessToken: secureTokenPtr("token"),
DefaultUser: utilsShared.ToPtr("user"),
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
assert.Equal(t, "user", api.defaults.User)
assert.Equal(t, "/workdir", *api.defaults.Workdir)
val, ok := api.defaults.EnvVars.Load("KEY")
assert.True(t, ok)
assert.Equal(t, "value", val)
})
}

View File

@ -1,214 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"errors"
"sync"
"github.com/awnumar/memguard"
)
var (
ErrTokenNotSet = errors.New("access token not set")
ErrTokenEmpty = errors.New("empty token not allowed")
)
// SecureToken wraps memguard for secure token storage.
// It uses LockedBuffer which provides memory locking, guard pages,
// and secure zeroing on destroy.
type SecureToken struct {
mu sync.RWMutex
buffer *memguard.LockedBuffer
}
// Set securely replaces the token, destroying the old one first.
// The old token memory is zeroed before the new token is stored.
// The input byte slice is wiped after copying to secure memory.
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
func (s *SecureToken) Set(token []byte) error {
if len(token) == 0 {
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
// Destroy old token first (zeros memory)
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
s.buffer = memguard.NewBufferFromBytes(token)
return nil
}
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
// directly into memguard, wiping the input bytes after copying.
//
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
// so they never contain JSON escape sequences.
func (s *SecureToken) UnmarshalJSON(data []byte) error {
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
memguard.WipeBytes(data)
return errors.New("invalid secure token JSON string")
}
content := data[1 : len(data)-1]
// Access tokens are hex strings - reject if contains backslash
if bytes.ContainsRune(content, '\\') {
memguard.WipeBytes(data)
return errors.New("invalid secure token: unexpected escape sequence")
}
if len(content) == 0 {
memguard.WipeBytes(data)
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Allocate secure buffer and copy directly into it
s.buffer = memguard.NewBuffer(len(content))
copy(s.buffer.Bytes(), content)
// Wipe the input data
memguard.WipeBytes(data)
return nil
}
// TakeFrom transfers the token from src to this SecureToken, destroying any
// existing token. The source token is cleared after transfer.
// This avoids copying the underlying bytes.
func (s *SecureToken) TakeFrom(src *SecureToken) {
if src == nil || s == src {
return
}
// Extract buffer from source
src.mu.Lock()
buffer := src.buffer
src.buffer = nil
src.mu.Unlock()
// Install buffer in destination
s.mu.Lock()
if s.buffer != nil {
s.buffer.Destroy()
}
s.buffer = buffer
s.mu.Unlock()
}
// Equals checks if token matches using constant-time comparison.
// Returns false if the receiver is nil.
func (s *SecureToken) Equals(token string) bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo([]byte(token))
}
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
// Returns false if either receiver or other is nil.
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
if s == nil || other == nil {
return false
}
if s == other {
return s.IsSet()
}
// Get a copy of other's bytes (avoids holding two locks simultaneously)
otherBytes, err := other.Bytes()
if err != nil {
return false
}
defer memguard.WipeBytes(otherBytes)
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo(otherBytes)
}
// IsSet returns true if a token is stored.
// Returns false if the receiver is nil.
func (s *SecureToken) IsSet() bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.buffer != nil && s.buffer.IsAlive()
}
// Bytes returns a copy of the token bytes (for signature generation).
// The caller should zero the returned slice after use.
// Returns ErrTokenNotSet if the receiver is nil.
func (s *SecureToken) Bytes() ([]byte, error) {
if s == nil {
return nil, ErrTokenNotSet
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return nil, ErrTokenNotSet
}
// Return a copy (unavoidable for signature generation)
src := s.buffer.Bytes()
result := make([]byte, len(src))
copy(result, src)
return result, nil
}
// Destroy securely wipes the token from memory.
// No-op if the receiver is nil.
func (s *SecureToken) Destroy() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
}

View File

@ -1,463 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"sync"
"testing"
"github.com/awnumar/memguard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSecureTokenSetAndEquals(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Initially not set
assert.False(t, st.IsSet(), "token should not be set initially")
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
// Set token
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet(), "token should be set after Set()")
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
assert.False(t, st.Equals(""), "equals should return false for empty token")
}
func TestSecureTokenReplace(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set initial token
err := st.Set([]byte("first-token"))
require.NoError(t, err)
assert.True(t, st.Equals("first-token"))
// Replace with new token (old one should be destroyed)
err = st.Set([]byte("second-token"))
require.NoError(t, err)
assert.True(t, st.Equals("second-token"), "should match new token")
assert.False(t, st.Equals("first-token"), "should not match old token")
}
func TestSecureTokenDestroy(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set and then destroy
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
st.Destroy()
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
// Destroy on already destroyed should be safe
st.Destroy()
assert.False(t, st.IsSet())
// Nil receiver should be safe
var nilToken *SecureToken
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
nilToken.Destroy() // should not panic
_, err = nilToken.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
}
func TestSecureTokenBytes(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Bytes should return error when not set
_, err := st.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet)
// Set token and get bytes
err = st.Set([]byte("test-token"))
require.NoError(t, err)
bytes, err := st.Bytes()
require.NoError(t, err)
assert.Equal(t, []byte("test-token"), bytes)
// Zero out the bytes (as caller should do)
memguard.WipeBytes(bytes)
// Original should still be intact
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
// After destroy, bytes should fail
st.Destroy()
_, err = st.Bytes()
assert.ErrorIs(t, err, ErrTokenNotSet)
}
func TestSecureTokenConcurrentAccess(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("initial-token"))
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 100
// Concurrent reads
for range numGoroutines {
wg.Go(func() {
st.IsSet()
st.Equals("initial-token")
})
}
// Concurrent writes
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
st.Set([]byte("token-" + string(rune('a'+idx))))
}(i)
}
wg.Wait()
// Should still be in a valid state
assert.True(t, st.IsSet())
}
func TestSecureTokenEmptyToken(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Setting empty token should return an error
err := st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after empty token error")
// Setting nil should also return an error
err = st.Set(nil)
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after nil token error")
}
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set a valid token first
err := st.Set([]byte("valid-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
// Attempting to set empty token should fail and preserve existing token
err = st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
}
func TestSecureTokenUnmarshalJSON(t *testing.T) {
t.Parallel()
t.Run("unmarshals valid JSON string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
require.NoError(t, err)
assert.True(t, st.IsSet())
assert.True(t, st.Equals("my-secret-token"))
})
t.Run("returns error for empty string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`""`))
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet())
})
t.Run("returns error for invalid JSON", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`not-valid-json`))
require.Error(t, err)
assert.False(t, st.IsSet())
})
t.Run("replaces existing token", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("old-token"))
require.NoError(t, err)
err = st.UnmarshalJSON([]byte(`"new-token"`))
require.NoError(t, err)
assert.True(t, st.Equals("new-token"))
assert.False(t, st.Equals("old-token"))
})
t.Run("wipes input buffer after parsing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte(`"secret-token-12345"`)
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("secret-token-12345"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("wipes input buffer on error", func(t *testing.T) {
t.Parallel()
// Create a buffer with an empty token (will error)
input := []byte(`""`)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.Error(t, err)
// Verify the input buffer was still wiped
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("rejects escape sequences", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
require.Error(t, err)
assert.Contains(t, err.Error(), "escape sequence")
assert.False(t, st.IsSet())
})
}
func TestSecureTokenSetWipesInput(t *testing.T) {
t.Parallel()
t.Run("wipes input buffer after storing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte("my-secret-token")
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.Set(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("my-secret-token"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
}
func TestSecureTokenTakeFrom(t *testing.T) {
t.Parallel()
t.Run("transfers token from source to destination", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("source-token"))
require.NoError(t, err)
dst := &SecureToken{}
dst.TakeFrom(src)
assert.True(t, dst.IsSet())
assert.True(t, dst.Equals("source-token"))
assert.False(t, src.IsSet(), "source should be empty after transfer")
})
t.Run("replaces existing destination token", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("new-token"))
require.NoError(t, err)
dst := &SecureToken{}
err = dst.Set([]byte("old-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.True(t, dst.Equals("new-token"))
assert.False(t, dst.Equals("old-token"))
assert.False(t, src.IsSet())
})
t.Run("handles nil source", func(t *testing.T) {
t.Parallel()
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(nil)
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
assert.True(t, dst.Equals("existing-token"))
})
t.Run("handles empty source", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
})
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
st.TakeFrom(st)
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
assert.True(t, st.Equals("token"), "token value should be unchanged")
})
}
func TestSecureTokenEqualsSecure(t *testing.T) {
t.Parallel()
t.Run("returns true for matching tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("same-token"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("same-token"))
require.NoError(t, err)
assert.True(t, st1.EqualsSecure(st2))
assert.True(t, st2.EqualsSecure(st1))
})
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
t.Parallel()
// This test verifies the fix for the lock ordering deadlock bug.
const iterations = 100
for range iterations {
a := &SecureToken{}
err := a.Set([]byte("token-a"))
require.NoError(t, err)
b := &SecureToken{}
err = b.Set([]byte("token-b"))
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)
// Goroutine 1: a.TakeFrom(b)
go func() {
defer wg.Done()
a.TakeFrom(b)
}()
// Goroutine 2: b.EqualsSecure(a)
go func() {
defer wg.Done()
b.EqualsSecure(a)
}()
wg.Wait()
}
})
t.Run("returns false for different tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token-a"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("token-b"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when comparing with nil", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st.EqualsSecure(nil))
})
t.Run("returns false when other is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token"))
require.NoError(t, err)
st2 := &SecureToken{}
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when self is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
st2 := &SecureToken{}
err := st2.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("self-comparison returns true when set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
})
t.Run("self-comparison returns false when not set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
})
}

View File

@ -1,25 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"net/http"
)
// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder)
// and forces a GC cycle before Firecracker takes a VM snapshot. This ensures
// the Go runtime's page allocator is in a consistent state when vCPUs are frozen.
//
// Called by the host agent as a best-effort signal before vm.Pause().
func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if a.portSubsystem != nil {
a.portSubsystem.Stop()
a.logger.Info().Msg("snapshot/prepare: port subsystem quiesced")
}
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(http.StatusNoContent)
}

View File

@ -1,108 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"encoding/json"
"net/http"
"sync"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
// MMDSClient provides access to MMDS metadata.
type MMDSClient interface {
GetAccessTokenHash(ctx context.Context) (string, error)
}
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
type DefaultMMDSClient struct{}
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
return host.GetAccessTokenHashFromMMDS(ctx)
}
type API struct {
isNotFC bool
logger *zerolog.Logger
accessToken *SecureToken
defaults *execcontext.Defaults
version string
mmdsChan chan *host.MMDSOpts
hyperloopLock sync.Mutex
mmdsClient MMDSClient
lastSetTime *utils.AtomicMax
initLock sync.Mutex
// rootCtx is the parent context from main(), used to restart
// long-lived goroutines after snapshot restore.
rootCtx context.Context
portSubsystem *publicport.PortSubsystem
}
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem, version string) *API {
return &API{
logger: l,
defaults: defaults,
mmdsChan: mmdsChan,
isNotFC: isNotFC,
mmdsClient: &DefaultMMDSClient{},
lastSetTime: utils.NewAtomicMax(),
accessToken: &SecureToken{},
rootCtx: rootCtx,
portSubsystem: portSubsystem,
version: version,
}
}
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Health check")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"version": a.version,
})
}
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Get metrics")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
metrics, err := host.GetMetrics()
if err != nil {
a.logger.Error().Err(err).Msg("Failed to get metrics")
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(metrics); err != nil {
a.logger.Error().Err(err).Msg("Failed to encode metrics")
}
}
func (a *API) getLogger(err error) *zerolog.Event {
if err != nil {
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
}
return a.logger.Info() //nolint:zerologlint // this is only prep
}

View File

@ -1,311 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
"syscall"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
logger.Debug().
Str("path", path).
Msg("File processing")
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
if err != nil {
err := fmt.Errorf("error ensuring directories: %w", err)
return http.StatusInternalServerError, err
}
canBePreChowned := false
stat, err := os.Stat(path)
if err != nil && !os.IsNotExist(err) {
errMsg := fmt.Errorf("error getting file info: %w", err)
return http.StatusInternalServerError, errMsg
} else if err == nil {
if stat.IsDir() {
err := fmt.Errorf("path is a directory: %s", path)
return http.StatusBadRequest, err
}
canBePreChowned = true
}
hasBeenChowned := false
if canBePreChowned {
err = os.Chown(path, uid, gid)
if err != nil {
if !os.IsNotExist(err) {
err = fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
} else {
hasBeenChowned = true
}
}
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = fmt.Errorf("not enough inodes available: %w", err)
return http.StatusInsufficientStorage, err
}
err := fmt.Errorf("error opening file: %w", err)
return http.StatusInternalServerError, err
}
defer file.Close()
if !hasBeenChowned {
err = os.Chown(path, uid, gid)
if err != nil {
err := fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
}
_, err = file.ReadFrom(part)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = ErrNoDiskSpace
if r.ContentLength > 0 {
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
}
return http.StatusInsufficientStorage, err
}
err = fmt.Errorf("error writing file: %w", err)
return http.StatusInternalServerError, err
}
return http.StatusNoContent, nil
}
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
var pathToResolve string
if params.Path != nil {
pathToResolve = *params.Path
} else {
var err error
customPart := utils.NewCustomPart(part)
pathToResolve, err = customPart.FileNameWithPath()
if err != nil {
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
}
}
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
if err != nil {
return "", fmt.Errorf("error resolving path: %w", err)
}
for _, entry := range *paths {
if entry.Path == filePath {
var alreadyUploaded []string
for _, uploadedFile := range *paths {
if uploadedFile.Path != filePath {
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
}
}
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
if len(alreadyUploaded) > 1 {
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
}
return "", errMsg
}
}
return filePath, nil
}
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
defer part.Close()
if part.FormName() != "file" {
return nil, http.StatusOK, nil
}
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
if err != nil {
return nil, http.StatusBadRequest, err
}
logger := a.logger.
With().
Str(string(logs.OperationIDKey), operationID).
Str("event_type", "file_processing").
Logger()
status, err := processFile(r, filePath, part, uid, gid, logger)
if err != nil {
return nil, status, err
}
return &EntryInfo{
Path: filePath,
Name: filepath.Base(filePath),
Type: File,
}, http.StatusOK, nil
}
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
// Capture original body to ensure it's always closed
originalBody := r.Body
defer originalBody.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File write")
}()
// Handle gzip-encoded request body
body, err := getDecompressedBody(r)
if err != nil {
errMsg = fmt.Errorf("error decompressing request body: %w", err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
defer body.Close()
r.Body = body
f, err := r.MultipartReader()
if err != nil {
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
uid, gid, err := permissions.GetUserIdInts(u)
if err != nil {
errMsg = fmt.Errorf("error getting user ids: %w", err)
jsonError(w, http.StatusInternalServerError, errMsg)
return
}
paths := UploadSuccess{}
for {
part, partErr := f.NextPart()
if partErr == io.EOF {
// We're done reading the parts.
break
} else if partErr != nil {
errMsg = fmt.Errorf("error reading form: %w", partErr)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
break
}
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
if err != nil {
errorCode = status
errMsg = err
jsonError(w, errorCode, errMsg)
return
}
if entry != nil {
paths = append(paths, *entry)
}
}
data, err := json.Marshal(paths)
if err != nil {
errMsg = fmt.Errorf("error marshaling response: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}

View File

@ -1,251 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package api
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProcessFile(t *testing.T) {
t.Parallel()
uid := os.Getuid()
gid := os.Getgid()
newRequest := func(content []byte) (*http.Request, io.Reader) {
request := &http.Request{
ContentLength: int64(len(content)),
}
buffer := bytes.NewBuffer(content)
return request, buffer
}
var emptyReq http.Request
var emptyPart *bytes.Buffer
var emptyLogger zerolog.Logger
t.Run("failed to ensure directories", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error ensuring directories: ")
})
t.Run("attempt to replace directory with a file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
assert.ErrorContains(t, err, "path is a directory: ")
})
t.Run("fail to create file", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error opening file: ")
})
t.Run("out of disk space", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
mountSize := 1024
tempDir := createTmpfsMount(t, mountSize)
// create test file
firstFileSize := mountSize / 2
tempFile1 := filepath.Join(tempDir, "test-file-1")
// fill it up
cmd := exec.CommandContext(t.Context(),
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
err := cmd.Run()
require.NoError(t, err)
// create a new file that would fill up the
secondFileContents := make([]byte, mountSize*2)
for index := range secondFileContents {
secondFileContents[index] = 'a'
}
// try to replace it
request, buffer := newRequest(secondFileContents)
tempFile2 := filepath.Join(tempDir, "test-file-2")
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
})
t.Run("happy path", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("overwrite file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to replace it
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
})
t.Run("write new file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough disk space available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("write new file with no inodes available", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough inodes available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
t.Parallel()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
newContent := []byte("102\n")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
t.Run("replace file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
require.NoError(t, err)
newContent := []byte("new-file-contents")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
}
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
t.Helper()
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
}
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
t.Helper()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
tempDir := t.TempDir()
cmd := exec.CommandContext(t.Context(),
"mount",
"tmpfs",
tempDir,
"-t", "tmpfs",
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
err := cmd.Run()
require.NoError(t, err)
t.Cleanup(func() {
ctx := context.WithoutCancel(t.Context())
cmd := exec.CommandContext(ctx, "umount", tempDir)
err := cmd.Run()
require.NoError(t, err)
})
return tempDir
}

View File

@ -1,39 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package execcontext
import (
"errors"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Defaults struct {
EnvVars *utils.Map[string, string]
User string
Workdir *string
}
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
if workdir != "" {
return workdir
}
if defaultWorkdir != nil {
return *defaultWorkdir
}
return ""
}
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
if username != nil {
return *username, nil
}
if defaultUsername != "" {
return defaultUsername, nil
}
return "", errors.New("username not provided")
}

View File

@ -1,96 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package host
import (
"math"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
"golang.org/x/sys/unix"
)
type Metrics struct {
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
// Deprecated: kept for backwards compatibility with older orchestrators.
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
// Deprecated: kept for backwards compatibility with older orchestrators.
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
}
func GetMetrics() (*Metrics, error) {
v, err := mem.VirtualMemory()
if err != nil {
return nil, err
}
memUsedMiB := v.Used / 1024 / 1024
memTotalMiB := v.Total / 1024 / 1024
cpuTotal, err := cpu.Counts(true)
if err != nil {
return nil, err
}
cpuUsedPcts, err := cpu.Percent(0, false)
if err != nil {
return nil, err
}
cpuUsedPct := cpuUsedPcts[0]
cpuUsedPctRounded := float32(cpuUsedPct)
if cpuUsedPct > 0 {
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
}
diskMetrics, err := diskStats("/")
if err != nil {
return nil, err
}
return &Metrics{
Timestamp: time.Now().UTC().Unix(),
CPUCount: uint32(cpuTotal),
CPUUsedPercent: cpuUsedPctRounded,
MemUsedMiB: memUsedMiB,
MemTotalMiB: memTotalMiB,
MemTotal: v.Total,
MemUsed: v.Used,
DiskUsed: diskMetrics.Total - diskMetrics.Available,
DiskTotal: diskMetrics.Total,
}, nil
}
type diskSpace struct {
Total uint64
Available uint64
}
func diskStats(path string) (diskSpace, error) {
var st unix.Statfs_t
if err := unix.Statfs(path, &st); err != nil {
return diskSpace{}, err
}
block := uint64(st.Bsize)
// all data blocks
total := st.Blocks * block
// blocks available
available := st.Bavail * block
return diskSpace{Total: total, Available: available}, nil
}

View File

@ -1,185 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package host
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
const (
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
mmdsDefaultAddress = "169.254.169.254"
mmdsTokenExpiration = 60 * time.Second
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
)
var mmdsAccessTokenClient = &http.Client{
Timeout: mmdsAccessTokenRequestClientTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
type MMDSOpts struct {
SandboxID string `json:"instanceID"`
TemplateID string `json:"envID"`
LogsCollectorAddress string `json:"address"`
AccessTokenHash string `json:"accessTokenHash"`
}
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
opts.SandboxID = sandboxID
opts.TemplateID = templateID
opts.LogsCollectorAddress = collectorAddress
}
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
parsed := make(map[string]any)
err := json.Unmarshal(jsonLogs, &parsed)
if err != nil {
return nil, err
}
parsed["instanceID"] = opts.SandboxID
parsed["envID"] = opts.TemplateID
data, err := json.Marshal(parsed)
return data, err
}
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
if err != nil {
return "", err
}
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
token := string(body)
if len(token) == 0 {
return "", fmt.Errorf("mmds token is an empty string")
}
return token, nil
}
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
if err != nil {
return nil, err
}
request.Header["X-metadata-token"] = []string{token}
request.Header["Accept"] = []string{"application/json"}
response, err := client.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
var opts MMDSOpts
err = json.Unmarshal(body, &opts)
if err != nil {
return nil, err
}
return &opts, nil
}
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
// This is used to validate that /init requests come from the orchestrator.
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
if err != nil {
return "", fmt.Errorf("failed to get MMDS token: %w", err)
}
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
if err != nil {
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
}
return opts.AccessTokenHash, nil
}
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
httpClient := &http.Client{}
defer httpClient.CloseIdleConnections()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
return
case <-ticker.C:
token, err := getMMDSToken(ctx, httpClient)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
continue
}
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
continue
}
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
}
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
}
if mmdsOpts.LogsCollectorAddress != "" {
mmdsChan <- mmdsOpts
}
return
}
}
}

View File

@ -1,49 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"time"
"github.com/rs/zerolog"
)
const (
defaultMaxBufferSize = 2 << 15
defaultTimeout = 2 * time.Second
)
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
timer := time.NewTicker(defaultTimeout)
defer timer.Stop()
var buffer []byte
defer func() {
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
}
}()
for {
select {
case <-timer.C:
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
}
case data, ok := <-dataCh:
if !ok {
return
}
buffer = append(buffer, data...)
if len(buffer) >= defaultMaxBufferSize {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
continue
}
}
}
}

View File

@ -1,174 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package exporter
import (
"bytes"
"context"
"fmt"
"log"
"net/http"
"os"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
)
const ExporterTimeout = 10 * time.Second
type HTTPExporter struct {
client http.Client
logs [][]byte
isNotFC bool
mmdsOpts *host.MMDSOpts
// Concurrency coordination
triggers chan struct{}
logLock sync.RWMutex
mmdsLock sync.RWMutex
startOnce sync.Once
}
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
exporter := &HTTPExporter{
client: http.Client{
Timeout: ExporterTimeout,
},
triggers: make(chan struct{}, 1),
isNotFC: isNotFC,
startOnce: sync.Once{},
mmdsOpts: &host.MMDSOpts{
SandboxID: "unknown",
TemplateID: "unknown",
LogsCollectorAddress: "",
},
}
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
return exporter
}
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
if address == "" {
return nil
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
response, err := w.client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
return nil
}
func printLog(logs []byte) {
fmt.Fprintf(os.Stdout, "%v", string(logs))
}
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
for {
select {
case <-ctx.Done():
return
case mmdsOpts, ok := <-mmdsChan:
if !ok {
return
}
w.mmdsLock.Lock()
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
w.mmdsLock.Unlock()
w.startOnce.Do(func() {
go w.start(ctx)
})
}
}
}
func (w *HTTPExporter) start(ctx context.Context) {
for range w.triggers {
logs := w.getAllLogs()
if len(logs) == 0 {
continue
}
if w.isNotFC {
for _, log := range logs {
fmt.Fprintf(os.Stdout, "%v", string(log))
}
continue
}
for _, logLine := range logs {
w.mmdsLock.RLock()
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
w.mmdsLock.RUnlock()
if err != nil {
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
printLog(logLine)
continue
}
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
if err != nil {
log.Printf("error sending instance logs: %+v", err)
printLog(logLine)
continue
}
}
}
}
func (w *HTTPExporter) resumeProcessing() {
select {
case w.triggers <- struct{}{}:
default:
// Exporter processing already triggered
// This is expected behavior if the exporter is already processing logs
}
}
func (w *HTTPExporter) Write(logs []byte) (int, error) {
logsCopy := make([]byte, len(logs))
copy(logsCopy, logs)
go w.addLogs(logsCopy)
return len(logs), nil
}
func (w *HTTPExporter) getAllLogs() [][]byte {
w.logLock.Lock()
defer w.logLock.Unlock()
logs := w.logs
w.logs = nil
return logs
}
func (w *HTTPExporter) addLogs(logs []byte) {
w.logLock.Lock()
defer w.logLock.Unlock()
w.logs = append(w.logs, logs)
w.resumeProcessing()
}

View File

@ -1,174 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"context"
"fmt"
"strconv"
"strings"
"sync/atomic"
"connectrpc.com/connect"
"github.com/rs/zerolog"
)
type OperationID string
const (
OperationIDKey OperationID = "operation_id"
DefaultHTTPMethod string = "POST"
)
var operationID = atomic.Int32{}
func AssignOperationID() string {
id := operationID.Add(1)
return strconv.Itoa(int(id))
}
func AddRequestIDToContext(ctx context.Context) context.Context {
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
}
func formatMethod(method string) string {
parts := strings.Split(method, ".")
if len(parts) < 2 {
return method
}
split := strings.Split(parts[1], "/")
if len(split) < 2 {
return method
}
servicePart := split[0]
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
methodPart := split[1]
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
return fmt.Sprintf("%s %s", servicePart, methodPart)
}
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
ctx = AddRequestIDToContext(ctx)
res, err := next(ctx, req)
l := logger.
Err(err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
l = l.Int("error_code", int(connect.CodeOf(err)))
}
if req != nil {
l = l.Interface("request", req.Any())
}
if res != nil && err == nil {
l = l.Interface("response", res.Any())
}
if res == nil && err == nil {
l = l.Interface("response", nil)
}
l.Msg(formatMethod(req.Spec().Procedure))
return res, err
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
func LogServerStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
req *connect.Request[R],
stream *connect.ServerStream[T],
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
) error {
ctx = AddRequestIDToContext(ctx)
l := logger.Debug().
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if req != nil {
l = l.Interface("request", req.Any())
}
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
err := handler(ctx, req, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
} else {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
return err
}
func LogClientStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
stream *connect.ClientStream[T],
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
) (*connect.Response[R], error) {
ctx = AddRequestIDToContext(ctx)
logger.Debug().
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
res, err := handler(ctx, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
}
if res != nil && err == nil {
logEvent = logEvent.Interface("response", res.Any())
}
if res == nil && err == nil {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
return res, err
}
// Return logger with error level if err is not nil, otherwise return logger with debug level
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
if err != nil {
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
}
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
}

View File

@ -1,37 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package logs
import (
"context"
"io"
"os"
"time"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
)
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
zerolog.TimestampFieldName = "timestamp"
zerolog.TimeFieldFormat = time.RFC3339Nano
exporters := []io.Writer{}
if isNotFC {
exporters = append(exporters, os.Stdout)
} else {
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
}
l := zerolog.
New(io.MultiWriter(exporters...)).
With().
Timestamp().
Logger().
Level(zerolog.DebugLevel)
return &l
}

View File

@ -1,49 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
package permissions
import (
"context"
"fmt"
"os/user"
"connectrpc.com/authn"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
)
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
username, _, ok := req.BasicAuth()
if !ok {
// When no username is provided, ignore the authentication method (not all endpoints require it)
// Missing user is then handled in the GetAuthUser function
return nil, nil
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid username: '%s'", username)
}
return u, nil
}
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
u, ok := authn.GetInfo(ctx).(*user.User)
if !ok {
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid default user: '%s'", username)
}
return u, nil
}
return u, nil
}

Some files were not shown because too many files have changed in this diff Show More