Compare commits
296 Commits
v0.1.1
...
feat/migra
| Author | SHA1 | Date | |
|---|---|---|---|
| 62bede5dae | |||
| 74f85ce4e9 | |||
| 124e097e23 | |||
| a5425969ed | |||
| fb16bc9ed1 | |||
| dd8a940431 | |||
| eaa6b8576d | |||
| c2dc382787 | |||
| 3671af2498 | |||
| e34bcedc31 | |||
| ff91ef3edf | |||
| ba3a3db98c | |||
| 6faad45a28 | |||
| c08884fa2c | |||
| 4707f16c76 | |||
| 6164d7cae3 | |||
| dc6776cc8f | |||
| 0bfda08f47 | |||
| 485be22a16 | |||
| ead406bdac | |||
| 1472d77b52 | |||
| 6a0fea30a6 | |||
| 8c34388fc2 | |||
| aca43d51eb | |||
| 522e1c5e90 | |||
| d1d316f35c | |||
| 2af8412cdc | |||
| c93ad5e2db | |||
| 38799770db | |||
| 51b5d7b3ba | |||
| fd5fa28205 | |||
| 1244c08e42 | |||
| 021d709de2 | |||
| cac6fcd626 | |||
| 4954b19d7c | |||
| 01819642cc | |||
| cb28f7759d | |||
| 1178ab8b21 | |||
| 233e747d5d | |||
| f5a23c1fa0 | |||
| 20a228eb8d | |||
| ef5f223863 | |||
| 31456fd169 | |||
| bbcde17d49 | |||
| f328113a2a | |||
| 1143acd37a | |||
| 0b53d34417 | |||
| 3deecbff89 | |||
| bb582deefa | |||
| 7ef9a64613 | |||
| f3572f7356 | |||
| 2e998a26a2 | |||
| 4fcc19e91f | |||
| f3ec626d58 | |||
| f4733e2f7a | |||
| cdacc12a48 | |||
| bd98610153 | |||
| 5e13879954 | |||
| 339cd7bee1 | |||
| 153a54fdcd | |||
| 52ad21c339 | |||
| c3afd0c8a0 | |||
| 11928a172a | |||
| bb2146d838 | |||
| d270ab7752 | |||
| 7fd801c1eb | |||
| edec170652 | |||
| 684c98b0fa | |||
| ebbbde9cd1 | |||
| 6a6b489471 | |||
| dbc6030c17 | |||
| 23dca7d9ff | |||
| 9ee6e3e1a8 | |||
| aa96557d1c | |||
| 47be1143fb | |||
| 8f8638e6db | |||
| 003453fa3c | |||
| 92aab09104 | |||
| e7670e4449 | |||
| 955aa09780 | |||
| ce452c3d11 | |||
| ab034062d3 | |||
| 24f904fa74 | |||
| cc63ed2197 | |||
| 9c4fea93bc | |||
| 977c3a466a | |||
| e6e3975426 | |||
| bba5f80294 | |||
| 44c32587e3 | |||
| b9aa444472 | |||
| fb4b67adb3 | |||
| 9ea847923c | |||
| ed2222c80c | |||
| e91109d69c | |||
| 451d0819cc | |||
| 084c6caa7d | |||
| 43e838c55c | |||
| e1b23f3d79 | |||
| a3f75300a9 | |||
| e8a2217247 | |||
| 93e6fe8160 | |||
| f69fa8cded | |||
| bc8348b199 | |||
| 81715947bb | |||
| d705f83b68 | |||
| 2f0e7fcdc2 | |||
| 970ae2b6b2 | |||
| ded9c15f06 | |||
| 9d68eb5f00 | |||
| 700512b627 | |||
| d1975089f1 | |||
| a5ad3731f2 | |||
| 11d746dcfc | |||
| 5f877afb9e | |||
| 5b4fde055c | |||
| 59507d7553 | |||
| a265c15c4d | |||
| d332630267 | |||
| 587f6ed8ad | |||
| 82d281b5b5 | |||
| 17d5d07b3a | |||
| 71b87020c9 | |||
| 516890c49a | |||
| 962860ba74 | |||
| 117c46a386 | |||
| d828a6be08 | |||
| bbdb44afee | |||
| 784fe5c7a8 | |||
| 60c0de670c | |||
| 90bea52ccd | |||
| f920023ecf | |||
| 19ddb1ab8b | |||
| 5633957b51 | |||
| eb47e22496 | |||
| b1595baa19 | |||
| da06ecb97b | |||
| 0d5007089e | |||
| 0e7b198768 | |||
| 9ad704c12b | |||
| 0189d030bb | |||
| 7b853a05ba | |||
| 108b68c3fa | |||
| 565817273d | |||
| ea65fb584c | |||
| 25b5258841 | |||
| 46c43b95c2 | |||
| 000318f77e | |||
| f5eeb0ffcc | |||
| 75af2a4f66 | |||
| f6c3dc0801 | |||
| f5a9a1209f | |||
| 8d0356e372 | |||
| c3c9ced9dd | |||
| 7d0a21644f | |||
| 26917d432d | |||
| 430fb9e70e | |||
| 0807946d45 | |||
| 11ca6935a6 | |||
| e2f869bfc2 | |||
| 21b82c2283 | |||
| dbad418093 | |||
| 2bad843069 | |||
| 9332f4ac18 | |||
| cf191ca821 | |||
| d2202c4f49 | |||
| 1826af37a5 | |||
| acc721526d | |||
| 4b2ff279f7 | |||
| ab3fc4a807 | |||
| 09f030d202 | |||
| 43c15c86de | |||
| 851f54a9e1 | |||
| 4ed17b2776 | |||
| 0e6daaabe0 | |||
| 82531b735c | |||
| c9283cac70 | |||
| c1987b0bda | |||
| 2b31af8fde | |||
| 831c898b71 | |||
| 0f78982186 | |||
| 84dd15d22b | |||
| 5148b5dd64 | |||
| 37d85ec998 | |||
| e2beef817d | |||
| a9ca13b238 | |||
| e3ffa576ce | |||
| dd50cfdcb1 | |||
| 3675ecba65 | |||
| c8615466be | |||
| 2737288a2b | |||
| 0ea0e7cc70 | |||
| 11e08e5b96 | |||
| 4dc8cc3867 | |||
| 9852f96127 | |||
| bf05677bef | |||
| 4f340b8847 | |||
| f57fe85492 | |||
| 9a52b47786 | |||
| ab38c8372c | |||
| 8b5fa3438e | |||
| 2b4c5e0176 | |||
| 377e856c8f | |||
| 948db13bed | |||
| 25ce0729d5 | |||
| 88f919c4ca | |||
| 8f06fc554a | |||
| 1ca10230a9 | |||
| 46d60fc5a5 | |||
| 906cc42d13 | |||
| 75b28ed899 | |||
| 03e96629c7 | |||
| 34af77e0d8 | |||
| c89a664a37 | |||
| 3509ca90e8 | |||
| c8acac92cc | |||
| 5cb37bf2a0 | |||
| c0d6381bbe | |||
| 4ddd494160 | |||
| cdd89a7cee | |||
| 1ce62934b3 | |||
| 6898528096 | |||
| 12d1e356fa | |||
| 139f86bf9c | |||
| b0a8b498a8 | |||
| 4be65b0abb | |||
| f4675ebfc0 | |||
| 602ee470d9 | |||
| 8cdf91d895 | |||
| ed7880bc6c | |||
| 27ff828e60 | |||
| 6eacf0f735 | |||
| 88cb24bb86 | |||
| 49b0b646a8 | |||
| 9acdbb5ae9 | |||
| 7473c15f52 | |||
| 8d5ba3873a | |||
| b0e6f5ffb3 | |||
| a69b0f579c | |||
| 45793e181c | |||
| e3750f79f9 | |||
| 930da8a578 | |||
| 47b0ed5b52 | |||
| fee66bda50 | |||
| 2349f585ae | |||
| d4eb24be7e | |||
| 0414fbe733 | |||
| 6b76abe38e | |||
| 3ce8fdcb02 | |||
| 1be30034bd | |||
| 9878156798 | |||
| e069b3e679 | |||
| 9bf67aa7f7 | |||
| f968da9768 | |||
| 3932bc056e | |||
| aaeccd32ce | |||
| 915d934c26 | |||
| 336080bb6d | |||
| 90c296f5e1 | |||
| bf494f73fc | |||
| 71a7fdb76f | |||
| b3e8bdd171 | |||
| 1e681da738 | |||
| 8e5d426638 | |||
| 4e26d7a292 | |||
| 79eba782fb | |||
| b786a825d4 | |||
| 71564b202e | |||
| 5f0dbadea6 | |||
| 36782e1b4f | |||
| 97292ba0bf | |||
| 866f3ac012 | |||
| 2c66959b92 | |||
| e4ead076e3 | |||
| 1d59b50e49 | |||
| f38d5812d1 | |||
| 931b7d54b3 | |||
| 477d4f8cf6 | |||
| 88246fac2b | |||
| 1846168736 | |||
| c92cc29b88 | |||
| 712b77b01c | |||
| 80a99eec87 | |||
| a0d635ae5e | |||
| 63e9132d38 | |||
| 778894b488 | |||
| a1bd439c75 | |||
| 9b94df7f56 | |||
| 0c245e9e1c | |||
| b4d8edb65b | |||
| ec3360d9ad | |||
| d7b25b0891 | |||
| 34c89e814d | |||
| 6f0c365d44 | |||
| c31ce90306 | |||
| 7753938044 | |||
| a3898d68fb |
@ -16,7 +16,7 @@ WRENN_HOST_LISTEN_ADDR=:50051
|
||||
WRENN_HOST_INTERFACE=eth0
|
||||
WRENN_CP_URL=http://localhost:9725
|
||||
WRENN_DEFAULT_ROOTFS_SIZE=5Gi
|
||||
WRENN_FIRECRACKER_BIN=/usr/local/bin/firecracker
|
||||
WRENN_CH_BIN=/usr/local/bin/cloud-hypervisor
|
||||
|
||||
# Auth
|
||||
JWT_SECRET=
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -36,10 +36,14 @@ go.work.sum
|
||||
e2b/
|
||||
.impeccable.md
|
||||
.gstack
|
||||
.mcp.json
|
||||
|
||||
## Builds
|
||||
builds/
|
||||
|
||||
## Rust
|
||||
envd-rs/target/
|
||||
|
||||
## Frontend
|
||||
frontend/node_modules/
|
||||
frontend/.svelte-kit/
|
||||
@ -49,3 +53,6 @@ frontend/build/
|
||||
internal/dashboard/static/*
|
||||
!internal/dashboard/static/.gitkeep.dual-graph/
|
||||
.dual-graph/
|
||||
# Added by code-review-graph
|
||||
.code-review-graph/
|
||||
.mcp.json
|
||||
|
||||
116
CLAUDE.md
116
CLAUDE.md
@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## Project Overview
|
||||
|
||||
Wrenn Sandbox is a microVM-based code execution platform. Users create isolated sandboxes (Firecracker microVMs), run code inside them, and get output back via SDKs. Think E2B but with persistent sandboxes, pool-based pricing, and a single-binary deployment story.
|
||||
Wrenn Sandbox is a microVM-based code execution platform. Users create isolated sandboxes (Cloud Hypervisor microVMs), run code inside them, and get output back via SDKs. Think E2B but with persistent sandboxes, pool-based pricing, and a single-binary deployment story.
|
||||
|
||||
## Build & Development Commands
|
||||
|
||||
@ -14,7 +14,7 @@ All commands go through the Makefile. Never use raw `go build` or `go run`.
|
||||
make build # Build all binaries → builds/
|
||||
make build-cp # Control plane only
|
||||
make build-agent # Host agent only
|
||||
make build-envd # envd static binary (verified statically linked)
|
||||
make build-envd # envd static binary (Rust, musl, verified statically linked)
|
||||
make build-frontend # SvelteKit dashboard → frontend/build/ (served by Caddy)
|
||||
|
||||
make dev # Full local dev: infra + migrate + control plane
|
||||
@ -23,13 +23,13 @@ make dev-down # Stop dev infra
|
||||
make dev-cp # Control plane with hot reload (if air installed)
|
||||
make dev-frontend # Vite dev server with HMR (port 5173)
|
||||
make dev-agent # Host agent (sudo required)
|
||||
make dev-envd # envd in TCP debug mode
|
||||
make dev-envd # envd in debug mode (port 49983)
|
||||
|
||||
make check # fmt + vet + lint + test (CI order)
|
||||
make test # Unit tests: go test -race -v ./internal/...
|
||||
make test-integration # Integration tests (require host agent + Firecracker)
|
||||
make fmt # gofmt both modules
|
||||
make vet # go vet both modules
|
||||
make test-integration # Integration tests (require host agent + Cloud Hypervisor)
|
||||
make fmt # gofmt
|
||||
make vet # go vet
|
||||
make lint # golangci-lint
|
||||
|
||||
make migrate-up # Apply pending migrations
|
||||
@ -38,8 +38,8 @@ make migrate-create name=xxx # Scaffold new goose migration (never create manua
|
||||
make migrate-reset # Drop + re-apply all
|
||||
|
||||
make generate # Proto (buf) + sqlc codegen
|
||||
make proto # buf generate for all proto dirs
|
||||
make tidy # go mod tidy both modules
|
||||
make proto # buf generate for proto dirs
|
||||
make tidy # go mod tidy
|
||||
```
|
||||
|
||||
Run a single test: `go test -race -v -run TestName ./internal/path/...`
|
||||
@ -50,15 +50,15 @@ Run a single test: `go test -race -v -run TestName ./internal/path/...`
|
||||
User SDK → HTTPS/WS → Control Plane → Connect RPC → Host Agent → HTTP/Connect RPC over TAP → envd (inside VM)
|
||||
```
|
||||
|
||||
**Three binaries, two Go modules:**
|
||||
**Three binaries:**
|
||||
|
||||
| Binary | Module | Entry point | Runs as |
|
||||
|--------|--------|-------------|---------|
|
||||
| wrenn-cp | `git.omukk.dev/wrenn/wrenn` | `cmd/control-plane/main.go` | Unprivileged |
|
||||
| wrenn-agent | `git.omukk.dev/wrenn/wrenn` | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
|
||||
| envd | `git.omukk.dev/wrenn/wrenn/envd` (standalone `envd/go.mod`) | `envd/main.go` | PID 1 inside guest VM |
|
||||
| Binary | Language | Entry point | Runs as |
|
||||
|--------|----------|-------------|---------|
|
||||
| wrenn-cp | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/control-plane/main.go` | Unprivileged |
|
||||
| wrenn-agent | Go (`git.omukk.dev/wrenn/wrenn`) | `cmd/host-agent/main.go` | `wrenn` user with capabilities (SYS_ADMIN, NET_ADMIN, NET_RAW, SYS_PTRACE, KILL, DAC_OVERRIDE, MKNOD) via setcap; also accepts root |
|
||||
| envd | Rust (`envd-rs/`) | `envd-rs/src/main.rs` | PID 1 inside guest VM |
|
||||
|
||||
envd is a **completely independent Go module**. It is never imported by the main module. The only connection is the protobuf contract. It compiles to a static binary baked into rootfs images.
|
||||
envd is a standalone Rust binary (Tokio + Axum + connectrpc-rs). It is completely independent from the Go module — the only connection is the protobuf contract. It compiles to a statically linked musl binary baked into rootfs images.
|
||||
|
||||
**Key architectural invariant:** The host agent is **stateful** (in-memory `boxes` map is the source of truth for running VMs). The control plane is **stateless** (all persistent state in PostgreSQL). The reconciler (`internal/api/reconciler.go`) bridges the gap — it periodically compares DB records against the host agent's live state and marks orphaned sandboxes as "stopped".
|
||||
|
||||
@ -92,27 +92,31 @@ Startup (`cmd/host-agent/main.go`) wires: root/capabilities check → enable IP
|
||||
|
||||
- **RPC Server** (`internal/hostagent/server.go`): implements `hostagentv1connect.HostAgentServiceHandler`. Thin wrapper — every method delegates to `sandbox.Manager`. Maps Connect error codes on return.
|
||||
- **Sandbox Manager** (`internal/sandbox/manager.go`): the core orchestration layer. Maintains in-memory state in `boxes map[string]*sandboxState` (protected by `sync.RWMutex`). Each `sandboxState` holds a `models.Sandbox`, a `*network.Slot`, and an `*envdclient.Client`. Runs a TTL reaper (every 10s) that auto-destroys timed-out sandboxes.
|
||||
- **VM Manager** (`internal/vm/manager.go`, `fc.go`, `config.go`): manages Firecracker processes. Uses raw HTTP API over Unix socket (`/tmp/fc-{sandboxID}.sock`), not the firecracker-go-sdk Machine type. Launches Firecracker via `unshare -m` + `ip netns exec`. Configures VM via PUT to `/boot-source`, `/drives/rootfs`, `/network-interfaces/eth0`, `/machine-config`, then starts with PUT `/actions`.
|
||||
- **VM Manager** (`internal/vm/manager.go`, `ch.go`, `config.go`): manages Cloud Hypervisor processes. Uses raw HTTP API over Unix socket (`/tmp/ch-{sandboxID}.sock`). Launches Cloud Hypervisor via `unshare -m` + `ip netns exec` with `--api-socket path=...`. Configures and boots VM via `PUT /vm.create` + `PUT /vm.boot`. Snapshot restore uses `--restore source_url=file://...`.
|
||||
- **Network** (`internal/network/setup.go`, `allocator.go`): per-sandbox network namespace with veth pair + TAP device. See Networking section below.
|
||||
- **Device Mapper** (`internal/devicemapper/devicemapper.go`): CoW rootfs via device-mapper snapshots. Shared read-only loop devices per base template (refcounted `LoopRegistry`), per-sandbox sparse CoW files, dm-snapshot create/restore/remove/flatten operations.
|
||||
- **envd Client** (`internal/envdclient/client.go`, `health.go`): dual interface to the guest agent. Connect RPC for streaming process exec (`process.Start()` bidirectional stream). Plain HTTP for file operations (POST/GET `/files?path=...&username=root`). Health check polls `GET /health` every 100ms until ready (30s timeout).
|
||||
|
||||
### envd (Guest Agent)
|
||||
|
||||
**Module:** `envd/` with its own `go.mod` (`git.omukk.dev/wrenn/wrenn/envd`)
|
||||
**Directory:** `envd-rs/` — standalone Rust crate
|
||||
|
||||
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd). Extracted from E2B (Apache 2.0), with shared packages internalized into `envd/internal/shared/`. Listens on TCP `0.0.0.0:49983`.
|
||||
Runs as PID 1 inside the microVM via `wrenn-init.sh` (mounts procfs/sysfs/dev, sets hostname, writes resolv.conf, then execs envd via tini). Built with `cargo build --release --target x86_64-unknown-linux-musl`. Listens on TCP `0.0.0.0:49983`.
|
||||
|
||||
- **ProcessService**: start processes, stream stdout/stderr, signal handling, PTY support
|
||||
- **FilesystemService**: stat/list/mkdir/move/remove/watch files
|
||||
- **Health**: GET `/health`
|
||||
- **Stack**: Tokio (async runtime) + Axum (HTTP) + connectrpc-rs (Connect protocol RPC)
|
||||
- **ProcessService** (Connect RPC): start/connect/list/signal processes, stream stdout/stderr, PTY support
|
||||
- **FilesystemService** (Connect RPC): stat/list/mkdir/move/remove/watch files
|
||||
- **HTTP endpoints**: GET `/health`, GET `/metrics`, POST `/init`, POST `/snapshot/prepare`, GET/POST `/files`
|
||||
- **Proto codegen**: `connectrpc-build` compiles `proto/envd/*.proto` at `cargo build` time via `build.rs` — no committed stubs
|
||||
- **Build**: `make build-envd` → static musl binary in `builds/envd`
|
||||
- **Dev**: `make dev-envd` → `cargo run -- --port 49983`
|
||||
|
||||
### Dashboard (Frontend)
|
||||
|
||||
**Directory:** `frontend/` — standalone SvelteKit app (Svelte 5, runes mode)
|
||||
|
||||
- **Stack**: SvelteKit + `adapter-static` + Tailwind CSS v4 + Bits UI (headless accessible components)
|
||||
- **Package manager**: pnpm
|
||||
- **Package manager**: Bun
|
||||
- **Routing**: SvelteKit file-based routing under `frontend/src/routes/`
|
||||
- **Routing layout**: `/login` and `/signup` at root, authenticated pages under `/dashboard/*` (e.g. `/dashboard/capsules`, `/dashboard/keys`)
|
||||
- **Build output**: `frontend/build/` — static files served by Caddy
|
||||
@ -160,7 +164,7 @@ HIBERNATED → RUNNING (cold snapshot resume, slower)
|
||||
**Sandbox creation** (`POST /v1/capsules`):
|
||||
1. API handler generates sandbox ID, inserts into DB as "pending"
|
||||
2. RPC `CreateSandbox` → host agent → `sandbox.Manager.Create()`
|
||||
3. Manager: resolve base rootfs → acquire shared loop device → create dm-snapshot (sparse CoW file) → allocate network slot → `CreateNetwork()` (netns + veth + tap + NAT) → `vm.Create()` (start Firecracker with `/dev/mapper/wrenn-{id}`, configure via HTTP API, boot) → `envdclient.WaitUntilReady()` (poll /health) → store in-memory state
|
||||
3. Manager: resolve base rootfs → acquire shared loop device → create dm-snapshot (sparse CoW file) → allocate network slot → `CreateNetwork()` (netns + veth + tap + NAT) → `vm.Create()` (start Cloud Hypervisor with `/dev/mapper/wrenn-{id}`, configure via `PUT /vm.create` + `PUT /vm.boot`) → `envdclient.WaitUntilReady()` (poll /health) → store in-memory state
|
||||
4. API handler updates DB to "running" with host_ip
|
||||
|
||||
**Command execution** (`POST /v1/capsules/{id}/exec`):
|
||||
@ -185,17 +189,16 @@ Routes defined in `internal/api/server.go`, handlers in `internal/api/handlers_*
|
||||
|
||||
### Proto (Connect RPC)
|
||||
|
||||
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate. Three `buf.gen.yaml` files control output:
|
||||
Proto source of truth is `proto/envd/*.proto` and `proto/hostagent/*.proto`. Run `make proto` to regenerate Go stubs. Two `buf.gen.yaml` files control Go output:
|
||||
|
||||
| buf.gen.yaml location | Generates to | Used by |
|
||||
|---|---|---|
|
||||
| `proto/envd/buf.gen.yaml` | `proto/envd/gen/` | Main module (host agent's envd client) |
|
||||
| `proto/hostagent/buf.gen.yaml` | `proto/hostagent/gen/` | Main module (control plane ↔ host agent) |
|
||||
| `envd/spec/buf.gen.yaml` | `envd/internal/services/spec/` | envd module (guest agent server) |
|
||||
|
||||
The envd `buf.gen.yaml` reads from `../../proto/envd/` (same source protos) but generates into envd's own module. This means the same `.proto` files produce two independent sets of Go stubs — one for each Go module.
|
||||
The Rust envd (`envd-rs/`) generates its own protobuf stubs at `cargo build` time via `connectrpc-build` in `envd-rs/build.rs`, reading from the same `proto/envd/*.proto` sources. No committed Rust stubs — they live in `OUT_DIR`.
|
||||
|
||||
To add a new RPC method: edit the `.proto` file → `make proto` → implement the handler on both sides.
|
||||
To add a new RPC method: edit the `.proto` file → `make proto` (Go stubs) → rebuild envd-rs (Rust stubs generated automatically) → implement the handler on both sides.
|
||||
|
||||
### sqlc
|
||||
|
||||
@ -206,10 +209,10 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
|
||||
## Key Technical Decisions
|
||||
|
||||
- **Connect RPC** (not gRPC) for all RPC communication between components
|
||||
- **Buf + protoc-gen-connect-go** for code generation (not protoc-gen-go-grpc)
|
||||
- **Raw Firecracker HTTP API** via Unix socket (not firecracker-go-sdk Machine type)
|
||||
- **Buf + protoc-gen-connect-go** for Go code generation; **connectrpc-build** for Rust code generation in envd
|
||||
- **Raw Cloud Hypervisor HTTP API** via Unix socket (`PUT /vm.create` + `PUT /vm.boot`)
|
||||
- **TAP networking** (not vsock) for host-to-envd communication
|
||||
- **Device-mapper snapshots** for rootfs CoW — shared read-only loop device per base template, per-sandbox sparse CoW file, Firecracker gets `/dev/mapper/wrenn-{id}`
|
||||
- **Device-mapper snapshots** for rootfs CoW — shared read-only loop device per base template, per-sandbox sparse CoW file, Cloud Hypervisor gets `/dev/mapper/wrenn-{id}`
|
||||
- **PostgreSQL** via pgx/v5 + sqlc (type-safe query generation). Goose for migrations (plain SQL, up/down)
|
||||
- **Dashboard**: SvelteKit (Svelte 5, adapter-static) + Tailwind CSS v4 + Bits UI. Built to static files in `frontend/build/`, served by Caddy (not embedded in the Go binary)
|
||||
- **Lago** for billing (external service, not in this codebase)
|
||||
@ -218,19 +221,15 @@ To add a new query: add it to the appropriate `.sql` file in `db/queries/` → `
|
||||
|
||||
- **Go style**: `gofmt`, `go vet`, `context.Context` everywhere, errors wrapped with `fmt.Errorf("action: %w", err)`, `slog` for logging, no global state
|
||||
- **Naming**: Sandbox IDs `sb-` + 8 hex, API keys `wrn_` + 32 chars, Host IDs `host-` + 8 hex
|
||||
- **Dependencies**: Use `go get` to add deps, never hand-edit go.mod. For envd deps: `cd envd && go get ...` (separate module)
|
||||
- **Dependencies**: Use `go get` to add Go deps, never hand-edit go.mod. For envd-rs deps: edit `envd-rs/Cargo.toml`
|
||||
- **Generated code**: Always commit generated code (proto stubs, sqlc). Never add generated code to .gitignore
|
||||
- **Migrations**: Always use `make migrate-create name=xxx`, never create migration files manually
|
||||
- **Testing**: Table-driven tests for handlers and state machine transitions
|
||||
|
||||
### Two-module gotcha
|
||||
|
||||
The main module (`go.mod`) and envd (`envd/go.mod`) are fully independent. `make tidy`, `make fmt`, `make vet` already operate on both. But when adding dependencies manually, remember to target the correct module (`cd envd && go get ...` for envd deps). `make proto` also generates stubs for both modules from the same proto sources.
|
||||
|
||||
## Rootfs & Guest Init
|
||||
|
||||
- **wrenn-init** (`images/wrenn-init.sh`): the PID 1 init script baked into every rootfs. Mounts virtual filesystems, sets hostname, writes `/etc/resolv.conf`, then execs envd.
|
||||
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-debug-rootfs.sh [rootfs_path]`. This builds envd via `make build-envd`, mounts the rootfs image, copies in the new binaries, and unmounts. Defaults to `/var/lib/wrenn/images/minimal.ext4`.
|
||||
- **Updating the rootfs** after changing envd or wrenn-init: `bash scripts/update-minimal-rootfs.sh`. This builds envd via `make build-envd` (Rust → static musl binary), mounts the rootfs image, copies in the new binaries, and unmounts. Defaults to `/var/lib/wrenn/images/minimal.ext4`.
|
||||
- Rootfs images are minimal debootstrap — no systemd, no coreutils beyond busybox. Use `/bin/sh -c` for shell builtins inside the guest.
|
||||
|
||||
## Fixed Paths (on host machine)
|
||||
@ -238,19 +237,19 @@ The main module (`go.mod`) and envd (`envd/go.mod`) are fully independent. `make
|
||||
- Kernel: `/var/lib/wrenn/kernels/vmlinux`
|
||||
- Base rootfs images: `/var/lib/wrenn/images/{template}.ext4`
|
||||
- Sandbox clones: `/var/lib/wrenn/sandboxes/`
|
||||
- Firecracker: `/usr/local/bin/firecracker` (e2b's fork of firecracker)
|
||||
- Cloud Hypervisor: `/usr/local/bin/cloud-hypervisor`
|
||||
|
||||
## Design Context
|
||||
|
||||
### Users
|
||||
Developers across the full spectrum — solo engineers building side projects, startup teams integrating sandboxed execution into products, and platform/infra engineers at larger organizations running production workloads on Firecracker microVMs. They arrive with context: they know what a process is, what a rootfs is, what a TTY means. The interface must feel at home for all three: approachable enough not to intimidate a hacker, precise enough to earn the trust of a production ops team. Never condescend, never oversimplify. Trust the user to understand what they're looking at.
|
||||
Developers across the full spectrum — solo engineers building side projects, startup teams integrating sandboxed execution into products, and platform/infra engineers at larger organizations running production workloads on Cloud Hypervisor microVMs. They arrive with context: they know what a process is, what a rootfs is, what a TTY means. The interface must feel at home for all three: approachable enough not to intimidate a hacker, precise enough to earn the trust of a production ops team. Never condescend, never oversimplify. Trust the user to understand what they're looking at.
|
||||
|
||||
**Primary job to be done:** Understand what's running, act on it confidently, and get back to code.
|
||||
|
||||
### Brand Personality
|
||||
**Precise. Warm. Uncompromising.**
|
||||
|
||||
Wrenn is an engineer's favorite tool — built with visible care, not assembled from defaults. It runs real infrastructure (Firecracker microVMs), so the UI should reflect that seriousness without becoming cold or corporate. The warmth comes from the typography and color palette; the precision comes from hierarchy, density, and data fidelity.
|
||||
Wrenn is an engineer's favorite tool — built with visible care, not assembled from defaults. It runs real infrastructure (Cloud Hypervisor microVMs), so the UI should reflect that seriousness without becoming cold or corporate. The warmth comes from the typography and color palette; the precision comes from hierarchy, density, and data fidelity.
|
||||
|
||||
Emotional goal: **in control.** Users leave a session with full confidence in what's running, what happened, and what comes next. Nothing is hidden, nothing is ambiguous.
|
||||
|
||||
@ -372,3 +371,42 @@ All values are CSS custom properties in `frontend/src/app.css`.
|
||||
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
|
||||
|
||||
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
||||
|
||||
<!-- code-review-graph MCP tools -->
|
||||
## MCP Tools: code-review-graph
|
||||
|
||||
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||
you structural context (callers, dependents, test coverage) that file
|
||||
scanning cannot.
|
||||
|
||||
### When to use graph tools FIRST
|
||||
|
||||
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||
|
||||
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||
|
||||
### Key Tools
|
||||
|
||||
| Tool | Use when |
|
||||
|------|----------|
|
||||
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||
| `get_impact_radius` | Understanding blast radius of a change |
|
||||
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||
| `refactor_tool` | Planning renames, finding dead code |
|
||||
|
||||
### Workflow
|
||||
|
||||
1. The graph auto-updates on file changes (via hooks).
|
||||
2. Use `detect_changes` for code review.
|
||||
3. Use `get_affected_flows` to understand impact.
|
||||
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||
|
||||
42
Makefile
42
Makefile
@ -2,12 +2,10 @@
|
||||
# Variables
|
||||
# ═══════════════════════════════════════════════════
|
||||
DATABASE_URL ?= postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
|
||||
GOBIN := $(shell pwd)/builds
|
||||
ENVD_DIR := envd
|
||||
BIN_DIR := $(shell pwd)/builds
|
||||
COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
VERSION_CP := $(shell cat VERSION_CP 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
VERSION_AGENT := $(shell cat VERSION_AGENT 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
VERSION_ENVD := $(shell cat envd/VERSION 2>/dev/null | tr -d '[:space:]' || echo "0.0.0-dev")
|
||||
LDFLAGS := -s -w
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
@ -18,19 +16,23 @@ LDFLAGS := -s -w
|
||||
build: build-cp build-agent build-envd
|
||||
|
||||
build-frontend:
|
||||
cd frontend && pnpm install --frozen-lockfile && pnpm build
|
||||
cd frontend && bun install --frozen-lockfile && bun run build
|
||||
|
||||
build-cp:
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-cp ./cmd/control-plane
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_CP) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-cp ./cmd/control-plane
|
||||
|
||||
build-agent:
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(GOBIN)/wrenn-agent ./cmd/host-agent
|
||||
go build -v -ldflags="$(LDFLAGS) -X main.version=$(VERSION_AGENT) -X main.commit=$(COMMIT)" -o $(BIN_DIR)/wrenn-agent ./cmd/host-agent
|
||||
|
||||
build-envd:
|
||||
cd $(ENVD_DIR) && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags="$(LDFLAGS) -X main.Version=$(VERSION_ENVD) -X main.commitSHA=$(COMMIT)" -o $(GOBIN)/envd .
|
||||
@file $(GOBIN)/envd | grep -q "statically linked" || \
|
||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||
cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl
|
||||
@cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd
|
||||
@readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \
|
||||
readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \
|
||||
! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \
|
||||
{ ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \
|
||||
readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \
|
||||
(echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1)
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Development
|
||||
@ -57,11 +59,10 @@ dev-agent:
|
||||
sudo go run ./cmd/host-agent
|
||||
|
||||
dev-frontend:
|
||||
cd frontend && pnpm dev --port 5173 --host 0.0.0.0
|
||||
cd frontend && bun run dev --port 5173 --host 0.0.0.0
|
||||
|
||||
dev-envd:
|
||||
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
|
||||
|
||||
cd envd-rs && cargo run -- --port 49983
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Database (goose)
|
||||
@ -94,7 +95,6 @@ generate: proto sqlc
|
||||
proto:
|
||||
cd proto/envd && buf generate
|
||||
cd proto/hostagent && buf generate
|
||||
cd $(ENVD_DIR)/spec && buf generate
|
||||
|
||||
sqlc:
|
||||
sqlc generate
|
||||
@ -106,17 +106,16 @@ sqlc:
|
||||
|
||||
fmt:
|
||||
gofmt -w .
|
||||
cd $(ENVD_DIR) && gofmt -w .
|
||||
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
cd $(ENVD_DIR) && go vet ./...
|
||||
|
||||
test:
|
||||
go test -race -v ./internal/...
|
||||
cd envd-rs && cargo test
|
||||
|
||||
test-integration:
|
||||
go test -race -v -tags=integration ./tests/integration/...
|
||||
@ -125,7 +124,6 @@ test-all: test test-integration
|
||||
|
||||
tidy:
|
||||
go mod tidy
|
||||
cd $(ENVD_DIR) && go mod tidy
|
||||
|
||||
## Run all quality checks in CI order
|
||||
check: fmt vet lint test
|
||||
@ -155,8 +153,8 @@ setup-host:
|
||||
sudo bash scripts/setup-host.sh
|
||||
|
||||
install: build
|
||||
sudo cp $(GOBIN)/wrenn-cp /usr/local/bin/
|
||||
sudo cp $(GOBIN)/wrenn-agent /usr/local/bin/
|
||||
sudo cp $(BIN_DIR)/wrenn-cp /usr/local/bin/
|
||||
sudo cp $(BIN_DIR)/wrenn-agent /usr/local/bin/
|
||||
sudo cp deploy/systemd/*.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
@ -167,7 +165,7 @@ install: build
|
||||
|
||||
clean:
|
||||
rm -rf builds/
|
||||
cd $(ENVD_DIR) && rm -f envd
|
||||
cd envd-rs && cargo clean
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Help
|
||||
@ -183,11 +181,11 @@ help:
|
||||
@echo " make dev-cp Control plane (hot reload if air installed)"
|
||||
@echo " make dev-frontend Vite dev server with HMR (port 5173)"
|
||||
@echo " make dev-agent Host agent (sudo required)"
|
||||
@echo " make dev-envd envd in TCP debug mode"
|
||||
@echo " make dev-envd envd in debug mode (port 49983)"
|
||||
@echo ""
|
||||
@echo " make build Build all binaries → builds/"
|
||||
@echo " make build-frontend Build SvelteKit dashboard → frontend/build/"
|
||||
@echo " make build-envd Build envd static binary"
|
||||
@echo " make build-envd Build envd static binary (Rust, musl)"
|
||||
@echo ""
|
||||
@echo " make migrate-up Apply migrations"
|
||||
@echo " make migrate-create name=xxx New migration"
|
||||
|
||||
19
NOTICE
19
NOTICE
@ -1,19 +0,0 @@
|
||||
Wrenn Sandbox
|
||||
Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
|
||||
This project includes software derived from the following project:
|
||||
|
||||
Project: e2b infra
|
||||
Repository: https://github.com/e2b-dev/infra
|
||||
|
||||
The following files and directories in this repository contain code derived from the above project:
|
||||
|
||||
- envd/
|
||||
- proto/envd/*.proto
|
||||
- internal/snapshot/
|
||||
- internal/uffd/
|
||||
|
||||
Modifications to this code were made by M/S Omukk.
|
||||
|
||||
Copyright (c) 2023 FoundryLabs, Inc.
|
||||
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
@ -5,10 +5,11 @@ Secure infrastructure for AI
|
||||
## Prerequisites
|
||||
|
||||
- Linux host with `/dev/kvm` access (bare metal or nested virt)
|
||||
- Firecracker binary at `/usr/local/bin/firecracker`
|
||||
- Cloud Hypervisor binary at `/usr/local/bin/cloud-hypervisor`
|
||||
- PostgreSQL
|
||||
- Go 1.25+
|
||||
- pnpm (for frontend)
|
||||
- Rust 1.88+ with `x86_64-unknown-linux-musl` target (`rustup target add x86_64-unknown-linux-musl`)
|
||||
- Bun (for frontend)
|
||||
- Docker (for dev infra and rootfs builds)
|
||||
|
||||
## Build
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.0
|
||||
0.2.0
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.1
|
||||
0.2.0
|
||||
|
||||
@ -80,6 +80,25 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Register with the control plane before touching rootfs images. If the
|
||||
// agent can't reach the CP there's no point inflating images (and crashing
|
||||
// afterward would leave them in the expanded state).
|
||||
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: credsFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("host registration failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", creds.HostID)
|
||||
|
||||
// Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M").
|
||||
defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB
|
||||
if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" {
|
||||
@ -107,48 +126,47 @@ func main() {
|
||||
}
|
||||
slog.Info("resolved kernel", "version", kernelVersion, "path", kernelPath)
|
||||
|
||||
// Detect firecracker version.
|
||||
fcBin := envOrDefault("WRENN_FIRECRACKER_BIN", "/usr/local/bin/firecracker")
|
||||
fcVersion, err := sandbox.DetectFirecrackerVersion(fcBin)
|
||||
// Detect cloud-hypervisor version.
|
||||
chBin := envOrDefault("WRENN_CH_BIN", "/usr/local/bin/cloud-hypervisor")
|
||||
chVersion, err := sandbox.DetectCHVersion(chBin)
|
||||
if err != nil {
|
||||
slog.Error("failed to detect firecracker version", "error", err)
|
||||
slog.Error("failed to detect cloud-hypervisor version", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("resolved firecracker", "version", fcVersion, "path", fcBin)
|
||||
slog.Info("resolved cloud-hypervisor", "version", chVersion, "path", chBin)
|
||||
|
||||
cfg := sandbox.Config{
|
||||
WrennDir: rootDir,
|
||||
DefaultRootfsSizeMB: defaultRootfsSizeMB,
|
||||
KernelPath: kernelPath,
|
||||
KernelVersion: kernelVersion,
|
||||
FirecrackerBin: fcBin,
|
||||
FirecrackerVersion: fcVersion,
|
||||
VMMBin: chBin,
|
||||
VMMVersion: chVersion,
|
||||
AgentVersion: version,
|
||||
}
|
||||
|
||||
mgr := sandbox.New(cfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Set up lifecycle event callback sender so autonomous events
|
||||
// (auto-pause, auto-destroy) are pushed to the CP proactively.
|
||||
cb := hostagent.NewCallbackSender(cpURL, credsFile, creds.HostID)
|
||||
mgr.SetEventSender(hostagent.NewEventSender(cb))
|
||||
|
||||
mgr.StartTTLReaper(ctx)
|
||||
|
||||
// Register with the control plane and start heartbeating.
|
||||
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: credsFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("host registration failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", creds.HostID)
|
||||
|
||||
// httpServer is declared here so the shutdown func can reference it.
|
||||
httpServer := &http.Server{Addr: listenAddr}
|
||||
// ReadTimeout/WriteTimeout are intentionally omitted — they would kill
|
||||
// long-lived Connect RPC streams and WebSocket proxy connections.
|
||||
httpServer := &http.Server{
|
||||
Addr: listenAddr,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
IdleTimeout: 620 * time.Second, // > typical LB upstream timeout (600s)
|
||||
// Disable HTTP/2: empty non-nil map prevents Go from registering
|
||||
// the h2 ALPN token. Connect RPC works over HTTP/1.1; HTTP/2
|
||||
// multiplexing causes HOL blocking when a slow sandbox RPC stalls
|
||||
// the shared connection.
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
|
||||
}
|
||||
|
||||
// mTLS is mandatory — refuse to start without a valid certificate.
|
||||
var certStore hostagent.CertStore
|
||||
@ -193,6 +211,7 @@ func main() {
|
||||
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
||||
|
||||
proxyHandler := hostagent.NewProxyHandler(mgr)
|
||||
mgr.SetOnDestroy(proxyHandler.EvictProxy)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(path, handler)
|
||||
@ -212,8 +231,9 @@ func main() {
|
||||
func() {
|
||||
doShutdown("host deleted from CP")
|
||||
},
|
||||
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
|
||||
// onCredsRefreshed: hot-swap the TLS certificate and update callback JWT.
|
||||
func(tf *hostagent.TokenFile) {
|
||||
cb.UpdateJWT(tf.JWT)
|
||||
if tf.CertPEM == "" || tf.KeyPEM == "" {
|
||||
return
|
||||
}
|
||||
@ -225,12 +245,16 @@ func main() {
|
||||
},
|
||||
)
|
||||
|
||||
// Graceful shutdown on SIGINT/SIGTERM.
|
||||
// Graceful shutdown on SIGINT/SIGTERM. A second signal force-exits
|
||||
// so the operator can always kill the process if shutdown hangs.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
doShutdown("signal: " + sig.String())
|
||||
go doShutdown("signal: " + sig.String())
|
||||
sig = <-sigCh
|
||||
slog.Error("received second signal, force exiting", "signal", sig.String())
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID, "version", version, "commit", commit)
|
||||
@ -272,7 +296,7 @@ func checkPrivileges() error {
|
||||
name string
|
||||
}{
|
||||
{1, "CAP_DAC_OVERRIDE"}, // /dev/loop*, /dev/mapper/*, /dev/net/tun
|
||||
{5, "CAP_KILL"}, // SIGTERM/SIGKILL to Firecracker processes
|
||||
{5, "CAP_KILL"}, // SIGTERM/SIGKILL to cloud-hypervisor processes
|
||||
{12, "CAP_NET_ADMIN"}, // netlink, iptables, routing, TAP/veth
|
||||
{13, "CAP_NET_RAW"}, // raw sockets (iptables)
|
||||
{19, "CAP_SYS_PTRACE"}, // reading /proc/self/ns/net (netns.Get)
|
||||
|
||||
@ -9,4 +9,10 @@ VALUES ('00000000-0000-0000-0000-000000000000', 'Platform', 'platform')
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
-- +goose Down
|
||||
-- Delete dependent rows that reference the platform team via foreign keys.
|
||||
-- Order matters: children before parent.
|
||||
DELETE FROM sandboxes WHERE team_id = '00000000-0000-0000-0000-000000000000';
|
||||
DELETE FROM team_api_keys WHERE team_id = '00000000-0000-0000-0000-000000000000';
|
||||
DELETE FROM users_teams WHERE team_id = '00000000-0000-0000-0000-000000000000';
|
||||
DELETE FROM hosts WHERE team_id = '00000000-0000-0000-0000-000000000000';
|
||||
DELETE FROM teams WHERE id = '00000000-0000-0000-0000-000000000000';
|
||||
|
||||
11
db/migrations/20260418072009_daily_usage.sql
Normal file
11
db/migrations/20260418072009_daily_usage.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
CREATE TABLE daily_usage (
|
||||
team_id UUID NOT NULL,
|
||||
day DATE NOT NULL,
|
||||
cpu_minutes NUMERIC(18, 4) NOT NULL DEFAULT 0,
|
||||
ram_mb_minutes NUMERIC(18, 4) NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (team_id, day)
|
||||
);
|
||||
|
||||
-- +goose Down
|
||||
DROP TABLE daily_usage;
|
||||
@ -1,5 +1,5 @@
|
||||
// Package migrations embeds the SQL migration files so that external modules
|
||||
// (such as the enterprise edition) can access them programmatically.
|
||||
// (such as the cloud edition) can access them programmatically.
|
||||
package migrations
|
||||
|
||||
import "embed"
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
|
||||
|
||||
-- name: AnonymizeAuditLogsByUserID :exec
|
||||
UPDATE audit_logs
|
||||
SET actor_name = CASE WHEN actor_id = $1 THEN 'deleted-user' ELSE actor_name END,
|
||||
actor_id = CASE WHEN actor_id = $1 THEN NULL ELSE actor_id END,
|
||||
resource_id = CASE WHEN resource_type = 'member' AND resource_id = $1 THEN NULL ELSE resource_id END,
|
||||
metadata = CASE WHEN resource_type = 'member' AND resource_id = $1 AND metadata ? 'email' THEN metadata - 'email' ELSE metadata END
|
||||
WHERE actor_id = $1
|
||||
OR (resource_type = 'member' AND resource_id = $1);
|
||||
|
||||
-- name: ListAuditLogs :many
|
||||
SELECT * FROM audit_logs
|
||||
WHERE team_id = $1
|
||||
|
||||
@ -73,3 +73,35 @@ SELECT
|
||||
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
|
||||
FROM sandboxes
|
||||
GROUP BY team_id;
|
||||
|
||||
-- name: GetTeamsWithSnapshots :many
|
||||
SELECT DISTINCT team_id
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE sampled_at > NOW() - INTERVAL '93 days';
|
||||
|
||||
-- name: ComputeDailyUsageForDay :one
|
||||
SELECT
|
||||
COALESCE(SUM(vcpus_reserved * 10.0 / 60.0), 0)::NUMERIC(18,4) AS cpu_minutes,
|
||||
COALESCE(SUM(memory_mb_reserved * 10.0 / 60.0), 0)::NUMERIC(18,4) AS ram_mb_minutes
|
||||
FROM sandbox_metrics_snapshots
|
||||
WHERE team_id = $1
|
||||
AND sampled_at >= $2
|
||||
AND sampled_at < $3;
|
||||
|
||||
-- name: UpsertDailyUsage :exec
|
||||
INSERT INTO daily_usage (team_id, day, cpu_minutes, ram_mb_minutes)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (team_id, day) DO UPDATE
|
||||
SET cpu_minutes = EXCLUDED.cpu_minutes,
|
||||
ram_mb_minutes = EXCLUDED.ram_mb_minutes;
|
||||
|
||||
-- name: GetDailyUsage :many
|
||||
SELECT day, cpu_minutes, ram_mb_minutes
|
||||
FROM daily_usage
|
||||
WHERE team_id = $1
|
||||
AND day >= $2
|
||||
AND day <= $3
|
||||
ORDER BY day ASC;
|
||||
|
||||
-- name: DeleteDailyUsageByTeam :exec
|
||||
DELETE FROM daily_usage WHERE team_id = $1;
|
||||
|
||||
@ -72,7 +72,7 @@ ORDER BY created_at DESC;
|
||||
UPDATE sandboxes
|
||||
SET status = 'missing',
|
||||
last_updated = NOW()
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending');
|
||||
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending', 'pausing', 'resuming', 'stopping');
|
||||
|
||||
-- name: UpdateSandboxMetadata :exec
|
||||
UPDATE sandboxes
|
||||
@ -80,6 +80,30 @@ SET metadata = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: UpdateSandboxRunningIf :one
|
||||
-- Conditionally transition a sandbox to running only if the current status
|
||||
-- matches the expected value. Prevents races where a user destroys a sandbox
|
||||
-- while the create/resume goroutine is still in-flight.
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
host_ip = $3,
|
||||
guest_ip = $4,
|
||||
started_at = $5,
|
||||
last_active_at = $5,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1 AND status = $2
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateSandboxStatusIf :one
|
||||
-- Atomically update status only when the current status matches the expected value.
|
||||
-- Prevents background goroutines from overwriting a status that has since changed
|
||||
-- (e.g. user destroyed a sandbox while Create was in-flight).
|
||||
UPDATE sandboxes
|
||||
SET status = $3,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1 AND status = $2
|
||||
RETURNING *;
|
||||
|
||||
-- name: BulkRestoreRunning :exec
|
||||
-- Called by the reconciler when a host comes back online and its sandboxes are
|
||||
-- confirmed alive. Restores only sandboxes that are in 'missing' state.
|
||||
|
||||
@ -22,6 +22,12 @@ RETURNING *;
|
||||
-- name: SetUserAdmin :exec
|
||||
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
|
||||
|
||||
-- name: RevokeUserAdmin :execrows
|
||||
UPDATE users u SET is_admin = false, updated_at = NOW()
|
||||
WHERE u.id = $1
|
||||
AND u.is_admin = true
|
||||
AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1;
|
||||
|
||||
-- name: GetAdminUsers :many
|
||||
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;
|
||||
|
||||
@ -91,8 +97,8 @@ WHERE ut.user_id = $1
|
||||
WHERE ut2.team_id = ut.team_id AND ut2.user_id <> $1
|
||||
);
|
||||
|
||||
-- name: HardDeleteExpiredUsers :exec
|
||||
DELETE FROM users WHERE deleted_at IS NOT NULL AND deleted_at < NOW() - INTERVAL '15 days';
|
||||
-- name: ListExpiredSoftDeletedUsers :many
|
||||
SELECT id, email FROM users WHERE deleted_at IS NOT NULL AND deleted_at < NOW() - INTERVAL '15 days';
|
||||
|
||||
-- name: HardDeleteUser :exec
|
||||
DELETE FROM users WHERE id = $1;
|
||||
|
||||
2
envd-rs/.cargo/config.toml
Normal file
2
envd-rs/.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[target.x86_64-unknown-linux-musl]
|
||||
linker = "musl-gcc"
|
||||
2199
envd-rs/Cargo.lock
generated
Normal file
2199
envd-rs/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
83
envd-rs/Cargo.toml
Normal file
83
envd-rs/Cargo.toml
Normal file
@ -0,0 +1,83 @@
|
||||
[package]
|
||||
name = "envd"
|
||||
version = "0.3.0"
|
||||
edition = "2024"
|
||||
rust-version = "1.88"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# HTTP framework
|
||||
axum = { version = "0.8", features = ["multipart"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
||||
tower-service = "0.3"
|
||||
|
||||
# RPC (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
|
||||
connectrpc = { version = "0.3", features = ["axum"] }
|
||||
buffa-types = { path = "buffa-types-shim" }
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
|
||||
# System metrics
|
||||
sysinfo = "0.33"
|
||||
|
||||
# Unix syscalls
|
||||
nix = { version = "0.30", features = ["fs", "process", "signal", "user", "term", "mount", "ioctl"] }
|
||||
|
||||
# Concurrent map
|
||||
dashmap = "6"
|
||||
|
||||
# Crypto
|
||||
sha2 = "0.10"
|
||||
hmac = "0.12"
|
||||
hex = "0.4"
|
||||
base64 = "0.22"
|
||||
|
||||
# Secure memory
|
||||
zeroize = { version = "1", features = ["derive"] }
|
||||
|
||||
# File watching
|
||||
notify = "7"
|
||||
|
||||
# Compression
|
||||
flate2 = "1"
|
||||
|
||||
# Directory walking
|
||||
walkdir = "2"
|
||||
|
||||
# Misc
|
||||
libc = "0.2"
|
||||
bytes = "1"
|
||||
http = "1"
|
||||
http-body-util = "0.1"
|
||||
futures = "0.3"
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
subtle = "2"
|
||||
http-body = "1.0.1"
|
||||
buffa = "0.3"
|
||||
async-stream = "0.3.6"
|
||||
mime_guess = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
[build-dependencies]
|
||||
connectrpc-build = "0.3"
|
||||
|
||||
[profile.release]
|
||||
strip = true
|
||||
lto = true
|
||||
opt-level = "z"
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
140
envd-rs/README.md
Normal file
140
envd-rs/README.md
Normal file
@ -0,0 +1,140 @@
|
||||
# envd (Rust)
|
||||
|
||||
Wrenn guest agent daemon — runs as PID 1 inside Cloud Hypervisor microVMs. Provides process management, filesystem operations, file transfer, port forwarding, and VM lifecycle control over Connect RPC and HTTP.
|
||||
|
||||
Rust rewrite of `envd/` (Go). Drop-in replacement — same wire protocol, same endpoints, same CLI flags.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Rust 1.88+ (required by `connectrpc` 0.3.3)
|
||||
- `protoc` (protobuf compiler, for proto codegen at build time)
|
||||
- `musl-tools` (for static linking)
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install musl-tools protobuf-compiler
|
||||
|
||||
# Rust musl target
|
||||
rustup target add x86_64-unknown-linux-musl
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Static binary (production — what goes into the rootfs)
|
||||
|
||||
```bash
|
||||
cd envd-rs
|
||||
ENVD_COMMIT=$(git rev-parse --short HEAD) \
|
||||
cargo build --release --target x86_64-unknown-linux-musl
|
||||
```
|
||||
|
||||
Output: `target/x86_64-unknown-linux-musl/release/envd`
|
||||
|
||||
Verify static linking:
|
||||
|
||||
```bash
|
||||
file target/x86_64-unknown-linux-musl/release/envd
|
||||
# should say: "statically linked"
|
||||
|
||||
ldd target/x86_64-unknown-linux-musl/release/envd
|
||||
# should say: "not a dynamic executable"
|
||||
```
|
||||
|
||||
### Debug binary (dev machine, dynamically linked)
|
||||
|
||||
```bash
|
||||
cd envd-rs
|
||||
cargo build
|
||||
```
|
||||
|
||||
Run locally (outside a VM):
|
||||
|
||||
```bash
|
||||
./target/debug/envd --port 49983
|
||||
```
|
||||
|
||||
### Via Makefile (from repo root)
|
||||
|
||||
```bash
|
||||
make build-envd # static musl release build
|
||||
make build-envd-go # Go version (for comparison)
|
||||
```
|
||||
|
||||
## CLI Flags
|
||||
|
||||
```
|
||||
--port <PORT> Listen port [default: 49983]
|
||||
--version Print version and exit
|
||||
--commit Print git commit and exit
|
||||
--cmd <CMD> Spawn a process at startup (e.g. --cmd "/bin/bash")
|
||||
--cgroup-root <PATH> Cgroup v2 root [default: /sys/fs/cgroup]
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### HTTP
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|---------------------|--------------------------------------|
|
||||
| GET | `/health` | Health check, triggers post-restore |
|
||||
| GET | `/metrics` | System metrics (CPU, memory, disk) |
|
||||
| GET | `/envs` | Current environment variables |
|
||||
| POST | `/init` | Host agent init (token, env, mounts) |
|
||||
| POST | `/snapshot/prepare` | Quiesce before Cloud Hypervisor snapshot |
|
||||
| GET | `/files` | Download file (gzip, range support) |
|
||||
| POST | `/files` | Upload file(s) via multipart |
|
||||
|
||||
### Connect RPC (same port)
|
||||
|
||||
| Service | RPCs |
|
||||
|------------|-------------------------------------------------------------------------|
|
||||
| Process | List, Start, Connect, Update, StreamInput, SendInput, SendSignal, CloseStdin |
|
||||
| Filesystem | Stat, MakeDir, Move, ListDir, Remove, WatchDir, CreateWatcher, GetWatcherEvents, RemoveWatcher |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
42 files, ~4200 LOC Rust
|
||||
Binary: ~4 MB (stripped, LTO, musl static)
|
||||
|
||||
src/
|
||||
├── main.rs # Entry point, CLI, server setup
|
||||
├── state.rs # Shared AppState
|
||||
├── config.rs # Constants
|
||||
├── conntracker.rs # TCP connection tracking for snapshot/restore
|
||||
├── execcontext.rs # Default user/workdir/env
|
||||
├── logging.rs # tracing-subscriber (JSON or pretty)
|
||||
├── util.rs # AtomicMax
|
||||
├── auth/ # Token, signing, middleware
|
||||
├── crypto/ # SHA-256, SHA-512, HMAC
|
||||
├── host/ # System metrics
|
||||
├── http/ # Axum handlers (health, init, snapshot, files, encoding)
|
||||
├── permissions/ # Path resolution, user lookup, chown
|
||||
├── rpc/ # Connect RPC services
|
||||
│ ├── pb.rs # Generated proto types
|
||||
│ ├── process_*.rs # Process service + handler (PTY, pipe, broadcast)
|
||||
│ ├── filesystem_*.rs # Filesystem service (stat, list, watch, mkdir, move, remove)
|
||||
│ └── entry.rs # EntryInfo builder
|
||||
├── port/ # Port subsystem
|
||||
│ ├── conn.rs # /proc/net/tcp parser
|
||||
│ ├── scanner.rs # Periodic TCP port scanner
|
||||
│ ├── forwarder.rs # socat-based port forwarding
|
||||
│ └── subsystem.rs # Lifecycle (start/stop/restart)
|
||||
└── cgroups/ # Cgroup v2 manager (pty/user/socat groups)
|
||||
```
|
||||
|
||||
## Updating the rootfs
|
||||
|
||||
After building the static binary, copy it into the rootfs:
|
||||
|
||||
```bash
|
||||
bash scripts/update-debug-rootfs.sh [rootfs_path]
|
||||
```
|
||||
|
||||
Or manually:
|
||||
|
||||
```bash
|
||||
sudo mount -o loop /var/lib/wrenn/images/minimal.ext4 /mnt
|
||||
sudo cp target/x86_64-unknown-linux-musl/release/envd /mnt/usr/bin/envd
|
||||
sudo umount /mnt
|
||||
```
|
||||
12
envd-rs/buffa-types-shim/Cargo.toml
Normal file
12
envd-rs/buffa-types-shim/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "buffa-types"
|
||||
version = "0.3.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
buffa = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
||||
[build-dependencies]
|
||||
connectrpc-build = "0.3"
|
||||
9
envd-rs/buffa-types-shim/build.rs
Normal file
9
envd-rs/buffa-types-shim/build.rs
Normal file
@ -0,0 +1,9 @@
|
||||
fn main() {
|
||||
connectrpc_build::Config::new()
|
||||
.files(&["/usr/include/google/protobuf/timestamp.proto"])
|
||||
.includes(&["/usr/include"])
|
||||
.include_file("_types.rs")
|
||||
.emit_register_fn(false)
|
||||
.compile()
|
||||
.unwrap();
|
||||
}
|
||||
6
envd-rs/buffa-types-shim/src/lib.rs
Normal file
6
envd-rs/buffa-types-shim/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
|
||||
|
||||
use ::buffa;
|
||||
use ::serde;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/_types.rs"));
|
||||
11
envd-rs/build.rs
Normal file
11
envd-rs/build.rs
Normal file
@ -0,0 +1,11 @@
|
||||
fn main() {
|
||||
connectrpc_build::Config::new()
|
||||
.files(&[
|
||||
"../proto/envd/process.proto",
|
||||
"../proto/envd/filesystem.proto",
|
||||
])
|
||||
.includes(&["../proto/envd", "/usr/include"])
|
||||
.include_file("_connectrpc.rs")
|
||||
.compile()
|
||||
.unwrap();
|
||||
}
|
||||
3
envd-rs/rust-toolchain.toml
Normal file
3
envd-rs/rust-toolchain.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "stable"
|
||||
targets = ["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl"]
|
||||
56
envd-rs/src/auth/middleware.rs
Normal file
56
envd-rs/src/auth/middleware.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::auth::token::SecureToken;
|
||||
|
||||
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
|
||||
|
||||
/// Paths excluded from general token auth.
|
||||
/// Format: "METHOD/path"
|
||||
const AUTH_EXCLUDED: &[&str] = &[
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
"POST/snapshot/prepare",
|
||||
];
|
||||
|
||||
/// Axum middleware that checks X-Access-Token header.
|
||||
pub async fn auth_layer(
|
||||
request: Request,
|
||||
next: Next,
|
||||
access_token: Arc<SecureToken>,
|
||||
) -> Response {
|
||||
if access_token.is_set() {
|
||||
let method = request.method().as_str();
|
||||
let path = request.uri().path();
|
||||
let key = format!("{method}{path}");
|
||||
|
||||
let is_excluded = AUTH_EXCLUDED.iter().any(|p| *p == key);
|
||||
|
||||
let header_val = request
|
||||
.headers()
|
||||
.get(ACCESS_TOKEN_HEADER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if !access_token.equals(header_val) && !is_excluded {
|
||||
tracing::error!("unauthorized access attempt");
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
axum::Json(json!({
|
||||
"code": 401,
|
||||
"message": "unauthorized access, please provide a valid access token or method signing if supported"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
3
envd-rs/src/auth/mod.rs
Normal file
3
envd-rs/src/auth/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod token;
|
||||
pub mod signing;
|
||||
pub mod middleware;
|
||||
210
envd-rs/src/auth/signing.rs
Normal file
210
envd-rs/src/auth/signing.rs
Normal file
@ -0,0 +1,210 @@
|
||||
use crate::auth::token::SecureToken;
|
||||
use crate::crypto;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
pub const READ_OPERATION: &str = "read";
|
||||
pub const WRITE_OPERATION: &str = "write";
|
||||
|
||||
/// Generate a v1 signature: `v1_{sha256_base64(path:operation:username:token[:expiration])}`
|
||||
pub fn generate_signature(
|
||||
token: &SecureToken,
|
||||
path: &str,
|
||||
username: &str,
|
||||
operation: &str,
|
||||
expiration: Option<i64>,
|
||||
) -> Result<String, &'static str> {
|
||||
let mut token_bytes = token.bytes().ok_or("access token is not set")?;
|
||||
|
||||
let payload = match expiration {
|
||||
Some(exp) => format!(
|
||||
"{}:{}:{}:{}:{}",
|
||||
path,
|
||||
operation,
|
||||
username,
|
||||
String::from_utf8_lossy(&token_bytes),
|
||||
exp
|
||||
),
|
||||
None => format!(
|
||||
"{}:{}:{}:{}",
|
||||
path,
|
||||
operation,
|
||||
username,
|
||||
String::from_utf8_lossy(&token_bytes),
|
||||
),
|
||||
};
|
||||
|
||||
token_bytes.zeroize();
|
||||
|
||||
let hash = crypto::sha256::hash_without_prefix(payload.as_bytes());
|
||||
Ok(format!("v1_{hash}"))
|
||||
}
|
||||
|
||||
/// Validate a request's signing. Returns Ok(()) if valid.
|
||||
pub fn validate_signing(
|
||||
token: &SecureToken,
|
||||
header_token: Option<&str>,
|
||||
signature: Option<&str>,
|
||||
signature_expiration: Option<i64>,
|
||||
username: &str,
|
||||
path: &str,
|
||||
operation: &str,
|
||||
) -> Result<(), String> {
|
||||
if !token.is_set() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(ht) = header_token {
|
||||
if !ht.is_empty() {
|
||||
if token.equals(ht) {
|
||||
return Ok(());
|
||||
}
|
||||
return Err("access token present in header but does not match".into());
|
||||
}
|
||||
}
|
||||
|
||||
let sig = signature.ok_or("missing signature query parameter")?;
|
||||
|
||||
let expected = generate_signature(token, path, username, operation, signature_expiration)
|
||||
.map_err(|e| format!("error generating signing key: {e}"))?;
|
||||
|
||||
if expected != sig {
|
||||
return Err("invalid signature".into());
|
||||
}
|
||||
|
||||
if let Some(exp) = signature_expiration {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
if exp < now {
|
||||
return Err("signature is already expired".into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_token(val: &[u8]) -> SecureToken {
|
||||
let t = SecureToken::new();
|
||||
t.set(val).unwrap();
|
||||
t
|
||||
}
|
||||
|
||||
fn far_future() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64
|
||||
+ 3600
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_starts_with_v1() {
|
||||
let token = test_token(b"secret");
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert!(sig.starts_with("v1_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_deterministic() {
|
||||
let token = test_token(b"secret");
|
||||
let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert_eq!(s1, s2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_with_expiration_differs() {
|
||||
let token = test_token(b"secret");
|
||||
let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap();
|
||||
let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap();
|
||||
assert_ne!(without, with);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_unset_token_errors() {
|
||||
let token = SecureToken::new();
|
||||
assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_token_set_passes() {
|
||||
let token = SecureToken::new();
|
||||
assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_correct_header_token() {
|
||||
let token = test_token(b"secret");
|
||||
assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_wrong_header_token() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("does not match"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_valid_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let exp = far_future();
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap();
|
||||
assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_invalid_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("invalid signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_expired_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let expired: i64 = 1_000_000;
|
||||
let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap();
|
||||
let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("expired"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_missing_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("missing signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_empty_header_token_falls_through_to_signature() {
|
||||
let token = test_token(b"secret");
|
||||
let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("missing signature"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_valid_signature_no_expiration() {
|
||||
let token = test_token(b"secret");
|
||||
let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap();
|
||||
assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_operations_produce_different_signatures() {
|
||||
let token = test_token(b"secret");
|
||||
let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap();
|
||||
let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap();
|
||||
assert_ne!(r, w);
|
||||
}
|
||||
}
|
||||
256
envd-rs/src/auth/token.rs
Normal file
256
envd-rs/src/auth/token.rs
Normal file
@ -0,0 +1,256 @@
|
||||
use std::sync::RwLock;
|
||||
|
||||
use subtle::ConstantTimeEq;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
/// Secure token storage with constant-time comparison and zeroize-on-drop.
|
||||
///
|
||||
/// Mirrors Go's SecureToken backed by memguard.LockedBuffer.
|
||||
/// In Rust we rely on `zeroize` for Drop-based zeroing.
|
||||
pub struct SecureToken {
|
||||
inner: RwLock<Option<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl SecureToken {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set(&self, token: &[u8]) -> Result<(), &'static str> {
|
||||
if token.is_empty() {
|
||||
return Err("empty token not allowed");
|
||||
}
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut old) = *guard {
|
||||
old.zeroize();
|
||||
}
|
||||
*guard = Some(token.to_vec());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_set(&self) -> bool {
|
||||
let guard = self.inner.read().unwrap();
|
||||
guard.is_some()
|
||||
}
|
||||
|
||||
/// Constant-time comparison.
|
||||
pub fn equals(&self, other: &str) -> bool {
|
||||
let guard = self.inner.read().unwrap();
|
||||
match guard.as_ref() {
|
||||
Some(buf) => buf.as_slice().ct_eq(other.as_bytes()).into(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Constant-time comparison with another SecureToken.
|
||||
pub fn equals_secure(&self, other: &SecureToken) -> bool {
|
||||
let other_bytes = match other.bytes() {
|
||||
Some(b) => b,
|
||||
None => return false,
|
||||
};
|
||||
let guard = self.inner.read().unwrap();
|
||||
let result = match guard.as_ref() {
|
||||
Some(buf) => buf.as_slice().ct_eq(&other_bytes).into(),
|
||||
None => false,
|
||||
};
|
||||
// other_bytes dropped here, Vec<u8> doesn't auto-zeroize but
|
||||
// we accept this — same as Go's `defer memguard.WipeBytes(otherBytes)`
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns a copy of the token bytes (for signature generation).
|
||||
pub fn bytes(&self) -> Option<Vec<u8>> {
|
||||
let guard = self.inner.read().unwrap();
|
||||
guard.as_ref().map(|b| b.clone())
|
||||
}
|
||||
|
||||
/// Transfer token from another SecureToken, clearing the source.
|
||||
pub fn take_from(&self, src: &SecureToken) {
|
||||
let taken = {
|
||||
let mut src_guard = src.inner.write().unwrap();
|
||||
src_guard.take()
|
||||
};
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut old) = *guard {
|
||||
old.zeroize();
|
||||
}
|
||||
*guard = taken;
|
||||
}
|
||||
|
||||
pub fn destroy(&self) {
|
||||
let mut guard = self.inner.write().unwrap();
|
||||
if let Some(ref mut buf) = *guard {
|
||||
buf.zeroize();
|
||||
}
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SecureToken {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
if let Some(ref mut buf) = *guard {
|
||||
buf.zeroize();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize from JSON string, matching Go's UnmarshalJSON behavior.
|
||||
/// Expects a quoted JSON string. Rejects escape sequences.
|
||||
impl SecureToken {
|
||||
pub fn from_json_bytes(data: &mut [u8]) -> Result<Self, &'static str> {
|
||||
if data.len() < 2 || data[0] != b'"' || data[data.len() - 1] != b'"' {
|
||||
data.zeroize();
|
||||
return Err("invalid secure token JSON string");
|
||||
}
|
||||
|
||||
let content = &data[1..data.len() - 1];
|
||||
if content.contains(&b'\\') {
|
||||
data.zeroize();
|
||||
return Err("invalid secure token: unexpected escape sequence");
|
||||
}
|
||||
|
||||
if content.is_empty() {
|
||||
data.zeroize();
|
||||
return Err("empty token not allowed");
|
||||
}
|
||||
|
||||
let token = Self::new();
|
||||
token.set(content).map_err(|_| "failed to set token")?;
|
||||
|
||||
data.zeroize();
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_is_unset() {
|
||||
let t = SecureToken::new();
|
||||
assert!(!t.is_set());
|
||||
assert!(!t.equals("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_and_equals() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"secret").unwrap();
|
||||
assert!(t.is_set());
|
||||
assert!(t.equals("secret"));
|
||||
assert!(!t.equals("wrong"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_empty_errors() {
|
||||
let t = SecureToken::new();
|
||||
assert!(t.set(b"").is_err());
|
||||
assert!(!t.is_set());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_overwrites_previous() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"first").unwrap();
|
||||
t.set(b"second").unwrap();
|
||||
assert!(!t.equals("first"));
|
||||
assert!(t.equals("second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn destroy_clears() {
|
||||
let t = SecureToken::new();
|
||||
t.set(b"secret").unwrap();
|
||||
t.destroy();
|
||||
assert!(!t.is_set());
|
||||
assert!(!t.equals("secret"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bytes_returns_copy() {
|
||||
let t = SecureToken::new();
|
||||
assert!(t.bytes().is_none());
|
||||
t.set(b"hello").unwrap();
|
||||
assert_eq!(t.bytes().unwrap(), b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_from_transfers_and_clears_source() {
|
||||
let src = SecureToken::new();
|
||||
src.set(b"token").unwrap();
|
||||
let dst = SecureToken::new();
|
||||
dst.take_from(&src);
|
||||
assert!(!src.is_set());
|
||||
assert!(dst.equals("token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_from_overwrites_existing() {
|
||||
let src = SecureToken::new();
|
||||
src.set(b"new").unwrap();
|
||||
let dst = SecureToken::new();
|
||||
dst.set(b"old").unwrap();
|
||||
dst.take_from(&src);
|
||||
assert!(dst.equals("new"));
|
||||
assert!(!dst.equals("old"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_matching() {
|
||||
let a = SecureToken::new();
|
||||
a.set(b"same").unwrap();
|
||||
let b = SecureToken::new();
|
||||
b.set(b"same").unwrap();
|
||||
assert!(a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_different() {
|
||||
let a = SecureToken::new();
|
||||
a.set(b"one").unwrap();
|
||||
let b = SecureToken::new();
|
||||
b.set(b"two").unwrap();
|
||||
assert!(!a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_secure_unset() {
|
||||
let a = SecureToken::new();
|
||||
let b = SecureToken::new();
|
||||
assert!(!a.equals_secure(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_valid() {
|
||||
let mut data = b"\"mysecret\"".to_vec();
|
||||
let t = SecureToken::from_json_bytes(&mut data).unwrap();
|
||||
assert!(t.equals("mysecret"));
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_missing_quotes() {
|
||||
let mut data = b"noquotes".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_escape_sequences() {
|
||||
let mut data = b"\"has\\nescapes\"".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_json_bytes_rejects_empty_content() {
|
||||
let mut data = b"\"\"".to_vec();
|
||||
assert!(SecureToken::from_json_bytes(&mut data).is_err());
|
||||
assert!(data.iter().all(|&b| b == 0));
|
||||
}
|
||||
}
|
||||
66
envd-rs/src/cgroups/mod.rs
Normal file
66
envd-rs/src/cgroups/mod.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::os::unix::io::{OwnedFd, RawFd};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ProcessType {
|
||||
Pty,
|
||||
User,
|
||||
Socat,
|
||||
}
|
||||
|
||||
pub trait CgroupManager: Send + Sync {
|
||||
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd>;
|
||||
}
|
||||
|
||||
pub struct Cgroup2Manager {
|
||||
fds: HashMap<ProcessType, OwnedFd>,
|
||||
}
|
||||
|
||||
impl Cgroup2Manager {
|
||||
pub fn new(root: &str, configs: &[(ProcessType, &str, &[(&str, &str)])]) -> Result<Self, String> {
|
||||
let mut fds = HashMap::new();
|
||||
|
||||
for (proc_type, sub_path, properties) in configs {
|
||||
let full_path = PathBuf::from(root).join(sub_path);
|
||||
|
||||
fs::create_dir_all(&full_path).map_err(|e| {
|
||||
format!("failed to create cgroup {}: {e}", full_path.display())
|
||||
})?;
|
||||
|
||||
for (name, value) in *properties {
|
||||
let prop_path = full_path.join(name);
|
||||
fs::write(&prop_path, value).map_err(|e| {
|
||||
format!("failed to write cgroup property {}: {e}", prop_path.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let fd = nix::fcntl::open(
|
||||
&full_path,
|
||||
nix::fcntl::OFlag::O_RDONLY,
|
||||
nix::sys::stat::Mode::empty(),
|
||||
)
|
||||
.map_err(|e| format!("failed to open cgroup {}: {e}", full_path.display()))?;
|
||||
|
||||
fds.insert(*proc_type, fd);
|
||||
}
|
||||
|
||||
Ok(Self { fds })
|
||||
}
|
||||
}
|
||||
|
||||
impl CgroupManager for Cgroup2Manager {
|
||||
fn get_fd(&self, proc_type: ProcessType) -> Option<RawFd> {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
self.fds.get(&proc_type).map(|fd| fd.as_raw_fd())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoopCgroupManager;
|
||||
|
||||
impl CgroupManager for NoopCgroupManager {
|
||||
fn get_fd(&self, _proc_type: ProcessType) -> Option<RawFd> {
|
||||
None
|
||||
}
|
||||
}
|
||||
11
envd-rs/src/config.rs
Normal file
11
envd-rs/src/config.rs
Normal file
@ -0,0 +1,11 @@
|
||||
use std::time::Duration;
|
||||
|
||||
pub const DEFAULT_PORT: u16 = 49983;
|
||||
pub const IDLE_TIMEOUT: Duration = Duration::from_secs(640);
|
||||
pub const CORS_MAX_AGE: Duration = Duration::from_secs(7200);
|
||||
pub const PORT_SCANNER_INTERVAL: Duration = Duration::from_millis(1000);
|
||||
pub const DEFAULT_USER: &str = "root";
|
||||
pub const WRENN_RUN_DIR: &str = "/run/wrenn";
|
||||
|
||||
pub const KILOBYTE: u64 = 1024;
|
||||
pub const MEGABYTE: u64 = 1024 * KILOBYTE;
|
||||
200
envd-rs/src/conntracker.rs
Normal file
200
envd-rs/src/conntracker.rs
Normal file
@ -0,0 +1,200 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Tracks active TCP connections for snapshot/restore lifecycle.
|
||||
///
|
||||
/// Before snapshot: close idle connections, record active ones.
|
||||
/// After restore: close all pre-snapshot connections (zombie TCP sockets).
|
||||
///
|
||||
/// In Rust/axum, we don't have Go's ConnState callback. Instead we track
|
||||
/// connections via a tower middleware that registers connection IDs.
|
||||
/// For the initial implementation, we track by a simple connection counter
|
||||
/// and rely on axum's graceful shutdown mechanics.
|
||||
pub struct ConnTracker {
|
||||
inner: Mutex<ConnTrackerInner>,
|
||||
}
|
||||
|
||||
struct ConnTrackerInner {
|
||||
active: HashSet<u64>,
|
||||
pre_snapshot: Option<HashSet<u64>>,
|
||||
next_id: u64,
|
||||
keepalives_enabled: bool,
|
||||
}
|
||||
|
||||
impl ConnTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Mutex::new(ConnTrackerInner {
|
||||
active: HashSet::new(),
|
||||
pre_snapshot: None,
|
||||
next_id: 0,
|
||||
keepalives_enabled: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_connection(&self) -> u64 {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let id = inner.next_id;
|
||||
inner.next_id += 1;
|
||||
inner.active.insert(id);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn remove_connection(&self, id: u64) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.active.remove(&id);
|
||||
if let Some(ref mut pre) = inner.pre_snapshot {
|
||||
pre.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prepare_for_snapshot(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.keepalives_enabled = false;
|
||||
inner.pre_snapshot = Some(inner.active.clone());
|
||||
tracing::info!(
|
||||
active_connections = inner.active.len(),
|
||||
"snapshot: recorded pre-snapshot connections, keep-alives disabled"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn restore_after_snapshot(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
if let Some(pre) = inner.pre_snapshot.take() {
|
||||
let zombie_count = pre.len();
|
||||
for id in &pre {
|
||||
inner.active.remove(id);
|
||||
}
|
||||
if zombie_count > 0 {
|
||||
tracing::info!(zombie_count, "restore: closed zombie connections");
|
||||
}
|
||||
}
|
||||
inner.keepalives_enabled = true;
|
||||
}
|
||||
|
||||
pub fn keepalives_enabled(&self) -> bool {
|
||||
self.inner.lock().unwrap().keepalives_enabled
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn active_count(&self) -> usize {
|
||||
self.inner.lock().unwrap().active.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn register_assigns_sequential_ids() {
|
||||
let ct = ConnTracker::new();
|
||||
assert_eq!(ct.register_connection(), 0);
|
||||
assert_eq!(ct.register_connection(), 1);
|
||||
assert_eq!(ct.register_connection(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_clears_active() {
|
||||
let ct = ConnTracker::new();
|
||||
let id = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.remove_connection(id);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_is_noop() {
|
||||
let ct = ConnTracker::new();
|
||||
ct.remove_connection(999);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_disables_keepalives() {
|
||||
let ct = ConnTracker::new();
|
||||
assert!(ct.keepalives_enabled());
|
||||
ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
assert!(!ct.keepalives_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restore_removes_zombies_and_reenables_keepalives() {
|
||||
let ct = ConnTracker::new();
|
||||
let id0 = ct.register_connection();
|
||||
let id1 = ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
// Both pre-snapshot connections removed as zombies
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
// IDs don't matter anymore, but remove shouldn't panic
|
||||
ct.remove_connection(id0);
|
||||
ct.remove_connection(id1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restore_without_prepare_is_noop() {
|
||||
let ct = ConnTracker::new();
|
||||
let _id = ct.register_connection();
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_before_restore_not_zombie() {
|
||||
let ct = ConnTracker::new();
|
||||
let id0 = ct.register_connection();
|
||||
let _id1 = ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
// Close id0 during snapshot window
|
||||
ct.remove_connection(id0);
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.restore_after_snapshot();
|
||||
// id1 was zombie (still active at restore), id0 already gone
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn post_snapshot_connection_survives_restore() {
|
||||
let ct = ConnTracker::new();
|
||||
ct.register_connection();
|
||||
ct.prepare_for_snapshot();
|
||||
// New connection after snapshot
|
||||
let _post = ct.register_connection();
|
||||
ct.restore_after_snapshot();
|
||||
// Pre-snapshot connection removed, post-snapshot survives
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_lifecycle() {
|
||||
let ct = ConnTracker::new();
|
||||
let _a = ct.register_connection();
|
||||
let b = ct.register_connection();
|
||||
let _c = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 3);
|
||||
assert!(ct.keepalives_enabled());
|
||||
|
||||
ct.prepare_for_snapshot();
|
||||
assert!(!ct.keepalives_enabled());
|
||||
|
||||
let d = ct.register_connection();
|
||||
ct.remove_connection(b);
|
||||
|
||||
ct.restore_after_snapshot();
|
||||
assert!(ct.keepalives_enabled());
|
||||
// a and c were zombies, b removed before restore, d is post-snapshot
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
ct.remove_connection(d);
|
||||
assert_eq!(ct.active_count(), 0);
|
||||
|
||||
// Can reuse tracker after restore
|
||||
let e = ct.register_connection();
|
||||
assert_eq!(ct.active_count(), 1);
|
||||
assert!(e > d);
|
||||
}
|
||||
}
|
||||
43
envd-rs/src/crypto/hmac_sha256.rs
Normal file
43
envd-rs/src/crypto/hmac_sha256.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub fn compute(key: &[u8], data: &[u8]) -> String {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
|
||||
mac.update(data);
|
||||
let result = mac.finalize();
|
||||
hex::encode(result.into_bytes())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rfc4231_tc1() {
|
||||
let key = &[0x0b; 20];
|
||||
let data = b"Hi There";
|
||||
assert_eq!(
|
||||
compute(key, data),
|
||||
"b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rfc4231_tc2() {
|
||||
let key = b"Jefe";
|
||||
let data = b"what do ya want for nothing?";
|
||||
assert_eq!(
|
||||
compute(key, data),
|
||||
"5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_64_hex_chars() {
|
||||
let result = compute(b"key", b"data");
|
||||
assert_eq!(result.len(), 64);
|
||||
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
}
|
||||
3
envd-rs/src/crypto/mod.rs
Normal file
3
envd-rs/src/crypto/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod sha256;
|
||||
pub mod sha512;
|
||||
pub mod hmac_sha256;
|
||||
54
envd-rs/src/crypto/sha256.rs
Normal file
54
envd-rs/src/crypto/sha256.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD_NO_PAD;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
pub fn hash(data: &[u8]) -> String {
|
||||
let h = Sha256::digest(data);
|
||||
let encoded = STANDARD_NO_PAD.encode(h);
|
||||
format!("$sha256${encoded}")
|
||||
}
|
||||
|
||||
pub fn hash_without_prefix(data: &[u8]) -> String {
|
||||
let h = Sha256::digest(data);
|
||||
STANDARD_NO_PAD.encode(h)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const VECTORS: &[(&[u8], &str)] = &[
|
||||
(b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"),
|
||||
(b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"),
|
||||
(b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn known_answer_with_prefix() {
|
||||
for (input, expected_b64) in VECTORS {
|
||||
let result = hash(input);
|
||||
assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn known_answer_without_prefix() {
|
||||
for (input, expected_b64) in VECTORS {
|
||||
let result = hash_without_prefix(input);
|
||||
assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_base64_padding() {
|
||||
for (input, _) in VECTORS {
|
||||
assert!(!hash(input).contains('='));
|
||||
assert!(!hash_without_prefix(input).contains('='));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deterministic() {
|
||||
assert_eq!(hash(b"test"), hash(b"test"));
|
||||
}
|
||||
}
|
||||
43
envd-rs/src/crypto/sha512.rs
Normal file
43
envd-rs/src/crypto/sha512.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use sha2::{Digest, Sha512};
|
||||
|
||||
pub fn hash_access_token(token: &str) -> String {
|
||||
let h = Sha512::digest(token.as_bytes());
|
||||
hex::encode(h)
|
||||
}
|
||||
|
||||
pub fn hash_access_token_bytes(token: &[u8]) -> String {
|
||||
let h = Sha512::digest(token);
|
||||
hex::encode(h)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const VECTORS: &[(&str, &str)] = &[
|
||||
("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"),
|
||||
("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"),
|
||||
("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn known_answer() {
|
||||
for (input, expected) in VECTORS {
|
||||
assert_eq!(hash_access_token(input), *expected, "input: {input:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn str_and_bytes_agree() {
|
||||
for (input, _) in VECTORS {
|
||||
assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_lowercase_hex_128_chars() {
|
||||
let h = hash_access_token("anything");
|
||||
assert_eq!(h.len(), 128);
|
||||
assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
|
||||
}
|
||||
}
|
||||
118
envd-rs/src/execcontext.rs
Normal file
118
envd-rs/src/execcontext.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
pub struct Defaults {
|
||||
pub env_vars: Arc<DashMap<String, String>>,
|
||||
user: RwLock<String>,
|
||||
workdir: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl Defaults {
|
||||
pub fn new(user: &str) -> Self {
|
||||
Self {
|
||||
env_vars: Arc::new(DashMap::new()),
|
||||
user: RwLock::new(user.to_string()),
|
||||
workdir: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(&self) -> String {
|
||||
self.user.read().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn set_user(&self, user: String) {
|
||||
*self.user.write().unwrap() = user;
|
||||
}
|
||||
|
||||
pub fn workdir(&self) -> Option<String> {
|
||||
self.workdir.read().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn set_workdir(&self, workdir: Option<String>) {
|
||||
*self.workdir.write().unwrap() = workdir;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String {
|
||||
if !workdir.is_empty() {
|
||||
return workdir.to_string();
|
||||
}
|
||||
if let Some(dw) = default_workdir {
|
||||
return dw.to_string();
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
pub fn resolve_default_username<'a>(
|
||||
username: Option<&'a str>,
|
||||
default_username: &'a str,
|
||||
) -> Result<&'a str, &'static str> {
|
||||
if let Some(u) = username {
|
||||
return Ok(u);
|
||||
}
|
||||
if !default_username.is_empty() {
|
||||
return Ok(default_username);
|
||||
}
|
||||
Err("username not provided")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn workdir_explicit_overrides_default() {
|
||||
assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_empty_uses_default() {
|
||||
assert_eq!(resolve_default_workdir("", Some("/default")), "/default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_empty_no_default_returns_empty() {
|
||||
assert_eq!(resolve_default_workdir("", None), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workdir_explicit_ignores_none_default() {
|
||||
assert_eq!(resolve_default_workdir("/explicit", None), "/explicit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_explicit_returns_explicit() {
|
||||
assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_none_uses_default() {
|
||||
assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_none_empty_default_errors() {
|
||||
assert!(resolve_default_username(None, "").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn username_some_overrides_empty_default() {
|
||||
assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defaults_user_set_and_get() {
|
||||
let d = Defaults::new("initial");
|
||||
assert_eq!(d.user(), "initial");
|
||||
d.set_user("changed".into());
|
||||
assert_eq!(d.user(), "changed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defaults_workdir_initially_none() {
|
||||
let d = Defaults::new("user");
|
||||
assert!(d.workdir().is_none());
|
||||
d.set_workdir(Some("/home".into()));
|
||||
assert_eq!(d.workdir().unwrap(), "/home");
|
||||
}
|
||||
}
|
||||
336
envd-rs/src/http/encoding.rs
Normal file
336
envd-rs/src/http/encoding.rs
Normal file
@ -0,0 +1,336 @@
|
||||
use axum::http::Request;
|
||||
|
||||
const ENCODING_GZIP: &str = "gzip";
|
||||
const ENCODING_IDENTITY: &str = "identity";
|
||||
const ENCODING_WILDCARD: &str = "*";
|
||||
|
||||
const SUPPORTED_ENCODINGS: &[&str] = &[ENCODING_GZIP];
|
||||
|
||||
struct EncodingWithQuality {
|
||||
encoding: String,
|
||||
quality: f64,
|
||||
}
|
||||
|
||||
fn parse_encoding_with_quality(value: &str) -> EncodingWithQuality {
|
||||
let value = value.trim();
|
||||
let mut quality = 1.0;
|
||||
|
||||
if let Some(idx) = value.find(';') {
|
||||
let params = &value[idx + 1..];
|
||||
let enc = value[..idx].trim();
|
||||
for param in params.split(';') {
|
||||
let param = param.trim();
|
||||
if let Some(stripped) = param.strip_prefix("q=").or_else(|| param.strip_prefix("Q=")) {
|
||||
if let Ok(q) = stripped.parse::<f64>() {
|
||||
quality = q;
|
||||
}
|
||||
}
|
||||
}
|
||||
return EncodingWithQuality {
|
||||
encoding: enc.to_ascii_lowercase(),
|
||||
quality,
|
||||
};
|
||||
}
|
||||
|
||||
EncodingWithQuality {
|
||||
encoding: value.to_ascii_lowercase(),
|
||||
quality,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_accept_encoding_header(header: &str) -> (Vec<EncodingWithQuality>, bool) {
|
||||
if header.is_empty() {
|
||||
return (Vec::new(), false);
|
||||
}
|
||||
|
||||
let encodings: Vec<EncodingWithQuality> =
|
||||
header.split(',').map(|v| parse_encoding_with_quality(v)).collect();
|
||||
|
||||
let mut identity_rejected = false;
|
||||
let mut identity_explicitly_accepted = false;
|
||||
let mut wildcard_rejected = false;
|
||||
|
||||
for eq in &encodings {
|
||||
match eq.encoding.as_str() {
|
||||
ENCODING_IDENTITY => {
|
||||
if eq.quality == 0.0 {
|
||||
identity_rejected = true;
|
||||
} else {
|
||||
identity_explicitly_accepted = true;
|
||||
}
|
||||
}
|
||||
ENCODING_WILDCARD => {
|
||||
if eq.quality == 0.0 {
|
||||
wildcard_rejected = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcard_rejected && !identity_explicitly_accepted {
|
||||
identity_rejected = true;
|
||||
}
|
||||
|
||||
(encodings, identity_rejected)
|
||||
}
|
||||
|
||||
pub fn is_identity_acceptable<B>(r: &Request<B>) -> bool {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("accept-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let (_, rejected) = parse_accept_encoding_header(header);
|
||||
!rejected
|
||||
}
|
||||
|
||||
pub fn parse_accept_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("accept-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if header.is_empty() {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
let (mut encodings, identity_rejected) = parse_accept_encoding_header(header);
|
||||
encodings.sort_by(|a, b| b.quality.partial_cmp(&a.quality).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
for eq in &encodings {
|
||||
if eq.quality == 0.0 {
|
||||
continue;
|
||||
}
|
||||
if eq.encoding == ENCODING_IDENTITY {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if eq.encoding == ENCODING_WILDCARD {
|
||||
if identity_rejected && !SUPPORTED_ENCODINGS.is_empty() {
|
||||
return Ok(SUPPORTED_ENCODINGS[0]);
|
||||
}
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if eq.encoding == ENCODING_GZIP {
|
||||
return Ok(ENCODING_GZIP);
|
||||
}
|
||||
}
|
||||
|
||||
if !identity_rejected {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
Err(format!("no acceptable encoding found, supported: {SUPPORTED_ENCODINGS:?}"))
|
||||
}
|
||||
|
||||
pub fn parse_content_encoding<B>(r: &Request<B>) -> Result<&'static str, String> {
|
||||
let header = r
|
||||
.headers()
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if header.is_empty() {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
|
||||
let encoding = header.trim().to_ascii_lowercase();
|
||||
if encoding == ENCODING_IDENTITY {
|
||||
return Ok(ENCODING_IDENTITY);
|
||||
}
|
||||
if SUPPORTED_ENCODINGS.contains(&encoding.as_str()) {
|
||||
return Ok(ENCODING_GZIP);
|
||||
}
|
||||
|
||||
Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::Request;
|
||||
|
||||
fn req_with_accept(v: &str) -> Request<()> {
|
||||
Request::builder()
|
||||
.header("accept-encoding", v)
|
||||
.body(())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn req_with_content(v: &str) -> Request<()> {
|
||||
Request::builder()
|
||||
.header("content-encoding", v)
|
||||
.body(())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn req_no_headers() -> Request<()> {
|
||||
Request::builder().body(()).unwrap()
|
||||
}
|
||||
|
||||
// parse_encoding_with_quality
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_default_1() {
|
||||
let eq = parse_encoding_with_quality("gzip");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_explicit() {
|
||||
let eq = parse_encoding_with_quality("gzip;q=0.8");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_case_insensitive() {
|
||||
let eq = parse_encoding_with_quality("GZIP;Q=0.5");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_zero() {
|
||||
let eq = parse_encoding_with_quality("gzip;q=0");
|
||||
assert_eq!(eq.quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_quality_whitespace_trimmed() {
|
||||
let eq = parse_encoding_with_quality(" gzip ; q=0.9 ");
|
||||
assert_eq!(eq.encoding, "gzip");
|
||||
assert_eq!(eq.quality, 0.9);
|
||||
}
|
||||
|
||||
// parse_accept_encoding_header
|
||||
|
||||
#[test]
|
||||
fn accept_header_empty() {
|
||||
let (encs, rejected) = parse_accept_encoding_header("");
|
||||
assert!(encs.is_empty());
|
||||
assert!(!rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_identity_q0_rejects() {
|
||||
let (_, rejected) = parse_accept_encoding_header("identity;q=0");
|
||||
assert!(rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_wildcard_q0_rejects_identity() {
|
||||
let (_, rejected) = parse_accept_encoding_header("*;q=0");
|
||||
assert!(rejected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_header_wildcard_q0_but_identity_explicit_accepted() {
|
||||
let (_, rejected) = parse_accept_encoding_header("*;q=0, identity");
|
||||
assert!(!rejected);
|
||||
}
|
||||
|
||||
// parse_accept_encoding (full)
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_no_header_returns_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_gzip() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_identity_explicit() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_gzip_higher_quality() {
|
||||
assert_eq!(
|
||||
parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(),
|
||||
"gzip"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_wildcard_returns_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_wildcard_identity_rejected_returns_gzip() {
|
||||
assert_eq!(
|
||||
parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(),
|
||||
"gzip"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_all_rejected_errors() {
|
||||
assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_encoding_unsupported_only_falls_to_identity() {
|
||||
assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
// is_identity_acceptable
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_no_header() {
|
||||
assert!(is_identity_acceptable(&req_no_headers()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_gzip_only() {
|
||||
assert!(is_identity_acceptable(&req_with_accept("gzip")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_not_acceptable_identity_q0() {
|
||||
assert!(!is_identity_acceptable(&req_with_accept("identity;q=0")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_not_acceptable_wildcard_q0() {
|
||||
assert!(!is_identity_acceptable(&req_with_accept("*;q=0")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_acceptable_wildcard_q0_but_identity_explicit() {
|
||||
assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity")));
|
||||
}
|
||||
|
||||
// parse_content_encoding
|
||||
|
||||
#[test]
|
||||
fn content_encoding_empty_returns_identity() {
|
||||
assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_gzip() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_identity_explicit() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_unsupported_errors() {
|
||||
assert!(parse_content_encoding(&req_with_content("br")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_encoding_case_insensitive() {
|
||||
assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip");
|
||||
}
|
||||
}
|
||||
25
envd-rs/src/http/envs.rs
Normal file
25
envd-rs/src/http/envs.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn get_envs(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
tracing::debug!("getting env vars");
|
||||
|
||||
let envs: HashMap<String, String> = state
|
||||
.defaults
|
||||
.env_vars
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect();
|
||||
|
||||
(
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(envs),
|
||||
)
|
||||
}
|
||||
20
envd-rs/src/http/error.rs
Normal file
20
envd-rs/src/http/error.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use axum::Json;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorBody {
|
||||
code: u16,
|
||||
message: String,
|
||||
}
|
||||
|
||||
pub fn json_error(status: StatusCode, message: &str) -> impl IntoResponse {
|
||||
(
|
||||
status,
|
||||
Json(ErrorBody {
|
||||
code: status.as_u16(),
|
||||
message: message.to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
447
envd-rs/src/http/files.rs
Normal file
447
envd-rs/src/http/files.rs
Normal file
@ -0,0 +1,447 @@
|
||||
use std::io::Write as _;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::extract::{FromRequest, Query, Request, State};
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::signing;
|
||||
use crate::execcontext;
|
||||
use crate::http::encoding;
|
||||
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::state::AppState;
|
||||
|
||||
const ACCESS_TOKEN_HEADER: &str = "x-access-token";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FileParams {
|
||||
pub path: Option<String>,
|
||||
pub username: Option<String>,
|
||||
pub signature: Option<String>,
|
||||
pub signature_expiration: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EntryInfo {
|
||||
path: String,
|
||||
name: String,
|
||||
r#type: &'static str,
|
||||
}
|
||||
|
||||
fn json_error(status: StatusCode, msg: &str) -> Response {
|
||||
let body = serde_json::json!({ "code": status.as_u16(), "message": msg });
|
||||
(status, axum::Json(body)).into_response()
|
||||
}
|
||||
|
||||
fn extract_header_token(req: &Request) -> Option<&str> {
|
||||
req.headers()
|
||||
.get(ACCESS_TOKEN_HEADER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
}
|
||||
|
||||
fn validate_file_signing(
|
||||
state: &AppState,
|
||||
header_token: Option<&str>,
|
||||
params: &FileParams,
|
||||
path: &str,
|
||||
operation: &str,
|
||||
username: &str,
|
||||
) -> Result<(), String> {
|
||||
signing::validate_signing(
|
||||
&state.access_token,
|
||||
header_token,
|
||||
params.signature.as_deref(),
|
||||
params.signature_expiration,
|
||||
username,
|
||||
path,
|
||||
operation,
|
||||
)
|
||||
}
|
||||
|
||||
/// GET /files — download a file
|
||||
pub async fn get_files(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<FileParams>,
|
||||
req: Request,
|
||||
) -> Response {
|
||||
let path_str = params.path.as_deref().unwrap_or("");
|
||||
let header_token = extract_header_token(&req);
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let username = match execcontext::resolve_default_username(
|
||||
params.username.as_deref(),
|
||||
&default_user,
|
||||
) {
|
||||
Ok(u) => u.to_string(),
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_file_signing(
|
||||
&state,
|
||||
header_token,
|
||||
¶ms,
|
||||
path_str,
|
||||
signing::READ_OPERATION,
|
||||
&username,
|
||||
) {
|
||||
return json_error(StatusCode::UNAUTHORIZED, &e);
|
||||
}
|
||||
|
||||
let user = match lookup_user(&username) {
|
||||
Ok(u) => u,
|
||||
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
|
||||
};
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref())
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
};
|
||||
|
||||
let meta = match std::fs::metadata(&resolved) {
|
||||
Ok(m) => m,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
return json_error(
|
||||
StatusCode::NOT_FOUND,
|
||||
&format!("path '{}' does not exist", resolved),
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error checking path: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if meta.is_dir() {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("path '{}' is a directory", resolved),
|
||||
);
|
||||
}
|
||||
|
||||
if !meta.file_type().is_file() {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("path '{}' is not a regular file", resolved),
|
||||
);
|
||||
}
|
||||
|
||||
let accept_enc = match encoding::parse_accept_encoding(&req) {
|
||||
Ok(e) => e,
|
||||
Err(e) => return json_error(StatusCode::NOT_ACCEPTABLE, &e),
|
||||
};
|
||||
|
||||
let has_range_or_conditional = req.headers().get("range").is_some()
|
||||
|| req.headers().get("if-modified-since").is_some()
|
||||
|| req.headers().get("if-none-match").is_some()
|
||||
|| req.headers().get("if-range").is_some();
|
||||
|
||||
let use_encoding = if has_range_or_conditional {
|
||||
if !encoding::is_identity_acceptable(&req) {
|
||||
return json_error(
|
||||
StatusCode::NOT_ACCEPTABLE,
|
||||
"identity encoding not acceptable for Range or conditional request",
|
||||
);
|
||||
}
|
||||
"identity"
|
||||
} else {
|
||||
accept_enc
|
||||
};
|
||||
|
||||
let file_data = match std::fs::read(&resolved) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error reading file: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let filename = Path::new(&resolved)
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let content_disposition = format!("inline; filename=\"{}\"", filename);
|
||||
let content_type = mime_guess::from_path(&resolved)
|
||||
.first_raw()
|
||||
.unwrap_or("application/octet-stream");
|
||||
|
||||
if use_encoding == "gzip" {
|
||||
let mut encoder =
|
||||
flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
|
||||
if let Err(e) = encoder.write_all(&file_data) {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("gzip encoding error: {e}"),
|
||||
);
|
||||
}
|
||||
let compressed = match encoder.finish() {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("gzip finish error: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_ENCODING, "gzip")
|
||||
.header(header::CONTENT_DISPOSITION, content_disposition)
|
||||
.header(header::VARY, "Accept-Encoding")
|
||||
.body(Body::from(compressed))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_DISPOSITION, content_disposition)
|
||||
.header(header::VARY, "Accept-Encoding")
|
||||
.header(header::CONTENT_LENGTH, file_data.len())
|
||||
.body(Body::from(file_data))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// POST /files — upload file(s) via multipart
|
||||
pub async fn post_files(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(params): Query<FileParams>,
|
||||
req: Request,
|
||||
) -> Response {
|
||||
let path_str = params.path.as_deref().unwrap_or("");
|
||||
let header_token = extract_header_token(&req);
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let username = match execcontext::resolve_default_username(
|
||||
params.username.as_deref(),
|
||||
&default_user,
|
||||
) {
|
||||
Ok(u) => u.to_string(),
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, e),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_file_signing(
|
||||
&state,
|
||||
header_token,
|
||||
¶ms,
|
||||
path_str,
|
||||
signing::WRITE_OPERATION,
|
||||
&username,
|
||||
) {
|
||||
return json_error(StatusCode::UNAUTHORIZED, &e);
|
||||
}
|
||||
|
||||
let user = match lookup_user(&username) {
|
||||
Ok(u) => u,
|
||||
Err(e) => return json_error(StatusCode::UNAUTHORIZED, &e),
|
||||
};
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let uid = user.uid;
|
||||
let gid = user.gid;
|
||||
|
||||
let content_enc = match encoding::parse_content_encoding(&req) {
|
||||
Ok(e) => e,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
};
|
||||
|
||||
let mut multipart = match axum::extract::Multipart::from_request(req, &()).await {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error parsing multipart: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let mut uploaded: Vec<EntryInfo> = Vec::new();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
|
||||
while let Ok(Some(field)) = multipart.next_field().await {
|
||||
let field_name = field.name().unwrap_or("").to_string();
|
||||
if field_name != "file" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let file_path = if !path_str.is_empty() {
|
||||
match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
}
|
||||
} else {
|
||||
let fname = field
|
||||
.file_name()
|
||||
.unwrap_or("upload")
|
||||
.to_string();
|
||||
match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
|
||||
}
|
||||
};
|
||||
|
||||
if uploaded.iter().any(|e| e.path == file_path) {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("cannot upload multiple files to same path '{}'", file_path),
|
||||
);
|
||||
}
|
||||
|
||||
let raw_bytes = match field.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&format!("error reading field: {e}"),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let data = if content_enc == "gzip" {
|
||||
use std::io::Read;
|
||||
let mut decoder = flate2::read::GzDecoder::new(&raw_bytes[..]);
|
||||
let mut buf = Vec::new();
|
||||
match decoder.read_to_end(&mut buf) {
|
||||
Ok(_) => buf,
|
||||
Err(e) => {
|
||||
return json_error(
|
||||
StatusCode::BAD_REQUEST,
|
||||
&format!("gzip decompression failed: {e}"),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
raw_bytes.to_vec()
|
||||
};
|
||||
|
||||
if let Err(e) = process_file(&file_path, &data, uid, gid) {
|
||||
let (status, msg) = e;
|
||||
return json_error(status, &msg);
|
||||
}
|
||||
|
||||
let name = Path::new(&file_path)
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
uploaded.push(EntryInfo {
|
||||
path: file_path,
|
||||
name,
|
||||
r#type: "file",
|
||||
});
|
||||
}
|
||||
|
||||
axum::Json(uploaded).into_response()
|
||||
}
|
||||
|
||||
fn process_file(
|
||||
path: &str,
|
||||
data: &[u8],
|
||||
uid: nix::unistd::Uid,
|
||||
gid: nix::unistd::Gid,
|
||||
) -> Result<(), (StatusCode, String)> {
|
||||
let dir = Path::new(path)
|
||||
.parent()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !dir.is_empty() {
|
||||
ensure_dirs(&dir, uid, gid).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error ensuring directories: {e}"),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let can_pre_chown = match std::fs::metadata(path) {
|
||||
Ok(meta) => {
|
||||
if meta.is_dir() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("path is a directory: {path}"),
|
||||
));
|
||||
}
|
||||
true
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => false,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error getting file info: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut chowned = false;
|
||||
if can_pre_chown {
|
||||
match std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())) {
|
||||
Ok(()) => chowned = true,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error changing ownership: {e}"),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o666)
|
||||
.open(path)
|
||||
.map_err(|e| {
|
||||
if e.raw_os_error() == Some(libc::ENOSPC) {
|
||||
return (
|
||||
StatusCode::INSUFFICIENT_STORAGE,
|
||||
"not enough disk space available".to_string(),
|
||||
);
|
||||
}
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error opening file: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !chowned {
|
||||
std::os::unix::fs::chown(path, Some(uid.as_raw()), Some(gid.as_raw())).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error changing ownership: {e}"),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
file.write_all(data).map_err(|e| {
|
||||
if e.raw_os_error() == Some(libc::ENOSPC) {
|
||||
return (
|
||||
StatusCode::INSUFFICIENT_STORAGE,
|
||||
"not enough disk space available".to_string(),
|
||||
);
|
||||
}
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("error writing file: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
20
envd-rs/src/http/health.rs
Normal file
20
envd-rs/src/http/health.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn get_health(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
state.try_restore_recovery();
|
||||
|
||||
tracing::trace!("health check");
|
||||
|
||||
(
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(json!({ "version": state.version })),
|
||||
)
|
||||
}
|
||||
249
envd-rs/src/http/init.rs
Normal file
249
envd-rs/src/http/init.rs
Normal file
@ -0,0 +1,249 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct InitRequest {
|
||||
#[serde(rename = "access_token")]
|
||||
pub access_token: Option<String>,
|
||||
#[serde(rename = "defaultUser")]
|
||||
pub default_user: Option<String>,
|
||||
#[serde(rename = "defaultWorkdir")]
|
||||
pub default_workdir: Option<String>,
|
||||
#[serde(rename = "envVars")]
|
||||
pub env_vars: Option<HashMap<String, String>>,
|
||||
#[serde(rename = "hyperloop_ip")]
|
||||
pub hyperloop_ip: Option<String>,
|
||||
pub timestamp: Option<String>,
|
||||
#[serde(rename = "volume_mounts")]
|
||||
pub volume_mounts: Option<Vec<VolumeMount>>,
|
||||
pub sandbox_id: Option<String>,
|
||||
pub template_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VolumeMount {
|
||||
pub nfs_target: String,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// POST /init — called by host agent after boot and after every resume.
|
||||
pub async fn post_init(
|
||||
State(state): State<Arc<AppState>>,
|
||||
body: Option<Json<InitRequest>>,
|
||||
) -> impl IntoResponse {
|
||||
let init_req = body.map(|b| b.0).unwrap_or_default();
|
||||
|
||||
// Validate access token if provided
|
||||
if let Some(ref token_str) = init_req.access_token {
|
||||
if let Err(e) = validate_init_access_token(&state, token_str).await {
|
||||
tracing::error!(error = %e, "init: access token validation failed");
|
||||
return (StatusCode::UNAUTHORIZED, e).into_response();
|
||||
}
|
||||
}
|
||||
|
||||
// Idempotent timestamp check
|
||||
if let Some(ref ts_str) = init_req.timestamp {
|
||||
if let Ok(ts) = chrono_parse_to_nanos(ts_str) {
|
||||
if !state.last_set_time.set_to_greater(ts) {
|
||||
// Stale request, skip data updates
|
||||
return trigger_restore_and_respond(&state).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply env vars
|
||||
if let Some(ref vars) = init_req.env_vars {
|
||||
tracing::debug!(count = vars.len(), "setting env vars");
|
||||
for (k, v) in vars {
|
||||
state.defaults.env_vars.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set access token
|
||||
if let Some(ref token_str) = init_req.access_token {
|
||||
if !token_str.is_empty() {
|
||||
tracing::debug!("setting access token");
|
||||
let _ = state.access_token.set(token_str.as_bytes());
|
||||
} else if state.access_token.is_set() {
|
||||
tracing::debug!("clearing access token");
|
||||
state.access_token.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
// Set default user
|
||||
if let Some(ref user) = init_req.default_user {
|
||||
if !user.is_empty() {
|
||||
tracing::debug!(user = %user, "setting default user");
|
||||
state.defaults.set_user(user.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set default workdir
|
||||
if let Some(ref workdir) = init_req.default_workdir {
|
||||
if !workdir.is_empty() {
|
||||
tracing::debug!(workdir = %workdir, "setting default workdir");
|
||||
state.defaults.set_workdir(Some(workdir.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Hyperloop /etc/hosts setup
|
||||
if let Some(ref ip) = init_req.hyperloop_ip {
|
||||
let ip = ip.clone();
|
||||
let env_vars = Arc::clone(&state.defaults.env_vars);
|
||||
tokio::spawn(async move {
|
||||
setup_hyperloop(&ip, &env_vars).await;
|
||||
});
|
||||
}
|
||||
|
||||
// NFS mounts
|
||||
if let Some(ref mounts) = init_req.volume_mounts {
|
||||
for mount in mounts {
|
||||
let target = mount.nfs_target.clone();
|
||||
let path = mount.path.clone();
|
||||
tokio::spawn(async move {
|
||||
setup_nfs(&target, &path).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Set sandbox/template metadata from request body.
|
||||
if let Some(ref id) = init_req.sandbox_id {
|
||||
tracing::debug!(sandbox_id = %id, "setting sandbox ID from init request");
|
||||
// SAFETY: envd is single-threaded at init time; no concurrent env reads.
|
||||
unsafe { std::env::set_var("WRENN_SANDBOX_ID", id) };
|
||||
write_run_file(".WRENN_SANDBOX_ID", id);
|
||||
state.defaults.env_vars.insert("WRENN_SANDBOX_ID".into(), id.clone());
|
||||
}
|
||||
if let Some(ref id) = init_req.template_id {
|
||||
tracing::debug!(template_id = %id, "setting template ID from init request");
|
||||
// SAFETY: envd is single-threaded at init time; no concurrent env reads.
|
||||
unsafe { std::env::set_var("WRENN_TEMPLATE_ID", id) };
|
||||
write_run_file(".WRENN_TEMPLATE_ID", id);
|
||||
state.defaults.env_vars.insert("WRENN_TEMPLATE_ID".into(), id.clone());
|
||||
}
|
||||
|
||||
trigger_restore_and_respond(&state).await
|
||||
}
|
||||
|
||||
async fn trigger_restore_and_respond(state: &AppState) -> axum::response::Response {
|
||||
state.try_restore_recovery();
|
||||
|
||||
(
|
||||
StatusCode::NO_CONTENT,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn validate_init_access_token(state: &AppState, request_token: &str) -> Result<(), String> {
|
||||
// Fast path: matches existing token
|
||||
if state.access_token.is_set() && !request_token.is_empty() && state.access_token.equals(request_token) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// First-time setup: no existing token
|
||||
if !state.access_token.is_set() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if request_token.is_empty() {
|
||||
return Err("access token reset not authorized".into());
|
||||
}
|
||||
|
||||
Err("access token validation failed".into())
|
||||
}
|
||||
|
||||
async fn setup_hyperloop(address: &str, env_vars: &dashmap::DashMap<String, String>) {
|
||||
// Write to /etc/hosts: events.wrenn.local → address
|
||||
let entry = format!("{address} events.wrenn.local\n");
|
||||
|
||||
match std::fs::read_to_string("/etc/hosts") {
|
||||
Ok(contents) => {
|
||||
let filtered: String = contents
|
||||
.lines()
|
||||
.filter(|line| !line.contains("events.wrenn.local"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let new_contents = format!("{filtered}\n{entry}");
|
||||
if let Err(e) = std::fs::write("/etc/hosts", new_contents) {
|
||||
tracing::error!(error = %e, "failed to modify hosts file");
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to read hosts file");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
env_vars.insert(
|
||||
"WRENN_EVENTS_ADDRESS".into(),
|
||||
format!("http://{address}"),
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_nfs(nfs_target: &str, path: &str) {
|
||||
let mkdir = tokio::process::Command::new("mkdir")
|
||||
.args(["-p", path])
|
||||
.output()
|
||||
.await;
|
||||
if let Err(e) = mkdir {
|
||||
tracing::error!(error = %e, path, "nfs: mkdir failed");
|
||||
return;
|
||||
}
|
||||
|
||||
let mount = tokio::process::Command::new("mount")
|
||||
.args([
|
||||
"-v",
|
||||
"-t",
|
||||
"nfs",
|
||||
"-o",
|
||||
"mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl",
|
||||
nfs_target,
|
||||
path,
|
||||
])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match mount {
|
||||
Ok(output) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
if output.status.success() {
|
||||
tracing::info!(nfs_target, path, stdout = %stdout, "nfs: mount success");
|
||||
} else {
|
||||
tracing::error!(nfs_target, path, stderr = %stderr, "nfs: mount failed");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, nfs_target, path, "nfs: mount command failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_run_file(name: &str, value: &str) {
|
||||
let dir = std::path::Path::new("/run/wrenn");
|
||||
if let Err(e) = std::fs::create_dir_all(dir) {
|
||||
tracing::warn!(error = %e, "failed to create /run/wrenn");
|
||||
return;
|
||||
}
|
||||
if let Err(e) = std::fs::write(dir.join(name), value) {
|
||||
tracing::warn!(error = %e, name, "failed to write run file");
|
||||
}
|
||||
}
|
||||
|
||||
fn chrono_parse_to_nanos(ts: &str) -> Result<i64, ()> {
|
||||
let secs = ts.parse::<f64>().ok();
|
||||
if let Some(s) = secs {
|
||||
return Ok((s * 1_000_000_000.0) as i64);
|
||||
}
|
||||
Err(())
|
||||
}
|
||||
89
envd-rs/src/http/metrics.rs
Normal file
89
envd-rs/src/http/metrics.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Metrics {
|
||||
ts: i64,
|
||||
cpu_count: u32,
|
||||
cpu_used_pct: f32,
|
||||
mem_total_mib: u64,
|
||||
mem_used_mib: u64,
|
||||
mem_total: u64,
|
||||
mem_used: u64,
|
||||
disk_used: u64,
|
||||
disk_total: u64,
|
||||
}
|
||||
|
||||
pub async fn get_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
tracing::trace!("get metrics");
|
||||
|
||||
match collect_metrics(&state) {
|
||||
Ok(m) => (
|
||||
StatusCode::OK,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
Json(m),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to get metrics");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_metrics(state: &AppState) -> Result<Metrics, String> {
|
||||
let cpu_count = state.cpu_count();
|
||||
let cpu_used_pct_rounded = state.cpu_used_pct();
|
||||
|
||||
let mut sys = sysinfo::System::new();
|
||||
sys.refresh_memory();
|
||||
let mem_total = sys.total_memory();
|
||||
let mem_available = sys.available_memory();
|
||||
let mem_used = mem_total.saturating_sub(mem_available);
|
||||
let mem_total_mib = mem_total / 1024 / 1024;
|
||||
let mem_used_mib = mem_used / 1024 / 1024;
|
||||
|
||||
let (disk_total, disk_used) = disk_stats("/").map_err(|e| e.to_string())?;
|
||||
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
Ok(Metrics {
|
||||
ts,
|
||||
cpu_count,
|
||||
cpu_used_pct: cpu_used_pct_rounded,
|
||||
mem_total_mib,
|
||||
mem_used_mib,
|
||||
mem_total,
|
||||
mem_used,
|
||||
disk_used,
|
||||
disk_total,
|
||||
})
|
||||
}
|
||||
|
||||
fn disk_stats(path: &str) -> Result<(u64, u64), nix::Error> {
|
||||
use std::ffi::CString;
|
||||
|
||||
let c_path = CString::new(path).unwrap();
|
||||
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
|
||||
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
|
||||
if ret != 0 {
|
||||
return Err(nix::Error::last());
|
||||
}
|
||||
|
||||
let block = stat.f_bsize as u64;
|
||||
let total = stat.f_blocks * block;
|
||||
let available = stat.f_bavail * block;
|
||||
|
||||
Ok((total, total - available))
|
||||
}
|
||||
56
envd-rs/src/http/mod.rs
Normal file
56
envd-rs/src/http/mod.rs
Normal file
@ -0,0 +1,56 @@
|
||||
pub mod encoding;
|
||||
pub mod envs;
|
||||
pub mod error;
|
||||
pub mod files;
|
||||
pub mod health;
|
||||
pub mod init;
|
||||
pub mod metrics;
|
||||
pub mod snapshot;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::Router;
|
||||
use axum::routing::{get, post};
|
||||
use http::header::{CACHE_CONTROL, HeaderName};
|
||||
use http::Method;
|
||||
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
|
||||
|
||||
use crate::config::CORS_MAX_AGE;
|
||||
use crate::state::AppState;
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::any())
|
||||
.allow_methods(AllowMethods::list([
|
||||
Method::HEAD,
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
]))
|
||||
.allow_headers(AllowHeaders::any())
|
||||
.expose_headers([
|
||||
HeaderName::from_static("location"),
|
||||
CACHE_CONTROL,
|
||||
HeaderName::from_static("x-content-type-options"),
|
||||
HeaderName::from_static("connect-content-encoding"),
|
||||
HeaderName::from_static("connect-protocol-version"),
|
||||
HeaderName::from_static("grpc-encoding"),
|
||||
HeaderName::from_static("grpc-message"),
|
||||
HeaderName::from_static("grpc-status"),
|
||||
HeaderName::from_static("grpc-status-details-bin"),
|
||||
])
|
||||
.max_age(Duration::from_secs(CORS_MAX_AGE.as_secs()));
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health::get_health))
|
||||
.route("/metrics", get(metrics::get_metrics))
|
||||
.route("/envs", get(envs::get_envs))
|
||||
.route("/init", post(init::post_init))
|
||||
.route("/snapshot/prepare", post(snapshot::post_snapshot_prepare))
|
||||
.route("/files", get(files::get_files).post(files::post_files))
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
}
|
||||
49
envd-rs/src/http/snapshot.rs
Normal file
49
envd-rs/src/http/snapshot.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum::http::{StatusCode, header};
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// POST /snapshot/prepare — quiesce subsystems before VM snapshot.
|
||||
///
|
||||
/// In Rust there is no GC dance. We just:
|
||||
/// 1. Drop page cache to shrink snapshot size
|
||||
/// 2. Stop port subsystem
|
||||
/// 3. Close idle connections via conntracker
|
||||
/// 4. Set needs_restore flag
|
||||
pub async fn post_snapshot_prepare(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
// Drop page cache BEFORE blocking the reclaimer — avoids snapshotting
|
||||
// gigabytes of stale cache that inflates the memory dump on disk.
|
||||
// "1" = pagecache only (keep dentries/inodes for faster resume).
|
||||
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "1") {
|
||||
tracing::warn!(error = %e, "snapshot/prepare: drop_caches failed");
|
||||
} else {
|
||||
tracing::info!("snapshot/prepare: page cache dropped");
|
||||
}
|
||||
|
||||
// Block memory reclaimer — prevents drop_caches from running mid-freeze
|
||||
// which would corrupt kernel page table state.
|
||||
state.snapshot_in_progress.store(true, Ordering::Release);
|
||||
|
||||
if let Some(ref ps) = state.port_subsystem {
|
||||
ps.stop();
|
||||
tracing::info!("snapshot/prepare: port subsystem stopped");
|
||||
}
|
||||
|
||||
state.conn_tracker.prepare_for_snapshot();
|
||||
tracing::info!("snapshot/prepare: connections prepared");
|
||||
|
||||
// Sync filesystem buffers so dirty pages are flushed before freeze.
|
||||
unsafe { libc::sync(); }
|
||||
|
||||
state.needs_restore.store(true, Ordering::Release);
|
||||
tracing::info!("snapshot/prepare: ready for freeze");
|
||||
|
||||
(
|
||||
StatusCode::NO_CONTENT,
|
||||
[(header::CACHE_CONTROL, "no-store")],
|
||||
)
|
||||
}
|
||||
17
envd-rs/src/logging.rs
Normal file
17
envd-rs/src/logging.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
pub fn init(json: bool) {
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
if json {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer().json().flatten_event(true))
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer())
|
||||
.init();
|
||||
}
|
||||
}
|
||||
269
envd-rs/src/main.rs
Normal file
269
envd-rs/src/main.rs
Normal file
@ -0,0 +1,269 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod auth;
|
||||
mod cgroups;
|
||||
mod config;
|
||||
mod conntracker;
|
||||
mod crypto;
|
||||
mod execcontext;
|
||||
mod http;
|
||||
mod logging;
|
||||
mod permissions;
|
||||
mod port;
|
||||
mod rpc;
|
||||
mod state;
|
||||
mod util;
|
||||
|
||||
use std::fs;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use config::{DEFAULT_PORT, DEFAULT_USER, WRENN_RUN_DIR};
|
||||
use execcontext::Defaults;
|
||||
use port::subsystem::PortSubsystem;
|
||||
use state::AppState;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
const COMMIT: &str = {
|
||||
match option_env!("ENVD_COMMIT") {
|
||||
Some(c) => c,
|
||||
None => "unknown",
|
||||
}
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "envd", about = "Wrenn guest agent daemon")]
|
||||
struct Cli {
|
||||
#[arg(long, default_value_t = DEFAULT_PORT)]
|
||||
port: u16,
|
||||
|
||||
#[arg(long)]
|
||||
version: bool,
|
||||
|
||||
#[arg(long)]
|
||||
commit: bool,
|
||||
|
||||
#[arg(long = "cmd", default_value = "")]
|
||||
start_cmd: String,
|
||||
|
||||
#[arg(long = "cgroup-root", default_value = "/sys/fs/cgroup")]
|
||||
cgroup_root: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cli = Cli::parse();
|
||||
|
||||
if cli.version {
|
||||
println!("{VERSION}");
|
||||
return;
|
||||
}
|
||||
if cli.commit {
|
||||
println!("{COMMIT}");
|
||||
return;
|
||||
}
|
||||
|
||||
logging::init(true);
|
||||
|
||||
if let Err(e) = fs::create_dir_all(WRENN_RUN_DIR) {
|
||||
tracing::error!(error = %e, "failed to create wrenn run directory");
|
||||
}
|
||||
|
||||
let defaults = Defaults::new(DEFAULT_USER);
|
||||
defaults
|
||||
.env_vars
|
||||
.insert("WRENN_SANDBOX".into(), "true".into());
|
||||
|
||||
let wrenn_sandbox_path = Path::new(WRENN_RUN_DIR).join(".WRENN_SANDBOX");
|
||||
if let Err(e) = fs::write(&wrenn_sandbox_path, b"true") {
|
||||
tracing::error!(error = %e, "failed to write sandbox file");
|
||||
}
|
||||
|
||||
// Cgroup manager
|
||||
let cgroup_manager: Arc<dyn cgroups::CgroupManager> =
|
||||
match cgroups::Cgroup2Manager::new(
|
||||
&cli.cgroup_root,
|
||||
&[
|
||||
(
|
||||
cgroups::ProcessType::Pty,
|
||||
"wrenn/pty",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
(
|
||||
cgroups::ProcessType::User,
|
||||
"wrenn/user",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
(
|
||||
cgroups::ProcessType::Socat,
|
||||
"wrenn/socat",
|
||||
&[] as &[(&str, &str)],
|
||||
),
|
||||
],
|
||||
) {
|
||||
Ok(m) => {
|
||||
tracing::info!("cgroup2 manager initialized");
|
||||
Arc::new(m)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "cgroup2 init failed, using noop");
|
||||
Arc::new(cgroups::NoopCgroupManager)
|
||||
}
|
||||
};
|
||||
|
||||
// Port subsystem
|
||||
let port_subsystem = Arc::new(PortSubsystem::new(Arc::clone(&cgroup_manager)));
|
||||
port_subsystem.start();
|
||||
tracing::info!("port subsystem started");
|
||||
|
||||
let state = AppState::new(
|
||||
defaults,
|
||||
VERSION.to_string(),
|
||||
COMMIT.to_string(),
|
||||
Some(Arc::clone(&port_subsystem)),
|
||||
);
|
||||
|
||||
// Memory reclaimer — drop page cache when available memory is low.
|
||||
// The balloon device can only reclaim pages the guest kernel freed.
|
||||
// Pauses during snapshot/prepare to avoid corrupting kernel page table state.
|
||||
{
|
||||
let state_for_reclaimer = Arc::clone(&state);
|
||||
std::thread::spawn(move || memory_reclaimer(state_for_reclaimer));
|
||||
}
|
||||
|
||||
// RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port)
|
||||
let connect_router = rpc::rpc_router(Arc::clone(&state));
|
||||
|
||||
let app = http::router(Arc::clone(&state))
|
||||
.fallback_service(connect_router.into_axum_service());
|
||||
|
||||
// --cmd: spawn initial process if specified
|
||||
if !cli.start_cmd.is_empty() {
|
||||
let cmd = cli.start_cmd.clone();
|
||||
let state_clone = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
spawn_initial_command(&cmd, &state_clone);
|
||||
});
|
||||
}
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
|
||||
tracing::info!(port = cli.port, version = VERSION, commit = COMMIT, "envd starting");
|
||||
|
||||
let listener = TcpListener::bind(addr).await.expect("failed to bind");
|
||||
|
||||
let graceful = axum::serve(listener, app).with_graceful_shutdown(async move {
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM")
|
||||
.recv()
|
||||
.await;
|
||||
tracing::info!("SIGTERM received, shutting down");
|
||||
});
|
||||
|
||||
if let Err(e) = graceful.await {
|
||||
tracing::error!(error = %e, "server error");
|
||||
}
|
||||
|
||||
port_subsystem.stop();
|
||||
}
|
||||
|
||||
fn spawn_initial_command(cmd: &str, state: &AppState) {
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::process_handler;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let default_user = state.defaults.user();
|
||||
let user = match lookup_user(&default_user) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "cmd: failed to lookup user");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let home = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = state.defaults.workdir();
|
||||
let cwd = default_workdir
|
||||
.as_deref()
|
||||
.unwrap_or(&home);
|
||||
|
||||
match process_handler::spawn_process(
|
||||
cmd,
|
||||
&[],
|
||||
&HashMap::new(),
|
||||
cwd,
|
||||
None,
|
||||
false,
|
||||
Some("init-cmd".to_string()),
|
||||
&user,
|
||||
&state.defaults.env_vars,
|
||||
) {
|
||||
Ok(spawned) => {
|
||||
tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, cmd, "failed to spawn initial command");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn memory_reclaimer(state: Arc<AppState>) {
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
const CHECK_INTERVAL: Duration = Duration::from_secs(10);
|
||||
const DROP_THRESHOLD_PCT: u64 = 80;
|
||||
const RESTORE_GRACE_SECS: u64 = 30;
|
||||
|
||||
loop {
|
||||
std::thread::sleep(CHECK_INTERVAL);
|
||||
|
||||
if state.snapshot_in_progress.load(Ordering::Acquire) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip during post-restore grace period. Balloon deflation causes
|
||||
// transient high memory that resolves on its own — triggering
|
||||
// drop_caches during UFFD page fault storms makes the guest unresponsive.
|
||||
let restore_epoch = state.restore_epoch.load(Ordering::Acquire);
|
||||
if restore_epoch > 0 {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
if now.saturating_sub(restore_epoch) < RESTORE_GRACE_SECS {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sys = sysinfo::System::new();
|
||||
sys.refresh_memory();
|
||||
let total = sys.total_memory();
|
||||
let available = sys.available_memory();
|
||||
|
||||
if total == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let used_pct = ((total - available) * 100) / total;
|
||||
if used_pct >= DROP_THRESHOLD_PCT {
|
||||
if state.snapshot_in_progress.load(Ordering::Acquire) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") {
|
||||
tracing::debug!(error = %e, "drop_caches failed");
|
||||
} else {
|
||||
let mut sys2 = sysinfo::System::new();
|
||||
sys2.refresh_memory();
|
||||
let freed_mb =
|
||||
sys2.available_memory().saturating_sub(available) / (1024 * 1024);
|
||||
tracing::info!(used_pct, freed_mb, "page cache dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
2
envd-rs/src/permissions/mod.rs
Normal file
2
envd-rs/src/permissions/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod user;
|
||||
pub mod path;
|
||||
184
envd-rs/src/permissions/path.rs
Normal file
184
envd-rs/src/permissions/path.rs
Normal file
@ -0,0 +1,184 @@
|
||||
use std::fs;
|
||||
use std::os::unix::fs::chown;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use nix::unistd::{Gid, Uid};
|
||||
|
||||
fn expand_tilde(path: &str, home_dir: &str) -> Result<String, String> {
|
||||
if path.is_empty() || !path.starts_with('~') {
|
||||
return Ok(path.to_string());
|
||||
}
|
||||
if path.len() > 1 && path.as_bytes()[1] != b'/' && path.as_bytes()[1] != b'\\' {
|
||||
return Err("cannot expand user-specific home dir".into());
|
||||
}
|
||||
Ok(format!("{}{}", home_dir, &path[1..]))
|
||||
}
|
||||
|
||||
pub fn expand_and_resolve(
|
||||
path: &str,
|
||||
home_dir: &str,
|
||||
default_path: Option<&str>,
|
||||
) -> Result<String, String> {
|
||||
let path = if path.is_empty() {
|
||||
default_path.unwrap_or("").to_string()
|
||||
} else {
|
||||
path.to_string()
|
||||
};
|
||||
|
||||
let path = expand_tilde(&path, home_dir)?;
|
||||
|
||||
if Path::new(&path).is_absolute() {
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
let joined = PathBuf::from(home_dir).join(&path);
|
||||
joined
|
||||
.canonicalize()
|
||||
.or_else(|_| Ok(joined))
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> {
|
||||
let path = Path::new(path);
|
||||
let mut current = PathBuf::new();
|
||||
|
||||
for component in path.components() {
|
||||
current.push(component);
|
||||
let current_str = current.to_string_lossy();
|
||||
|
||||
if current_str == "/" {
|
||||
continue;
|
||||
}
|
||||
|
||||
match fs::metadata(¤t) {
|
||||
Ok(meta) => {
|
||||
if !meta.is_dir() {
|
||||
return Err(format!("path is a file: {current_str}"));
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
fs::create_dir(¤t)
|
||||
.map_err(|e| format!("failed to create directory {current_str}: {e}"))?;
|
||||
chown(¤t, Some(uid.as_raw()), Some(gid.as_raw()))
|
||||
.map_err(|e| format!("failed to chown directory {current_str}: {e}"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(format!("failed to stat directory {current_str}: {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// expand_tilde
|
||||
|
||||
#[test]
|
||||
fn tilde_empty_passthrough() {
|
||||
assert_eq!(expand_tilde("", "/home/u").unwrap(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_no_tilde_passthrough() {
|
||||
assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_bare() {
|
||||
assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_slash_path() {
|
||||
assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_nested() {
|
||||
assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_other_user_errors() {
|
||||
assert!(expand_tilde("~bob/foo", "/home/user").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_relative_no_tilde() {
|
||||
assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path");
|
||||
}
|
||||
|
||||
// expand_and_resolve
|
||||
|
||||
#[test]
|
||||
fn resolve_absolute_passthrough() {
|
||||
assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_empty_uses_default() {
|
||||
assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_empty_no_default_falls_back_to_home() {
|
||||
// Empty path with no default → joins "" with home_dir → returns home_dir
|
||||
let result = expand_and_resolve("", "/home", None).unwrap();
|
||||
assert_eq!(result, "/home");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_tilde_expands() {
|
||||
assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_relative_joins_home() {
|
||||
let result = expand_and_resolve("subdir", "/tmp", None).unwrap();
|
||||
// Relative path joined with home and canonicalized (or raw join on missing)
|
||||
assert!(result.starts_with("/tmp"));
|
||||
assert!(result.contains("subdir"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_tilde_other_user_errors() {
|
||||
assert!(expand_and_resolve("~bob", "/home/u", None).is_err());
|
||||
}
|
||||
|
||||
// ensure_dirs
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_creates_nested() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let path = tmp.path().join("a/b/c");
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap();
|
||||
assert!(path.is_dir());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_existing_is_ok() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_dirs_file_in_path_errors() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let file_path = tmp.path().join("afile");
|
||||
std::fs::write(&file_path, "").unwrap();
|
||||
let nested = file_path.join("subdir");
|
||||
let uid = nix::unistd::getuid();
|
||||
let gid = nix::unistd::getgid();
|
||||
let result = ensure_dirs(nested.to_str().unwrap(), uid, gid);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("path is a file"));
|
||||
}
|
||||
}
|
||||
32
envd-rs/src/permissions/user.rs
Normal file
32
envd-rs/src/permissions/user.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use nix::unistd::{Gid, Group, Uid, User};
|
||||
|
||||
pub fn lookup_user(username: &str) -> Result<User, String> {
|
||||
User::from_name(username)
|
||||
.map_err(|e| format!("error looking up user '{username}': {e}"))?
|
||||
.ok_or_else(|| format!("user '{username}' not found"))
|
||||
}
|
||||
|
||||
pub fn get_uid_gid(user: &User) -> (Uid, Gid) {
|
||||
(user.uid, user.gid)
|
||||
}
|
||||
|
||||
pub fn get_user_groups(user: &User) -> Vec<Gid> {
|
||||
let c_name = std::ffi::CString::new(user.name.as_str()).unwrap();
|
||||
nix::unistd::getgrouplist(&c_name, user.gid).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn lookup_username_by_uid(uid: Uid) -> String {
|
||||
User::from_uid(uid)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|u| u.name)
|
||||
.unwrap_or_else(|| uid.to_string())
|
||||
}
|
||||
|
||||
pub fn lookup_groupname_by_gid(gid: Gid) -> String {
|
||||
Group::from_gid(gid)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|g| g.name)
|
||||
.unwrap_or_else(|| gid.to_string())
|
||||
}
|
||||
260
envd-rs/src/port/conn.rs
Normal file
260
envd-rs/src/port/conn.rs
Normal file
@ -0,0 +1,260 @@
|
||||
use std::io::{self, BufRead};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnStat {
|
||||
pub local_ip: String,
|
||||
pub local_port: u32,
|
||||
pub status: String,
|
||||
pub family: u32,
|
||||
pub inode: u64,
|
||||
}
|
||||
|
||||
fn tcp_state_name(hex: &str) -> &'static str {
|
||||
match hex {
|
||||
"01" => "ESTABLISHED",
|
||||
"02" => "SYN_SENT",
|
||||
"03" => "SYN_RECV",
|
||||
"04" => "FIN_WAIT1",
|
||||
"05" => "FIN_WAIT2",
|
||||
"06" => "TIME_WAIT",
|
||||
"07" => "CLOSE",
|
||||
"08" => "CLOSE_WAIT",
|
||||
"09" => "LAST_ACK",
|
||||
"0A" => "LISTEN",
|
||||
"0B" => "CLOSING",
|
||||
_ => "UNKNOWN",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_tcp_connections() -> Vec<ConnStat> {
|
||||
let mut conns = Vec::new();
|
||||
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp", libc::AF_INET as u32) {
|
||||
conns.extend(c);
|
||||
}
|
||||
if let Ok(c) = parse_proc_net_tcp("/proc/net/tcp6", libc::AF_INET6 as u32) {
|
||||
conns.extend(c);
|
||||
}
|
||||
conns
|
||||
}
|
||||
|
||||
fn parse_proc_net_tcp(path: &str, family: u32) -> io::Result<Vec<ConnStat>> {
|
||||
let file = std::fs::File::open(path)?;
|
||||
let reader = io::BufReader::new(file);
|
||||
let mut conns = Vec::new();
|
||||
let mut first = true;
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if first {
|
||||
first = false;
|
||||
continue;
|
||||
}
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let fields: Vec<&str> = line.split_whitespace().collect();
|
||||
if fields.len() < 10 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (ip, port) = match parse_hex_addr(fields[1], family) {
|
||||
Some(v) => v,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let state = tcp_state_name(fields[3]);
|
||||
|
||||
let inode: u64 = match fields[9].parse() {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
conns.push(ConnStat {
|
||||
local_ip: ip,
|
||||
local_port: port,
|
||||
status: state.to_string(),
|
||||
family,
|
||||
inode,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(conns)
|
||||
}
|
||||
|
||||
fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> {
|
||||
let (ip_hex, port_hex) = s.split_once(':')?;
|
||||
let port = u32::from_str_radix(port_hex, 16).ok()?;
|
||||
let ip_bytes = hex::decode(ip_hex).ok()?;
|
||||
|
||||
let ip_str = if family == libc::AF_INET as u32 {
|
||||
if ip_bytes.len() != 4 {
|
||||
return None;
|
||||
}
|
||||
format!("{}.{}.{}.{}", ip_bytes[3], ip_bytes[2], ip_bytes[1], ip_bytes[0])
|
||||
} else {
|
||||
if ip_bytes.len() != 16 {
|
||||
return None;
|
||||
}
|
||||
let mut octets = [0u8; 16];
|
||||
for i in 0..4 {
|
||||
octets[i * 4] = ip_bytes[i * 4 + 3];
|
||||
octets[i * 4 + 1] = ip_bytes[i * 4 + 2];
|
||||
octets[i * 4 + 2] = ip_bytes[i * 4 + 1];
|
||||
octets[i * 4 + 3] = ip_bytes[i * 4];
|
||||
}
|
||||
let addr = std::net::Ipv6Addr::from(octets);
|
||||
addr.to_string()
|
||||
};
|
||||
|
||||
Some((ip_str, port))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
// tcp_state_name
|
||||
|
||||
#[test]
|
||||
fn state_all_known_codes() {
|
||||
assert_eq!(tcp_state_name("01"), "ESTABLISHED");
|
||||
assert_eq!(tcp_state_name("02"), "SYN_SENT");
|
||||
assert_eq!(tcp_state_name("03"), "SYN_RECV");
|
||||
assert_eq!(tcp_state_name("04"), "FIN_WAIT1");
|
||||
assert_eq!(tcp_state_name("05"), "FIN_WAIT2");
|
||||
assert_eq!(tcp_state_name("06"), "TIME_WAIT");
|
||||
assert_eq!(tcp_state_name("07"), "CLOSE");
|
||||
assert_eq!(tcp_state_name("08"), "CLOSE_WAIT");
|
||||
assert_eq!(tcp_state_name("09"), "LAST_ACK");
|
||||
assert_eq!(tcp_state_name("0A"), "LISTEN");
|
||||
assert_eq!(tcp_state_name("0B"), "CLOSING");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_unknown_code() {
|
||||
assert_eq!(tcp_state_name("FF"), "UNKNOWN");
|
||||
assert_eq!(tcp_state_name("00"), "UNKNOWN");
|
||||
}
|
||||
|
||||
// parse_hex_addr
|
||||
|
||||
#[test]
|
||||
fn ipv4_localhost() {
|
||||
let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "127.0.0.1");
|
||||
assert_eq!(port, 80);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_any() {
|
||||
let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "0.0.0.0");
|
||||
assert_eq!(port, 53);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_real_address() {
|
||||
// 192.168.1.1 in little-endian = 0101A8C0
|
||||
let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(ip, "192.168.1.1");
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv4_wrong_byte_count_returns_none() {
|
||||
assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_hex_returns_none() {
|
||||
assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_colon_returns_none() {
|
||||
assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv6_loopback() {
|
||||
// ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000
|
||||
let (ip, port) = parse_hex_addr(
|
||||
"00000000000000000000000001000000:0035",
|
||||
libc::AF_INET6 as u32,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(ip, "::1");
|
||||
assert_eq!(port, 53);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv6_wrong_byte_count_returns_none() {
|
||||
assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none());
|
||||
}
|
||||
|
||||
// parse_proc_net_tcp
|
||||
|
||||
fn write_tcp_file(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_file() {
|
||||
let f = write_tcp_file(
|
||||
" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n",
|
||||
);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert!(conns.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_single_entry() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 1);
|
||||
assert_eq!(conns[0].local_ip, "127.0.0.1");
|
||||
assert_eq!(conns[0].local_port, 80);
|
||||
assert_eq!(conns[0].status, "LISTEN");
|
||||
assert_eq!(conns[0].inode, 12345);
|
||||
assert_eq!(conns[0].family, libc::AF_INET as u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skips_malformed_rows() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000
|
||||
bad line
|
||||
1: short\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_entries() {
|
||||
let content = "\
|
||||
sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
|
||||
0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000
|
||||
1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n";
|
||||
let f = write_tcp_file(content);
|
||||
let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap();
|
||||
assert_eq!(conns.len(), 2);
|
||||
assert_eq!(conns[0].local_port, 80);
|
||||
assert_eq!(conns[1].local_port, 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_nonexistent_file_errors() {
|
||||
assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err());
|
||||
}
|
||||
}
|
||||
181
envd-rs/src/port/forwarder.rs
Normal file
181
envd-rs/src/port/forwarder.rs
Normal file
@ -0,0 +1,181 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::cgroups::{CgroupManager, ProcessType};
|
||||
|
||||
use super::conn::ConnStat;
|
||||
|
||||
const DEFAULT_GATEWAY_IP: &str = "169.254.0.21";
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum PortState {
|
||||
Forward,
|
||||
Delete,
|
||||
}
|
||||
|
||||
struct PortToForward {
|
||||
pid: Option<u32>,
|
||||
inode: u64,
|
||||
family: u32,
|
||||
state: PortState,
|
||||
port: u32,
|
||||
}
|
||||
|
||||
fn family_to_ip_version(family: u32) -> u32 {
|
||||
if family == libc::AF_INET as u32 {
|
||||
4
|
||||
} else if family == libc::AF_INET6 as u32 {
|
||||
6
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Forwarder {
|
||||
cgroup_manager: Arc<dyn CgroupManager>,
|
||||
ports: HashMap<String, PortToForward>,
|
||||
source_ip: String,
|
||||
}
|
||||
|
||||
impl Forwarder {
|
||||
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
|
||||
Self {
|
||||
cgroup_manager,
|
||||
ports: HashMap::new(),
|
||||
source_ip: DEFAULT_GATEWAY_IP.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_forwarding(
|
||||
&mut self,
|
||||
mut rx: mpsc::Receiver<Vec<ConnStat>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
self.stop_all();
|
||||
return;
|
||||
}
|
||||
msg = rx.recv() => {
|
||||
match msg {
|
||||
Some(conns) => self.process_scan(conns),
|
||||
None => {
|
||||
self.stop_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_scan(&mut self, conns: Vec<ConnStat>) {
|
||||
for ptf in self.ports.values_mut() {
|
||||
ptf.state = PortState::Delete;
|
||||
}
|
||||
|
||||
for conn in &conns {
|
||||
let key = format!("{}-{}", conn.inode, conn.local_port);
|
||||
if let Some(ptf) = self.ports.get_mut(&key) {
|
||||
ptf.state = PortState::Forward;
|
||||
} else {
|
||||
tracing::debug!(
|
||||
ip = %conn.local_ip,
|
||||
port = conn.local_port,
|
||||
family = family_to_ip_version(conn.family),
|
||||
"detected new port on localhost"
|
||||
);
|
||||
let mut ptf = PortToForward {
|
||||
pid: None,
|
||||
inode: conn.inode,
|
||||
family: family_to_ip_version(conn.family),
|
||||
state: PortState::Forward,
|
||||
port: conn.local_port,
|
||||
};
|
||||
self.start_port_forwarding(&mut ptf);
|
||||
self.ports.insert(key, ptf);
|
||||
}
|
||||
}
|
||||
|
||||
let to_stop: Vec<String> = self
|
||||
.ports
|
||||
.iter()
|
||||
.filter(|(_, v)| v.state == PortState::Delete)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
|
||||
for key in to_stop {
|
||||
if let Some(ptf) = self.ports.get(&key) {
|
||||
stop_port_forwarding(ptf);
|
||||
}
|
||||
self.ports.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
fn start_port_forwarding(&self, ptf: &mut PortToForward) {
|
||||
let listen_arg = format!(
|
||||
"TCP4-LISTEN:{},bind={},reuseaddr,fork",
|
||||
ptf.port, self.source_ip
|
||||
);
|
||||
let connect_arg = format!("TCP{}:localhost:{}", ptf.family, ptf.port);
|
||||
|
||||
let mut cmd = Command::new("socat");
|
||||
cmd.args(["-d", "-d", "-d", &listen_arg, &connect_arg]);
|
||||
|
||||
unsafe {
|
||||
let cgroup_fd = self.cgroup_manager.get_fd(ProcessType::Socat);
|
||||
cmd.pre_exec(move || {
|
||||
libc::setpgid(0, 0);
|
||||
if let Some(fd) = cgroup_fd {
|
||||
let pid_str = format!("{}", libc::getpid());
|
||||
let tasks_path = format!("/proc/self/fd/{}/cgroup.procs", fd);
|
||||
let _ = std::fs::write(&tasks_path, pid_str.as_bytes());
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
port = ptf.port,
|
||||
inode = ptf.inode,
|
||||
family = ptf.family,
|
||||
source_ip = %self.source_ip,
|
||||
"starting port forwarding"
|
||||
);
|
||||
|
||||
match cmd.spawn() {
|
||||
Ok(child) => {
|
||||
ptf.pid = Some(child.id());
|
||||
std::thread::spawn(move || {
|
||||
let mut child = child;
|
||||
let _ = child.wait();
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, port = ptf.port, "failed to start socat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_all(&mut self) {
|
||||
for ptf in self.ports.values() {
|
||||
stop_port_forwarding(ptf);
|
||||
}
|
||||
self.ports.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_port_forwarding(ptf: &PortToForward) {
|
||||
if let Some(pid) = ptf.pid {
|
||||
tracing::debug!(port = ptf.port, pid, "stopping port forwarding");
|
||||
unsafe {
|
||||
libc::kill(-(pid as i32), libc::SIGKILL);
|
||||
}
|
||||
}
|
||||
}
|
||||
4
envd-rs/src/port/mod.rs
Normal file
4
envd-rs/src/port/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod conn;
|
||||
pub mod forwarder;
|
||||
pub mod scanner;
|
||||
pub mod subsystem;
|
||||
81
envd-rs/src/port/scanner.rs
Normal file
81
envd-rs/src/port/scanner.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::conn::{ConnStat, read_tcp_connections};
|
||||
|
||||
pub struct ScannerFilter {
|
||||
pub ips: Vec<String>,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
impl ScannerFilter {
|
||||
pub fn matches(&self, conn: &ConnStat) -> bool {
|
||||
if self.state.is_empty() && self.ips.is_empty() {
|
||||
return false;
|
||||
}
|
||||
self.ips.contains(&conn.local_ip) && self.state == conn.status
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ScannerSubscriber {
|
||||
pub tx: mpsc::Sender<Vec<ConnStat>>,
|
||||
pub filter: Option<ScannerFilter>,
|
||||
}
|
||||
|
||||
pub struct Scanner {
|
||||
period: Duration,
|
||||
subs: RwLock<Vec<(String, Arc<ScannerSubscriber>)>>,
|
||||
}
|
||||
|
||||
impl Scanner {
|
||||
pub fn new(period: Duration) -> Self {
|
||||
Self {
|
||||
period,
|
||||
subs: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_subscriber(
|
||||
&self,
|
||||
id: &str,
|
||||
filter: Option<ScannerFilter>,
|
||||
) -> mpsc::Receiver<Vec<ConnStat>> {
|
||||
let (tx, rx) = mpsc::channel(4);
|
||||
let sub = Arc::new(ScannerSubscriber { tx, filter });
|
||||
let mut subs = self.subs.write().unwrap();
|
||||
subs.push((id.to_string(), sub));
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn remove_subscriber(&self, id: &str) {
|
||||
let mut subs = self.subs.write().unwrap();
|
||||
subs.retain(|(sid, _)| sid != id);
|
||||
}
|
||||
|
||||
pub async fn scan_and_broadcast(&self, cancel: CancellationToken) {
|
||||
loop {
|
||||
let conns = tokio::task::spawn_blocking(read_tcp_connections)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
{
|
||||
let subs = self.subs.read().unwrap();
|
||||
for (_, sub) in subs.iter() {
|
||||
let payload = match &sub.filter {
|
||||
Some(f) => conns.iter().filter(|c| f.matches(c)).cloned().collect(),
|
||||
None => conns.clone(),
|
||||
};
|
||||
let _ = sub.tx.try_send(payload);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => return,
|
||||
_ = tokio::time::sleep(self.period) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
78
envd-rs/src/port/subsystem.rs
Normal file
78
envd-rs/src/port/subsystem.rs
Normal file
@ -0,0 +1,78 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::cgroups::CgroupManager;
|
||||
use crate::config::PORT_SCANNER_INTERVAL;
|
||||
|
||||
use super::forwarder::Forwarder;
|
||||
use super::scanner::{Scanner, ScannerFilter};
|
||||
|
||||
pub struct PortSubsystem {
|
||||
cgroup_manager: Arc<dyn CgroupManager>,
|
||||
cancel: std::sync::Mutex<Option<CancellationToken>>,
|
||||
}
|
||||
|
||||
impl PortSubsystem {
|
||||
pub fn new(cgroup_manager: Arc<dyn CgroupManager>) -> Self {
|
||||
Self {
|
||||
cgroup_manager,
|
||||
cancel: std::sync::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self) {
|
||||
let mut guard = self.cancel.lock().unwrap();
|
||||
if guard.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
*guard = Some(cancel.clone());
|
||||
drop(guard);
|
||||
|
||||
let cgroup_manager = Arc::clone(&self.cgroup_manager);
|
||||
let cancel_scanner = cancel.clone();
|
||||
let cancel_forwarder = cancel.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let scanner = Arc::new(Scanner::new(PORT_SCANNER_INTERVAL));
|
||||
let rx = scanner.add_subscriber(
|
||||
"port-forwarder",
|
||||
Some(ScannerFilter {
|
||||
ips: vec![
|
||||
"127.0.0.1".to_string(),
|
||||
"localhost".to_string(),
|
||||
"::1".to_string(),
|
||||
],
|
||||
state: "LISTEN".to_string(),
|
||||
}),
|
||||
);
|
||||
|
||||
let scanner_clone = Arc::clone(&scanner);
|
||||
|
||||
let scanner_handle = tokio::spawn(async move {
|
||||
scanner_clone.scan_and_broadcast(cancel_scanner).await;
|
||||
});
|
||||
|
||||
let forwarder_handle = tokio::spawn(async move {
|
||||
let mut forwarder = Forwarder::new(cgroup_manager);
|
||||
forwarder.start_forwarding(rx, cancel_forwarder).await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(scanner_handle, forwarder_handle);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
let mut guard = self.cancel.lock().unwrap();
|
||||
if let Some(cancel) = guard.take() {
|
||||
cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn restart(&self) {
|
||||
self.stop();
|
||||
self.start();
|
||||
}
|
||||
}
|
||||
231
envd-rs/src/rpc/entry.rs
Normal file
231
envd-rs/src/rpc/entry.rs
Normal file
@ -0,0 +1,231 @@
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
use std::path::Path;
|
||||
|
||||
use connectrpc::{ConnectError, ErrorCode};
|
||||
|
||||
use crate::permissions::user::{lookup_groupname_by_gid, lookup_username_by_uid};
|
||||
use crate::rpc::pb::filesystem::{EntryInfo, FileType};
|
||||
use nix::unistd::{Gid, Uid};
|
||||
|
||||
const NFS_SUPER_MAGIC: i64 = 0x6969;
|
||||
const CIFS_MAGIC: i64 = 0xFF534D42;
|
||||
const SMB_SUPER_MAGIC: i64 = 0x517B;
|
||||
const SMB2_MAGIC_NUMBER: i64 = 0xFE534D42;
|
||||
const FUSE_SUPER_MAGIC: i64 = 0x65735546;
|
||||
|
||||
pub fn is_network_mount(path: &str) -> Result<bool, String> {
|
||||
let c_path = std::ffi::CString::new(path).map_err(|e| e.to_string())?;
|
||||
let mut stat: libc::statfs = unsafe { std::mem::zeroed() };
|
||||
let ret = unsafe { libc::statfs(c_path.as_ptr(), &mut stat) };
|
||||
if ret != 0 {
|
||||
return Err(format!(
|
||||
"statfs {path}: {}",
|
||||
std::io::Error::last_os_error()
|
||||
));
|
||||
}
|
||||
let fs_type = stat.f_type as i64;
|
||||
Ok(matches!(
|
||||
fs_type,
|
||||
NFS_SUPER_MAGIC | CIFS_MAGIC | SMB_SUPER_MAGIC | SMB2_MAGIC_NUMBER | FUSE_SUPER_MAGIC
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_entry_info(path: &str) -> Result<EntryInfo, ConnectError> {
|
||||
let p = Path::new(path);
|
||||
|
||||
let lstat = std::fs::symlink_metadata(p).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("file not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
|
||||
}
|
||||
})?;
|
||||
|
||||
let is_symlink = lstat.file_type().is_symlink();
|
||||
|
||||
let (file_type, mode, symlink_target) = if is_symlink {
|
||||
let target = std::fs::canonicalize(p)
|
||||
.map(|t| t.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| path.to_string());
|
||||
|
||||
let target_type = match std::fs::metadata(p) {
|
||||
Ok(meta) => meta_to_file_type(&meta),
|
||||
Err(_) => FileType::FILE_TYPE_UNSPECIFIED,
|
||||
};
|
||||
|
||||
let target_mode = std::fs::metadata(p)
|
||||
.map(|m| m.mode() & 0o7777)
|
||||
.unwrap_or(0);
|
||||
|
||||
(target_type, target_mode, Some(target))
|
||||
} else {
|
||||
let ft = meta_to_file_type(&lstat);
|
||||
let mode = lstat.mode() & 0o7777;
|
||||
(ft, mode, None)
|
||||
};
|
||||
|
||||
let uid = lstat.uid();
|
||||
let gid = lstat.gid();
|
||||
let owner = lookup_username_by_uid(Uid::from_raw(uid));
|
||||
let group = lookup_groupname_by_gid(Gid::from_raw(gid));
|
||||
|
||||
let modified_time = {
|
||||
let mtime_sec = lstat.mtime();
|
||||
let mtime_nsec = lstat.mtime_nsec() as i32;
|
||||
if mtime_sec == 0 && mtime_nsec == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(buffa_types::google::protobuf::Timestamp {
|
||||
seconds: mtime_sec,
|
||||
nanos: mtime_nsec,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let name = p
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let permissions = format_permissions(lstat.mode());
|
||||
|
||||
Ok(EntryInfo {
|
||||
name,
|
||||
r#type: buffa::EnumValue::Known(file_type),
|
||||
path: path.to_string(),
|
||||
size: lstat.len() as i64,
|
||||
mode,
|
||||
permissions,
|
||||
owner,
|
||||
group,
|
||||
modified_time: modified_time.into(),
|
||||
symlink_target: symlink_target,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn meta_to_file_type(meta: &std::fs::Metadata) -> FileType {
|
||||
if meta.is_file() {
|
||||
FileType::FILE_TYPE_FILE
|
||||
} else if meta.is_dir() {
|
||||
FileType::FILE_TYPE_DIRECTORY
|
||||
} else if meta.file_type().is_symlink() {
|
||||
FileType::FILE_TYPE_SYMLINK
|
||||
} else {
|
||||
FileType::FILE_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
fn format_permissions(mode: u32) -> String {
|
||||
let file_type = match mode & libc::S_IFMT {
|
||||
libc::S_IFDIR => 'd',
|
||||
libc::S_IFLNK => 'L',
|
||||
libc::S_IFREG => '-',
|
||||
libc::S_IFBLK => 'b',
|
||||
libc::S_IFCHR => 'c',
|
||||
libc::S_IFIFO => 'p',
|
||||
libc::S_IFSOCK => 'S',
|
||||
_ => '?',
|
||||
};
|
||||
|
||||
let perms = mode & 0o777;
|
||||
let mut s = String::with_capacity(10);
|
||||
s.push(file_type);
|
||||
for shift in [6, 3, 0] {
|
||||
let bits = (perms >> shift) & 7;
|
||||
s.push(if bits & 4 != 0 { 'r' } else { '-' });
|
||||
s.push(if bits & 2 != 0 { 'w' } else { '-' });
|
||||
s.push(if bits & 1 != 0 { 'x' } else { '-' });
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// format_permissions
|
||||
|
||||
#[test]
|
||||
fn regular_file_755() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn directory_755() {
|
||||
assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symlink_777() {
|
||||
assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn regular_file_000() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn regular_file_644() {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_device() {
|
||||
assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn char_device() {
|
||||
assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fifo() {
|
||||
assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn socket() {
|
||||
assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_type() {
|
||||
assert_eq!(format_permissions(0o755), "?rwxr-xr-x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn setuid_in_mode_only_affects_lower_bits() {
|
||||
// setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755
|
||||
assert_eq!(
|
||||
format_permissions(libc::S_IFREG | 0o4755),
|
||||
format_permissions(libc::S_IFREG | 0o755),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_always_10_chars() {
|
||||
for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] {
|
||||
assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// meta_to_file_type — needs real filesystem
|
||||
|
||||
#[test]
|
||||
fn meta_regular_file() {
|
||||
let f = tempfile::NamedTempFile::new().unwrap();
|
||||
let meta = std::fs::metadata(f.path()).unwrap();
|
||||
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn meta_directory() {
|
||||
let d = tempfile::TempDir::new().unwrap();
|
||||
let meta = std::fs::metadata(d.path()).unwrap();
|
||||
assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY);
|
||||
}
|
||||
}
|
||||
402
envd-rs/src/rpc/filesystem_service.rs
Normal file
402
envd-rs/src/rpc/filesystem_service.rs
Normal file
@ -0,0 +1,402 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use connectrpc::{ConnectError, Context, ErrorCode};
|
||||
use dashmap::DashMap;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::permissions::path::{ensure_dirs, expand_and_resolve};
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::entry::build_entry_info;
|
||||
use crate::rpc::pb::filesystem::*;
|
||||
use crate::state::AppState;
|
||||
|
||||
pub struct FilesystemServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
watchers: DashMap<String, WatcherHandle>,
|
||||
}
|
||||
|
||||
struct WatcherHandle {
|
||||
events: Arc<Mutex<Vec<FilesystemEvent>>>,
|
||||
_watcher: notify::RecommendedWatcher,
|
||||
}
|
||||
|
||||
impl FilesystemServiceImpl {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
watchers: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(&self, path: &str, ctx: &Context) -> Result<String, ConnectError> {
|
||||
let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user = lookup_user(&username).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}"))
|
||||
})?;
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let default_workdir = self.state.defaults.workdir();
|
||||
|
||||
expand_and_resolve(path, &home_dir, default_workdir.as_deref())
|
||||
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_username(ctx: &Context) -> Option<String> {
|
||||
ctx.extensions.get::<AuthUser>().map(|u| u.0.clone())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthUser(pub String);
|
||||
|
||||
impl Filesystem for FilesystemServiceImpl {
|
||||
async fn stat(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<StatRequestView<'static>>,
|
||||
) -> Result<(StatResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
let entry = build_entry_info(&path)?;
|
||||
Ok((
|
||||
StatResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn make_dir(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<MakeDirRequestView<'static>>,
|
||||
) -> Result<(MakeDirResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
match std::fs::metadata(&path) {
|
||||
Ok(meta) => {
|
||||
if meta.is_dir() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::AlreadyExists,
|
||||
format!("directory already exists: {path}"),
|
||||
));
|
||||
}
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("path exists but is not a directory: {path}"),
|
||||
));
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error getting file info: {e}"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
ensure_dirs(&path, user.uid, user.gid)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
let entry = build_entry_info(&path)?;
|
||||
Ok((
|
||||
MakeDirResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn r#move(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<MoveRequestView<'static>>,
|
||||
) -> Result<(MoveResponse, Context), ConnectError> {
|
||||
let source = self.resolve_path(request.source, &ctx)?;
|
||||
let destination = self.resolve_path(request.destination, &ctx)?;
|
||||
|
||||
let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user());
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
if let Some(parent) = Path::new(&destination).parent() {
|
||||
ensure_dirs(&parent.to_string_lossy(), user.uid, user.gid)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
}
|
||||
|
||||
std::fs::rename(&source, &destination).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("source not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error renaming: {e}"))
|
||||
}
|
||||
})?;
|
||||
|
||||
let entry = build_entry_info(&destination)?;
|
||||
Ok((
|
||||
MoveResponse {
|
||||
entry: entry.into(),
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn list_dir(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<ListDirRequestView<'static>>,
|
||||
) -> Result<(ListDirResponse, Context), ConnectError> {
|
||||
let mut depth = request.depth as usize;
|
||||
if depth == 0 {
|
||||
depth = 1;
|
||||
}
|
||||
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
let resolved = std::fs::canonicalize(&path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ConnectError::new(ErrorCode::NotFound, format!("path not found: {e}"))
|
||||
} else {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error resolving path: {e}"))
|
||||
}
|
||||
})?;
|
||||
let resolved_str = resolved.to_string_lossy().to_string();
|
||||
|
||||
let meta = std::fs::metadata(&resolved).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error getting file info: {e}"))
|
||||
})?;
|
||||
if !meta.is_dir() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("path is not a directory: {path}"),
|
||||
));
|
||||
}
|
||||
|
||||
let entries = walk_dir(&path, &resolved_str, depth)?;
|
||||
Ok((
|
||||
ListDirResponse {
|
||||
entries,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn remove(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<RemoveRequestView<'static>>,
|
||||
) -> Result<(RemoveResponse, Context), ConnectError> {
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
|
||||
if let Err(e1) = std::fs::remove_dir_all(&path) {
|
||||
if let Err(e2) = std::fs::remove_file(&path) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error removing: {e1}; also tried as file: {e2}"),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((RemoveResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn watch_dir(
|
||||
&self,
|
||||
_ctx: Context,
|
||||
_request: buffa::view::OwnedView<WatchDirRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<WatchDirResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
Err(ConnectError::new(
|
||||
ErrorCode::Unimplemented,
|
||||
"watch_dir streaming not yet implemented",
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_watcher(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<CreateWatcherRequestView<'static>>,
|
||||
) -> Result<(CreateWatcherResponse, Context), ConnectError> {
|
||||
use notify::{RecursiveMode, Watcher};
|
||||
|
||||
let path = self.resolve_path(request.path, &ctx)?;
|
||||
let recursive = request.recursive;
|
||||
|
||||
if let Ok(true) = crate::rpc::entry::is_network_mount(&path) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"watching network mounts is not supported",
|
||||
));
|
||||
}
|
||||
|
||||
let watcher_id = simple_id();
|
||||
let events: Arc<Mutex<Vec<FilesystemEvent>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_cb = Arc::clone(&events);
|
||||
|
||||
let mut watcher = notify::recommended_watcher(
|
||||
move |res: Result<notify::Event, notify::Error>| {
|
||||
if let Ok(event) = res {
|
||||
let event_type = match event.kind {
|
||||
notify::EventKind::Create(_) => EventType::EVENT_TYPE_CREATE,
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Data(_)) => {
|
||||
EventType::EVENT_TYPE_WRITE
|
||||
}
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Metadata(_)) => {
|
||||
EventType::EVENT_TYPE_CHMOD
|
||||
}
|
||||
notify::EventKind::Remove(_) => EventType::EVENT_TYPE_REMOVE,
|
||||
notify::EventKind::Modify(notify::event::ModifyKind::Name(_)) => {
|
||||
EventType::EVENT_TYPE_RENAME
|
||||
}
|
||||
_ => return,
|
||||
};
|
||||
|
||||
for p in &event.paths {
|
||||
if let Ok(mut guard) = events_cb.lock() {
|
||||
guard.push(FilesystemEvent {
|
||||
name: p.to_string_lossy().to_string(),
|
||||
r#type: buffa::EnumValue::Known(event_type),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("failed to create watcher: {e}"))
|
||||
})?;
|
||||
|
||||
let mode = if recursive {
|
||||
RecursiveMode::Recursive
|
||||
} else {
|
||||
RecursiveMode::NonRecursive
|
||||
};
|
||||
|
||||
watcher.watch(Path::new(&path), mode).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("failed to watch path: {e}"))
|
||||
})?;
|
||||
|
||||
self.watchers.insert(
|
||||
watcher_id.clone(),
|
||||
WatcherHandle {
|
||||
events,
|
||||
_watcher: watcher,
|
||||
},
|
||||
);
|
||||
|
||||
Ok((
|
||||
CreateWatcherResponse {
|
||||
watcher_id,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_watcher_events(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<GetWatcherEventsRequestView<'static>>,
|
||||
) -> Result<(GetWatcherEventsResponse, Context), ConnectError> {
|
||||
let watcher_id: &str = request.watcher_id;
|
||||
let handle = self.watchers.get(watcher_id).ok_or_else(|| {
|
||||
ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("watcher not found: {watcher_id}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let events = {
|
||||
let mut guard = handle.events.lock().unwrap();
|
||||
std::mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
Ok((
|
||||
GetWatcherEventsResponse {
|
||||
events,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn remove_watcher(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<RemoveWatcherRequestView<'static>>,
|
||||
) -> Result<(RemoveWatcherResponse, Context), ConnectError> {
|
||||
let watcher_id: &str = request.watcher_id;
|
||||
self.watchers.remove(watcher_id);
|
||||
Ok((RemoveWatcherResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
fn walk_dir(
|
||||
requested_path: &str,
|
||||
resolved_path: &str,
|
||||
depth: usize,
|
||||
) -> Result<Vec<EntryInfo>, ConnectError> {
|
||||
let mut entries = Vec::new();
|
||||
let base = Path::new(resolved_path);
|
||||
|
||||
for result in walkdir::WalkDir::new(resolved_path)
|
||||
.min_depth(1)
|
||||
.max_depth(depth)
|
||||
.follow_links(false)
|
||||
{
|
||||
let dir_entry = match result {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
if e.io_error()
|
||||
.is_some_and(|io| io.kind() == std::io::ErrorKind::NotFound)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!("error reading directory: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let entry_path = dir_entry.path();
|
||||
let mut entry = match build_entry_info(&entry_path.to_string_lossy()) {
|
||||
Ok(e) => e,
|
||||
Err(e) if e.code == ErrorCode::NotFound => continue,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
if let Ok(rel) = entry_path.strip_prefix(base) {
|
||||
let remapped = PathBuf::from(requested_path).join(rel);
|
||||
entry.path = remapped.to_string_lossy().to_string();
|
||||
}
|
||||
|
||||
entries.push(entry);
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
fn simple_id() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
format!("w-{nanos:x}")
|
||||
}
|
||||
26
envd-rs/src/rpc/mod.rs
Normal file
26
envd-rs/src/rpc/mod.rs
Normal file
@ -0,0 +1,26 @@
|
||||
pub mod pb;
|
||||
pub mod entry;
|
||||
pub mod process_handler;
|
||||
pub mod process_service;
|
||||
pub mod filesystem_service;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::process_service::ProcessServiceImpl;
|
||||
use crate::rpc::filesystem_service::FilesystemServiceImpl;
|
||||
use crate::state::AppState;
|
||||
|
||||
use pb::process::ProcessExt;
|
||||
use pb::filesystem::FilesystemExt;
|
||||
|
||||
/// Build the connect-rust Router with both RPC services registered.
|
||||
pub fn rpc_router(state: Arc<AppState>) -> connectrpc::Router {
|
||||
let process_svc = Arc::new(ProcessServiceImpl::new(Arc::clone(&state)));
|
||||
let filesystem_svc = Arc::new(FilesystemServiceImpl::new(Arc::clone(&state)));
|
||||
|
||||
let router = connectrpc::Router::new();
|
||||
let router = process_svc.register(router);
|
||||
let router = filesystem_svc.register(router);
|
||||
|
||||
router
|
||||
}
|
||||
10
envd-rs/src/rpc/pb.rs
Normal file
10
envd-rs/src/rpc/pb.rs
Normal file
@ -0,0 +1,10 @@
|
||||
#![allow(dead_code, non_camel_case_types, unused_imports, clippy::derivable_impls)]
|
||||
|
||||
use ::buffa;
|
||||
use ::buffa_types;
|
||||
use ::connectrpc;
|
||||
use ::futures;
|
||||
use ::http_body;
|
||||
use ::serde;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/_connectrpc.rs"));
|
||||
419
envd-rs/src/rpc/process_handler.rs
Normal file
419
envd-rs/src/rpc/process_handler.rs
Normal file
@ -0,0 +1,419 @@
|
||||
use std::io::Read;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::process::Stdio;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use connectrpc::{ConnectError, ErrorCode};
|
||||
use nix::pty::{openpty, Winsize};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::rpc::pb::process::*;
|
||||
|
||||
const STD_CHUNK_SIZE: usize = 32768;
|
||||
const PTY_CHUNK_SIZE: usize = 16384;
|
||||
const BROADCAST_CAPACITY: usize = 4096;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum DataEvent {
|
||||
Stdout(Vec<u8>),
|
||||
Stderr(Vec<u8>),
|
||||
Pty(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EndEvent {
|
||||
pub exit_code: i32,
|
||||
pub exited: bool,
|
||||
pub status: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
pub struct ProcessHandle {
|
||||
pub config: ProcessConfig,
|
||||
pub tag: Option<String>,
|
||||
pub pid: u32,
|
||||
|
||||
data_tx: broadcast::Sender<DataEvent>,
|
||||
end_tx: broadcast::Sender<EndEvent>,
|
||||
ended: Mutex<Option<EndEvent>>,
|
||||
|
||||
stdin: Mutex<Option<std::process::ChildStdin>>,
|
||||
pty_master: Mutex<Option<std::fs::File>>,
|
||||
}
|
||||
|
||||
impl ProcessHandle {
|
||||
pub fn subscribe_data(&self) -> broadcast::Receiver<DataEvent> {
|
||||
self.data_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn subscribe_end(&self) -> broadcast::Receiver<EndEvent> {
|
||||
self.end_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn cached_end(&self) -> Option<EndEvent> {
|
||||
self.ended.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> {
|
||||
signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write_stdin(&self, data: &[u8]) -> Result<(), ConnectError> {
|
||||
use std::io::Write;
|
||||
let mut guard = self.stdin.lock().unwrap();
|
||||
match guard.as_mut() {
|
||||
Some(stdin) => stdin.write_all(data).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error writing to stdin: {e}"))
|
||||
}),
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"stdin not enabled or closed",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pty(&self, data: &[u8]) -> Result<(), ConnectError> {
|
||||
use std::io::Write;
|
||||
let mut guard = self.pty_master.lock().unwrap();
|
||||
match guard.as_mut() {
|
||||
Some(master) => master.write_all(data).map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error writing to pty: {e}"))
|
||||
}),
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"pty not assigned to process",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close_stdin(&self) -> Result<(), ConnectError> {
|
||||
if self.pty_master.lock().unwrap().is_some() {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"cannot close stdin for PTY process — send Ctrl+D (0x04) instead",
|
||||
));
|
||||
}
|
||||
let mut guard = self.stdin.lock().unwrap();
|
||||
*guard = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resize_pty(&self, cols: u16, rows: u16) -> Result<(), ConnectError> {
|
||||
let guard = self.pty_master.lock().unwrap();
|
||||
match guard.as_ref() {
|
||||
Some(master) => {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let ws = libc::winsize {
|
||||
ws_row: rows,
|
||||
ws_col: cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
};
|
||||
let ret = unsafe { libc::ioctl(master.as_raw_fd(), libc::TIOCSWINSZ, &ws) };
|
||||
if ret != 0 {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::Internal,
|
||||
format!(
|
||||
"ioctl TIOCSWINSZ failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::FailedPrecondition,
|
||||
"tty not assigned to process",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SpawnedProcess {
|
||||
pub handle: Arc<ProcessHandle>,
|
||||
pub data_rx: broadcast::Receiver<DataEvent>,
|
||||
pub end_rx: broadcast::Receiver<EndEvent>,
|
||||
}
|
||||
|
||||
pub fn spawn_process(
|
||||
cmd_str: &str,
|
||||
args: &[String],
|
||||
envs: &std::collections::HashMap<String, String>,
|
||||
cwd: &str,
|
||||
pty_opts: Option<(u16, u16)>,
|
||||
enable_stdin: bool,
|
||||
tag: Option<String>,
|
||||
user: &nix::unistd::User,
|
||||
default_env_vars: &dashmap::DashMap<String, String>,
|
||||
) -> Result<SpawnedProcess, ConnectError> {
|
||||
let mut env: Vec<(String, String)> = Vec::new();
|
||||
env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default()));
|
||||
let home = user.dir.to_string_lossy().to_string();
|
||||
env.push(("HOME".into(), home));
|
||||
env.push(("USER".into(), user.name.clone()));
|
||||
env.push(("LOGNAME".into(), user.name.clone()));
|
||||
|
||||
default_env_vars.iter().for_each(|entry| {
|
||||
env.push((entry.key().clone(), entry.value().clone()));
|
||||
});
|
||||
|
||||
for (k, v) in envs {
|
||||
env.push((k.clone(), v.clone()));
|
||||
}
|
||||
|
||||
let nice_delta = 0 - current_nice();
|
||||
let oom_script = format!(
|
||||
r#"echo 100 > /proc/$$/oom_score_adj && exec /usr/bin/nice -n {} "${{@}}""#,
|
||||
nice_delta
|
||||
);
|
||||
let mut wrapper_args = vec![
|
||||
"-c".to_string(),
|
||||
oom_script,
|
||||
"--".to_string(),
|
||||
cmd_str.to_string(),
|
||||
];
|
||||
wrapper_args.extend_from_slice(args);
|
||||
|
||||
let uid = user.uid.as_raw();
|
||||
let gid = user.gid.as_raw();
|
||||
|
||||
let (data_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||
let (end_tx, _) = broadcast::channel(16);
|
||||
|
||||
let config = ProcessConfig {
|
||||
cmd: cmd_str.to_string(),
|
||||
args: args.to_vec(),
|
||||
envs: envs.clone(),
|
||||
cwd: Some(cwd.to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some((cols, rows)) = pty_opts {
|
||||
let pty_result = openpty(
|
||||
Some(&Winsize {
|
||||
ws_row: rows,
|
||||
ws_col: cols,
|
||||
ws_xpixel: 0,
|
||||
ws_ypixel: 0,
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| ConnectError::new(ErrorCode::Internal, format!("openpty failed: {e}")))?;
|
||||
|
||||
let master_fd = pty_result.master;
|
||||
let slave_fd = pty_result.slave;
|
||||
|
||||
let mut command = std::process::Command::new("/bin/sh");
|
||||
command
|
||||
.args(&wrapper_args)
|
||||
.env_clear()
|
||||
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
|
||||
.current_dir(cwd);
|
||||
|
||||
unsafe {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let slave_raw = slave_fd.as_raw_fd();
|
||||
let master_raw = master_fd.as_raw_fd();
|
||||
command.pre_exec(move || {
|
||||
libc::close(master_raw);
|
||||
nix::unistd::setsid()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
libc::ioctl(slave_raw, libc::TIOCSCTTY, 0);
|
||||
libc::dup2(slave_raw, 0);
|
||||
libc::dup2(slave_raw, 1);
|
||||
libc::dup2(slave_raw, 2);
|
||||
if slave_raw > 2 {
|
||||
libc::close(slave_raw);
|
||||
}
|
||||
libc::setgid(gid);
|
||||
libc::setuid(uid);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
command.stdin(Stdio::null());
|
||||
command.stdout(Stdio::null());
|
||||
command.stderr(Stdio::null());
|
||||
|
||||
let child = command.spawn().map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error starting pty process: {e}"))
|
||||
})?;
|
||||
|
||||
drop(slave_fd);
|
||||
|
||||
let pid = child.id();
|
||||
let master_file: std::fs::File = master_fd.into();
|
||||
let master_clone = master_file.try_clone().unwrap();
|
||||
|
||||
let handle = Arc::new(ProcessHandle {
|
||||
config,
|
||||
tag,
|
||||
pid,
|
||||
data_tx: data_tx.clone(),
|
||||
end_tx: end_tx.clone(),
|
||||
ended: Mutex::new(None),
|
||||
stdin: Mutex::new(None),
|
||||
pty_master: Mutex::new(Some(master_file)),
|
||||
});
|
||||
|
||||
let data_rx = handle.subscribe_data();
|
||||
let end_rx = handle.subscribe_end();
|
||||
|
||||
let data_tx_clone = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut master = master_clone;
|
||||
let mut buf = vec![0u8; PTY_CHUNK_SIZE];
|
||||
loop {
|
||||
match master.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = data_tx_clone.send(DataEvent::Pty(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let end_tx_clone = end_tx.clone();
|
||||
let handle_for_waiter = Arc::clone(&handle);
|
||||
std::thread::spawn(move || {
|
||||
let mut child = child;
|
||||
let end_event = match child.wait() {
|
||||
Ok(s) => EndEvent {
|
||||
exit_code: s.code().unwrap_or(-1),
|
||||
exited: s.code().is_some(),
|
||||
status: format!("{s}"),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => EndEvent {
|
||||
exit_code: -1,
|
||||
exited: false,
|
||||
status: "error".into(),
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
};
|
||||
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
|
||||
let _ = end_tx_clone.send(end_event);
|
||||
});
|
||||
|
||||
tracing::info!(pid, cmd = cmd_str, "process started (pty)");
|
||||
Ok(SpawnedProcess { handle, data_rx, end_rx })
|
||||
} else {
|
||||
let mut command = std::process::Command::new("/bin/sh");
|
||||
command
|
||||
.args(&wrapper_args)
|
||||
.env_clear()
|
||||
.envs(env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
|
||||
.current_dir(cwd)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
if enable_stdin {
|
||||
command.stdin(Stdio::piped());
|
||||
} else {
|
||||
command.stdin(Stdio::null());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
libc::setgid(gid);
|
||||
libc::setuid(uid);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
let mut child = command.spawn().map_err(|e| {
|
||||
ConnectError::new(ErrorCode::Internal, format!("error starting process: {e}"))
|
||||
})?;
|
||||
|
||||
let pid = child.id();
|
||||
let stdin = child.stdin.take();
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
|
||||
let handle = Arc::new(ProcessHandle {
|
||||
config,
|
||||
tag,
|
||||
pid,
|
||||
data_tx: data_tx.clone(),
|
||||
end_tx: end_tx.clone(),
|
||||
ended: Mutex::new(None),
|
||||
stdin: Mutex::new(stdin),
|
||||
pty_master: Mutex::new(None),
|
||||
});
|
||||
|
||||
let data_rx = handle.subscribe_data();
|
||||
let end_rx = handle.subscribe_end();
|
||||
|
||||
if let Some(mut out) = stdout {
|
||||
let tx = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut buf = vec![0u8; STD_CHUNK_SIZE];
|
||||
loop {
|
||||
match out.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = tx.send(DataEvent::Stdout(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(mut err_pipe) = stderr {
|
||||
let tx = data_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut buf = vec![0u8; STD_CHUNK_SIZE];
|
||||
loop {
|
||||
match err_pipe.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let _ = tx.send(DataEvent::Stderr(buf[..n].to_vec()));
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let end_tx_clone = end_tx.clone();
|
||||
let handle_for_waiter = Arc::clone(&handle);
|
||||
std::thread::spawn(move || {
|
||||
let end_event = match child.wait() {
|
||||
Ok(s) => EndEvent {
|
||||
exit_code: s.code().unwrap_or(-1),
|
||||
exited: s.code().is_some(),
|
||||
status: format!("{s}"),
|
||||
error: None,
|
||||
},
|
||||
Err(e) => EndEvent {
|
||||
exit_code: -1,
|
||||
exited: false,
|
||||
status: "error".into(),
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
};
|
||||
*handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone());
|
||||
let _ = end_tx_clone.send(end_event);
|
||||
});
|
||||
|
||||
tracing::info!(pid, cmd = cmd_str, "process started (pipe)");
|
||||
Ok(SpawnedProcess { handle, data_rx, end_rx })
|
||||
}
|
||||
}
|
||||
|
||||
fn current_nice() -> i32 {
|
||||
unsafe {
|
||||
*libc::__errno_location() = 0;
|
||||
let prio = libc::getpriority(libc::PRIO_PROCESS, 0);
|
||||
if *libc::__errno_location() != 0 {
|
||||
return 0;
|
||||
}
|
||||
20 - prio
|
||||
}
|
||||
}
|
||||
481
envd-rs/src/rpc/process_service.rs
Normal file
481
envd-rs/src/rpc/process_service.rs
Normal file
@ -0,0 +1,481 @@
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use connectrpc::{ConnectError, Context, ErrorCode};
|
||||
use dashmap::DashMap;
|
||||
use futures::Stream;
|
||||
|
||||
use crate::permissions::path::expand_and_resolve;
|
||||
use crate::permissions::user::lookup_user;
|
||||
use crate::rpc::pb::process::*;
|
||||
use crate::rpc::process_handler::{self, DataEvent, ProcessHandle};
|
||||
use crate::state::AppState;
|
||||
|
||||
pub struct ProcessServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
processes: DashMap<u32, Arc<ProcessHandle>>,
|
||||
}
|
||||
|
||||
impl ProcessServiceImpl {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
processes: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_process_by_selector(
|
||||
&self,
|
||||
selector: &ProcessSelectorView,
|
||||
) -> Result<Arc<ProcessHandle>, ConnectError> {
|
||||
match &selector.selector {
|
||||
Some(process_selector::SelectorView::Pid(pid)) => {
|
||||
let pid_val = *pid;
|
||||
self.processes
|
||||
.get(&pid_val)
|
||||
.map(|entry| Arc::clone(entry.value()))
|
||||
.ok_or_else(|| {
|
||||
ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("process with pid {pid_val} not found"),
|
||||
)
|
||||
})
|
||||
}
|
||||
Some(process_selector::SelectorView::Tag(tag)) => {
|
||||
let tag_str: &str = tag;
|
||||
for entry in self.processes.iter() {
|
||||
if let Some(ref t) = entry.value().tag {
|
||||
if t == tag_str {
|
||||
return Ok(Arc::clone(entry.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(ConnectError::new(
|
||||
ErrorCode::NotFound,
|
||||
format!("process with tag {tag_str} not found"),
|
||||
))
|
||||
}
|
||||
None => Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
"process selector required",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_from_request(
|
||||
&self,
|
||||
request: &StartRequestView<'_>,
|
||||
) -> Result<process_handler::SpawnedProcess, ConnectError> {
|
||||
let proc_config = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process config required")
|
||||
})?;
|
||||
|
||||
let username = self.state.defaults.user();
|
||||
let user =
|
||||
lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?;
|
||||
|
||||
let cmd: &str = proc_config.cmd;
|
||||
let args: Vec<String> = proc_config.args.iter().map(|s| s.to_string()).collect();
|
||||
let envs: HashMap<String, String> = proc_config
|
||||
.envs
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect();
|
||||
|
||||
let home_dir = user.dir.to_string_lossy().to_string();
|
||||
let cwd_str: &str = proc_config.cwd.unwrap_or("");
|
||||
let default_workdir = self.state.defaults.workdir();
|
||||
let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref())
|
||||
.map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?;
|
||||
|
||||
let effective_cwd = if cwd.is_empty() { "/" } else { &cwd };
|
||||
if let Err(_) = std::fs::metadata(effective_cwd) {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
format!("cwd '{effective_cwd}' does not exist"),
|
||||
));
|
||||
}
|
||||
|
||||
let pty_opts = request.pty.as_option().and_then(|pty| {
|
||||
pty.size
|
||||
.as_option()
|
||||
.map(|sz| (sz.cols as u16, sz.rows as u16))
|
||||
});
|
||||
|
||||
let enable_stdin = request.stdin.unwrap_or(true);
|
||||
let tag = request.tag.map(|s| s.to_string());
|
||||
|
||||
tracing::info!(
|
||||
cmd = cmd,
|
||||
has_pty = pty_opts.is_some(),
|
||||
pty_size = ?pty_opts,
|
||||
tag = ?tag,
|
||||
stdin = enable_stdin,
|
||||
cwd = effective_cwd,
|
||||
user = %username,
|
||||
"process.Start request"
|
||||
);
|
||||
|
||||
let spawned = process_handler::spawn_process(
|
||||
cmd,
|
||||
&args,
|
||||
&envs,
|
||||
effective_cwd,
|
||||
pty_opts,
|
||||
enable_stdin,
|
||||
tag,
|
||||
&user,
|
||||
&self.state.defaults.env_vars,
|
||||
)?;
|
||||
|
||||
self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle));
|
||||
|
||||
let processes = self.processes.clone();
|
||||
let pid = spawned.handle.pid;
|
||||
let mut cleanup_end_rx = spawned.handle.subscribe_end();
|
||||
tokio::spawn(async move {
|
||||
let _ = cleanup_end_rx.recv().await;
|
||||
processes.remove(&pid);
|
||||
});
|
||||
|
||||
Ok(spawned)
|
||||
}
|
||||
}
|
||||
|
||||
impl Process for ProcessServiceImpl {
|
||||
async fn list(
|
||||
&self,
|
||||
ctx: Context,
|
||||
_request: buffa::view::OwnedView<ListRequestView<'static>>,
|
||||
) -> Result<(ListResponse, Context), ConnectError> {
|
||||
let processes: Vec<ProcessInfo> = self
|
||||
.processes
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let h = entry.value();
|
||||
ProcessInfo {
|
||||
config: buffa::MessageField::some(h.config.clone()),
|
||||
pid: h.pid,
|
||||
tag: h.tag.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok((
|
||||
ListResponse {
|
||||
processes,
|
||||
..Default::default()
|
||||
},
|
||||
ctx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<StartRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<StartResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
let spawned = self.spawn_from_request(&request)?;
|
||||
let pid = spawned.handle.pid;
|
||||
|
||||
let mut data_rx = spawned.data_rx;
|
||||
let mut end_rx = spawned.end_rx;
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
yield Ok(make_start_response(pid));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
data = data_rx.recv() => {
|
||||
match data {
|
||||
Ok(ev) => yield Ok(make_data_start_response(ev)),
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
end = end_rx.recv() => {
|
||||
while let Ok(ev) = data_rx.try_recv() {
|
||||
yield Ok(make_data_start_response(ev));
|
||||
}
|
||||
if let Ok(end) = end {
|
||||
yield Ok(make_end_start_response(end));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok((Box::pin(stream), ctx))
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<ConnectRequestView<'static>>,
|
||||
) -> Result<
|
||||
(
|
||||
Pin<Box<dyn Stream<Item = Result<ConnectResponse, ConnectError>> + Send>>,
|
||||
Context,
|
||||
),
|
||||
ConnectError,
|
||||
> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
let pid = handle.pid;
|
||||
|
||||
let mut data_rx = handle.subscribe_data();
|
||||
let mut end_rx = handle.subscribe_end();
|
||||
let cached_end = handle.cached_end();
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(ProcessEvent {
|
||||
event: Some(process_event::Event::Start(Box::new(
|
||||
process_event::StartEvent { pid, ..Default::default() },
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
if let Some(end) = cached_end {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
});
|
||||
} else {
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
data = data_rx.recv() => {
|
||||
match data {
|
||||
Ok(ev) => {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
end = end_rx.recv() => {
|
||||
while let Ok(ev) = data_rx.try_recv() {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
if let Ok(end) = end {
|
||||
yield Ok(ConnectResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok((Box::pin(stream), ctx))
|
||||
}
|
||||
|
||||
async fn update(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<UpdateRequestView<'static>>,
|
||||
) -> Result<(UpdateResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
if let Some(pty) = request.pty.as_option() {
|
||||
if let Some(size) = pty.size.as_option() {
|
||||
handle.resize_pty(size.cols as u16, size.rows as u16)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((UpdateResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn stream_input(
|
||||
&self,
|
||||
ctx: Context,
|
||||
mut requests: Pin<
|
||||
Box<
|
||||
dyn Stream<
|
||||
Item = Result<
|
||||
buffa::view::OwnedView<StreamInputRequestView<'static>>,
|
||||
ConnectError,
|
||||
>,
|
||||
> + Send,
|
||||
>,
|
||||
>,
|
||||
) -> Result<(StreamInputResponse, Context), ConnectError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut handle: Option<Arc<ProcessHandle>> = None;
|
||||
|
||||
while let Some(result) = requests.next().await {
|
||||
let req = result?;
|
||||
match &req.event {
|
||||
Some(stream_input_request::EventView::Start(start)) => {
|
||||
if let Some(selector) = start.process.as_option() {
|
||||
handle = Some(self.get_process_by_selector(selector)?);
|
||||
}
|
||||
}
|
||||
Some(stream_input_request::EventView::Data(data)) => {
|
||||
let h = handle.as_ref().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::FailedPrecondition, "no start event received")
|
||||
})?;
|
||||
if let Some(input) = data.input.as_option() {
|
||||
write_input(h, input)?;
|
||||
}
|
||||
}
|
||||
Some(stream_input_request::EventView::Keepalive(_)) => {}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((StreamInputResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn send_input(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<SendInputRequestView<'static>>,
|
||||
) -> Result<(SendInputResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
if let Some(input) = request.input.as_option() {
|
||||
write_input(&handle, input)?;
|
||||
}
|
||||
|
||||
Ok((SendInputResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn send_signal(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<SendSignalRequestView<'static>>,
|
||||
) -> Result<(SendSignalResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
|
||||
let sig = match request.signal.as_known() {
|
||||
Some(Signal::SIGNAL_SIGKILL) => nix::sys::signal::Signal::SIGKILL,
|
||||
Some(Signal::SIGNAL_SIGTERM) => nix::sys::signal::Signal::SIGTERM,
|
||||
_ => {
|
||||
return Err(ConnectError::new(
|
||||
ErrorCode::InvalidArgument,
|
||||
"invalid or unspecified signal",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
handle.send_signal(sig)?;
|
||||
Ok((SendSignalResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
|
||||
async fn close_stdin(
|
||||
&self,
|
||||
ctx: Context,
|
||||
request: buffa::view::OwnedView<CloseStdinRequestView<'static>>,
|
||||
) -> Result<(CloseStdinResponse, Context), ConnectError> {
|
||||
let selector = request.process.as_option().ok_or_else(|| {
|
||||
ConnectError::new(ErrorCode::InvalidArgument, "process selector required")
|
||||
})?;
|
||||
let handle = self.get_process_by_selector(selector)?;
|
||||
handle.close_stdin()?;
|
||||
Ok((CloseStdinResponse { ..Default::default() }, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
fn write_input(handle: &ProcessHandle, input: &ProcessInputView) -> Result<(), ConnectError> {
|
||||
match &input.input {
|
||||
Some(process_input::InputView::Pty(d)) => handle.write_pty(d),
|
||||
Some(process_input::InputView::Stdin(d)) => handle.write_stdin(d),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_start_response(pid: u32) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(ProcessEvent {
|
||||
event: Some(process_event::Event::Start(Box::new(
|
||||
process_event::StartEvent {
|
||||
pid,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_data_event(ev: DataEvent) -> ProcessEvent {
|
||||
let output = match ev {
|
||||
DataEvent::Stdout(d) => Some(process_event::data_event::Output::Stdout(d.into())),
|
||||
DataEvent::Stderr(d) => Some(process_event::data_event::Output::Stderr(d.into())),
|
||||
DataEvent::Pty(d) => Some(process_event::data_event::Output::Pty(d.into())),
|
||||
};
|
||||
ProcessEvent {
|
||||
event: Some(process_event::Event::Data(Box::new(
|
||||
process_event::DataEvent {
|
||||
output,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_data_start_response(ev: DataEvent) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(make_data_event(ev)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_end_event(end: process_handler::EndEvent) -> ProcessEvent {
|
||||
ProcessEvent {
|
||||
event: Some(process_event::Event::End(Box::new(
|
||||
process_event::EndEvent {
|
||||
exit_code: end.exit_code,
|
||||
exited: end.exited,
|
||||
status: end.status,
|
||||
error: end.error,
|
||||
..Default::default()
|
||||
},
|
||||
))),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_end_start_response(end: process_handler::EndEvent) -> StartResponse {
|
||||
StartResponse {
|
||||
event: buffa::MessageField::some(make_end_event(end)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
141
envd-rs/src/state.rs
Normal file
141
envd-rs/src/state.rs
Normal file
@ -0,0 +1,141 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::auth::token::SecureToken;
|
||||
use crate::conntracker::ConnTracker;
|
||||
use crate::execcontext::Defaults;
|
||||
use crate::port::subsystem::PortSubsystem;
|
||||
use crate::util::AtomicMax;
|
||||
|
||||
pub struct AppState {
|
||||
pub defaults: Defaults,
|
||||
pub version: String,
|
||||
pub commit: String,
|
||||
pub needs_restore: AtomicBool,
|
||||
pub last_set_time: AtomicMax,
|
||||
pub access_token: SecureToken,
|
||||
pub conn_tracker: ConnTracker,
|
||||
pub port_subsystem: Option<Arc<PortSubsystem>>,
|
||||
pub cpu_used_pct: AtomicU32,
|
||||
pub cpu_count: AtomicU32,
|
||||
pub snapshot_in_progress: AtomicBool,
|
||||
pub last_health_epoch: AtomicU64,
|
||||
pub restore_epoch: AtomicU64,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
defaults: Defaults,
|
||||
version: String,
|
||||
commit: String,
|
||||
port_subsystem: Option<Arc<PortSubsystem>>,
|
||||
) -> Arc<Self> {
|
||||
let state = Arc::new(Self {
|
||||
defaults,
|
||||
version,
|
||||
commit,
|
||||
needs_restore: AtomicBool::new(false),
|
||||
last_set_time: AtomicMax::new(),
|
||||
access_token: SecureToken::new(),
|
||||
conn_tracker: ConnTracker::new(),
|
||||
port_subsystem,
|
||||
cpu_used_pct: AtomicU32::new(0),
|
||||
cpu_count: AtomicU32::new(0),
|
||||
snapshot_in_progress: AtomicBool::new(false),
|
||||
last_health_epoch: AtomicU64::new(0),
|
||||
restore_epoch: AtomicU64::new(0),
|
||||
});
|
||||
|
||||
let state_clone = Arc::clone(&state);
|
||||
std::thread::spawn(move || {
|
||||
cpu_sampler(state_clone);
|
||||
});
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
pub fn cpu_used_pct(&self) -> f32 {
|
||||
f32::from_bits(self.cpu_used_pct.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
pub fn cpu_count(&self) -> u32 {
|
||||
self.cpu_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Runs post-restore recovery if `needs_restore` is set OR a wall-clock
|
||||
/// gap is detected (catches restores where snapshot/prepare never ran).
|
||||
pub fn try_restore_recovery(&self) {
|
||||
let now_epoch = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let prev_epoch = self.last_health_epoch.swap(now_epoch, Ordering::AcqRel);
|
||||
|
||||
// Detect restore via wall-clock gap: if >3s passed since last health
|
||||
// check, the VM was frozen and restored. Catches the case where
|
||||
// snapshot/prepare timed out and needs_restore was never set.
|
||||
let gap_detected = prev_epoch > 0 && now_epoch.saturating_sub(prev_epoch) > 3;
|
||||
|
||||
let flag_set = self
|
||||
.needs_restore
|
||||
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok();
|
||||
|
||||
if !flag_set && !gap_detected {
|
||||
return;
|
||||
}
|
||||
|
||||
if gap_detected && !flag_set {
|
||||
tracing::info!(
|
||||
gap_secs = now_epoch.saturating_sub(prev_epoch),
|
||||
"restore: detected via wall-clock gap (needs_restore was not set)"
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!("restore: post-restore recovery");
|
||||
self.snapshot_in_progress.store(false, Ordering::Release);
|
||||
self.restore_epoch.store(now_epoch, Ordering::Release);
|
||||
self.conn_tracker.restore_after_snapshot();
|
||||
|
||||
if let Some(ref ps) = self.port_subsystem {
|
||||
ps.restart();
|
||||
tracing::info!("restore: port subsystem restarted");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_sampler(state: Arc<AppState>) {
|
||||
use sysinfo::System;
|
||||
|
||||
let mut sys = System::new();
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
loop {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
|
||||
if state.needs_restore.load(Ordering::Acquire) {
|
||||
// After snapshot restore, sysinfo's internal CPU counters are stale.
|
||||
// Reinitialize to get a fresh baseline.
|
||||
sys = System::new();
|
||||
sys.refresh_cpu_all();
|
||||
continue;
|
||||
}
|
||||
|
||||
sys.refresh_cpu_all();
|
||||
|
||||
let pct = sys.global_cpu_usage();
|
||||
let rounded = if pct > 0.0 {
|
||||
(pct * 100.0).round() / 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
state
|
||||
.cpu_used_pct
|
||||
.store(rounded.to_bits(), Ordering::Relaxed);
|
||||
state
|
||||
.cpu_count
|
||||
.store(sys.cpus().len() as u32, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
102
envd-rs/src/util.rs
Normal file
102
envd-rs/src/util.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
|
||||
pub struct AtomicMax {
|
||||
val: AtomicI64,
|
||||
}
|
||||
|
||||
impl AtomicMax {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
val: AtomicI64::new(i64::MIN),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self) -> i64 {
|
||||
self.val.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Sets the stored value to `new` if `new` is strictly greater than
|
||||
/// the current value. Returns `true` if the value was updated.
|
||||
pub fn set_to_greater(&self, new: i64) -> bool {
|
||||
loop {
|
||||
let current = self.val.load(Ordering::Acquire);
|
||||
if new <= current {
|
||||
return false;
|
||||
}
|
||||
match self.val.compare_exchange_weak(
|
||||
current,
|
||||
new,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return true,
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn initial_value_is_i64_min() {
|
||||
let m = AtomicMax::new();
|
||||
assert_eq!(m.get(), i64::MIN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn updates_when_larger() {
|
||||
let m = AtomicMax::new();
|
||||
assert!(m.set_to_greater(0));
|
||||
assert_eq!(m.get(), 0);
|
||||
assert!(m.set_to_greater(100));
|
||||
assert_eq!(m.get(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_false_when_equal() {
|
||||
let m = AtomicMax::new();
|
||||
m.set_to_greater(42);
|
||||
assert!(!m.set_to_greater(42));
|
||||
assert_eq!(m.get(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_false_when_smaller() {
|
||||
let m = AtomicMax::new();
|
||||
m.set_to_greater(100);
|
||||
assert!(!m.set_to_greater(50));
|
||||
assert_eq!(m.get(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent_convergence() {
|
||||
let m = Arc::new(AtomicMax::new());
|
||||
let threads: Vec<_> = (0..8)
|
||||
.map(|t| {
|
||||
let m = Arc::clone(&m);
|
||||
std::thread::spawn(move || {
|
||||
for i in (t * 100)..((t + 1) * 100) {
|
||||
m.set_to_greater(i);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
assert_eq!(m.get(), 799);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_max_boundary() {
|
||||
let m = AtomicMax::new();
|
||||
assert!(m.set_to_greater(i64::MAX));
|
||||
assert!(!m.set_to_greater(i64::MAX));
|
||||
assert!(!m.set_to_greater(0));
|
||||
assert_eq!(m.get(), i64::MAX);
|
||||
}
|
||||
}
|
||||
202
envd/LICENSE
202
envd/LICENSE
@ -1,202 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023 FoundryLabs, Inc.
|
||||
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -1,62 +0,0 @@
|
||||
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
|
||||
BUILDS := ../builds
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Build
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: build build-debug
|
||||
|
||||
build:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
|
||||
@file $(BUILDS)/envd | grep -q "statically linked" || \
|
||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||
|
||||
build-debug:
|
||||
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Run (debug mode, not inside a VM)
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: run-debug
|
||||
|
||||
run-debug: build-debug
|
||||
$(BUILDS)/debug/envd -isnotfc -port 49983
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Code Generation
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: generate proto openapi
|
||||
|
||||
generate: proto openapi
|
||||
|
||||
proto:
|
||||
cd spec && buf generate --template buf.gen.yaml
|
||||
|
||||
openapi:
|
||||
go generate ./internal/api/...
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Quality
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: fmt vet test tidy
|
||||
|
||||
fmt:
|
||||
gofmt -w .
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
|
||||
test:
|
||||
go test -race -v ./...
|
||||
|
||||
tidy:
|
||||
go mod tidy
|
||||
|
||||
# ═══════════════════════════════════════════════════
|
||||
# Clean
|
||||
# ═══════════════════════════════════════════════════
|
||||
.PHONY: clean
|
||||
|
||||
clean:
|
||||
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd
|
||||
@ -1 +0,0 @@
|
||||
0.1.0
|
||||
42
envd/go.mod
42
envd/go.mod
@ -1,42 +0,0 @@
|
||||
module git.omukk.dev/wrenn/sandbox/envd
|
||||
|
||||
go 1.25.8
|
||||
|
||||
require (
|
||||
connectrpc.com/authn v0.1.0
|
||||
connectrpc.com/connect v1.19.1
|
||||
connectrpc.com/cors v0.1.0
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/dchest/uniuri v1.2.0
|
||||
github.com/e2b-dev/fsnotify v0.0.1
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oapi-codegen/runtime v1.2.0
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1
|
||||
github.com/rs/cors v1.11.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/shirou/gopsutil/v4 v4.26.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/txn2/txeh v1.8.0
|
||||
golang.org/x/sys v0.43.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/awnumar/memcall v0.4.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||
github.com/tklauser/numcpus v0.11.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
92
envd/go.sum
92
envd/go.sum
@ -1,92 +0,0 @@
|
||||
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
|
||||
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
|
||||
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
|
||||
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
|
||||
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
|
||||
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
|
||||
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
|
||||
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
|
||||
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
|
||||
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
|
||||
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
|
||||
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
|
||||
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
|
||||
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
|
||||
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||
@ -1,604 +0,0 @@
|
||||
// Package api provides primitives to interact with the openapi HTTP API.
|
||||
//
|
||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oapi-codegen/runtime"
|
||||
openapi_types "github.com/oapi-codegen/runtime/types"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for EntryInfoType.
|
||||
const (
|
||||
File EntryInfoType = "file"
|
||||
)
|
||||
|
||||
// Valid indicates whether the value is a known member of the EntryInfoType enum.
|
||||
func (e EntryInfoType) Valid() bool {
|
||||
switch e {
|
||||
case File:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EntryInfo defines model for EntryInfo.
|
||||
type EntryInfo struct {
|
||||
// Name Name of the file
|
||||
Name string `json:"name"`
|
||||
|
||||
// Path Path to the file
|
||||
Path string `json:"path"`
|
||||
|
||||
// Type Type of the file
|
||||
Type EntryInfoType `json:"type"`
|
||||
}
|
||||
|
||||
// EntryInfoType Type of the file
|
||||
type EntryInfoType string
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
type EnvVars map[string]string
|
||||
|
||||
// Error defines model for Error.
|
||||
type Error struct {
|
||||
// Code Error code
|
||||
Code int `json:"code"`
|
||||
|
||||
// Message Error message
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Metrics Resource usage metrics
|
||||
type Metrics struct {
|
||||
// CpuCount Number of CPU cores
|
||||
CpuCount *int `json:"cpu_count,omitempty"`
|
||||
|
||||
// CpuUsedPct CPU usage percentage
|
||||
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
|
||||
|
||||
// DiskTotal Total disk space in bytes
|
||||
DiskTotal *int `json:"disk_total,omitempty"`
|
||||
|
||||
// DiskUsed Used disk space in bytes
|
||||
DiskUsed *int `json:"disk_used,omitempty"`
|
||||
|
||||
// MemTotal Total virtual memory in bytes
|
||||
MemTotal *int `json:"mem_total,omitempty"`
|
||||
|
||||
// MemUsed Used virtual memory in bytes
|
||||
MemUsed *int `json:"mem_used,omitempty"`
|
||||
|
||||
// Ts Unix timestamp in UTC for current sandbox time
|
||||
Ts *int64 `json:"ts,omitempty"`
|
||||
}
|
||||
|
||||
// VolumeMount Volume
|
||||
type VolumeMount struct {
|
||||
NfsTarget string `json:"nfs_target"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// FilePath defines model for FilePath.
|
||||
type FilePath = string
|
||||
|
||||
// Signature defines model for Signature.
|
||||
type Signature = string
|
||||
|
||||
// SignatureExpiration defines model for SignatureExpiration.
|
||||
type SignatureExpiration = int
|
||||
|
||||
// User defines model for User.
|
||||
type User = string
|
||||
|
||||
// FileNotFound defines model for FileNotFound.
|
||||
type FileNotFound = Error
|
||||
|
||||
// InternalServerError defines model for InternalServerError.
|
||||
type InternalServerError = Error
|
||||
|
||||
// InvalidPath defines model for InvalidPath.
|
||||
type InvalidPath = Error
|
||||
|
||||
// InvalidUser defines model for InvalidUser.
|
||||
type InvalidUser = Error
|
||||
|
||||
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
|
||||
type NotEnoughDiskSpace = Error
|
||||
|
||||
// UploadSuccess defines model for UploadSuccess.
|
||||
type UploadSuccess = []EntryInfo
|
||||
|
||||
// GetFilesParams defines parameters for GetFiles.
|
||||
type GetFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartBody defines parameters for PostFiles.
|
||||
type PostFilesMultipartBody struct {
|
||||
File *openapi_types.File `json:"file,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesParams defines parameters for PostFiles.
|
||||
type PostFilesParams struct {
|
||||
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||
|
||||
// Username User used for setting the owner, or resolving relative paths.
|
||||
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||
|
||||
// Signature Signature used for file access permission verification.
|
||||
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||
|
||||
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||
}
|
||||
|
||||
// PostInitJSONBody defines parameters for PostInit.
|
||||
type PostInitJSONBody struct {
|
||||
// AccessToken Access token for secure access to envd service
|
||||
AccessToken *SecureToken `json:"accessToken,omitempty"`
|
||||
|
||||
// DefaultUser The default user to use for operations
|
||||
DefaultUser *string `json:"defaultUser,omitempty"`
|
||||
|
||||
// DefaultWorkdir The default working directory to use for operations
|
||||
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
|
||||
|
||||
// EnvVars Environment variables to set
|
||||
EnvVars *EnvVars `json:"envVars,omitempty"`
|
||||
|
||||
// HyperloopIP IP address of the hyperloop server to connect to
|
||||
HyperloopIP *string `json:"hyperloopIP,omitempty"`
|
||||
|
||||
// Timestamp The current timestamp in RFC3339 format
|
||||
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
|
||||
}
|
||||
|
||||
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
|
||||
type PostFilesMultipartRequestBody PostFilesMultipartBody
|
||||
|
||||
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
|
||||
type PostInitJSONRequestBody PostInitJSONBody
|
||||
|
||||
// ServerInterface represents all server handlers.
|
||||
type ServerInterface interface {
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
GetEnvs(w http.ResponseWriter, r *http.Request)
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
GetHealth(w http.ResponseWriter, r *http.Request)
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
PostInit(w http.ResponseWriter, r *http.Request)
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||
// Quiesce continuous goroutines before Firecracker snapshot
|
||||
// (POST /snapshot/prepare)
|
||||
PostSnapshotPrepare(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||
|
||||
type Unimplemented struct{}
|
||||
|
||||
// Get the environment variables
|
||||
// (GET /envs)
|
||||
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Download a file
|
||||
// (GET /files)
|
||||
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||
// (POST /files)
|
||||
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Check the health of the service
|
||||
// (GET /health)
|
||||
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Set initial vars, ensure the time and metadata is synced with the host
|
||||
// (POST /init)
|
||||
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Get the stats of the service
|
||||
// (GET /metrics)
|
||||
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// Quiesce continuous goroutines before Firecracker snapshot
|
||||
// (POST /snapshot/prepare)
|
||||
func (_ Unimplemented) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// ServerInterfaceWrapper converts contexts to parameters.
|
||||
type ServerInterfaceWrapper struct {
|
||||
Handler ServerInterface
|
||||
HandlerMiddlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// GetEnvs operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetEnvs(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params GetFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostFiles operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Parameter object where we will unmarshal all parameters from the context
|
||||
var params PostFilesParams
|
||||
|
||||
// ------------- Optional query parameter "path" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "username" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
// ------------- Optional query parameter "signature_expiration" -------------
|
||||
|
||||
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||
if err != nil {
|
||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||
return
|
||||
}
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostFiles(w, r, params)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetHealth operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetHealth(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostInit operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostInit(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// GetMetrics operation middleware
|
||||
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.GetMetrics(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// PostSnapshotPrepare operation middleware
|
||||
func (siw *ServerInterfaceWrapper) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
siw.Handler.PostSnapshotPrepare(w, r)
|
||||
}))
|
||||
|
||||
for _, middleware := range siw.HandlerMiddlewares {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type UnescapedCookieParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Error() string {
|
||||
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *UnescapedCookieParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type UnmarshalingParamError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Error() string {
|
||||
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *UnmarshalingParamError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type RequiredParamError struct {
|
||||
ParamName string
|
||||
}
|
||||
|
||||
func (e *RequiredParamError) Error() string {
|
||||
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
type RequiredHeaderError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Error() string {
|
||||
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
|
||||
}
|
||||
|
||||
func (e *RequiredHeaderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type InvalidParamFormatError struct {
|
||||
ParamName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Error() string {
|
||||
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *InvalidParamFormatError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TooManyValuesForParamError struct {
|
||||
ParamName string
|
||||
Count int
|
||||
}
|
||||
|
||||
func (e *TooManyValuesForParamError) Error() string {
|
||||
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
|
||||
}
|
||||
|
||||
// Handler creates http.Handler with routing matching OpenAPI spec.
|
||||
func Handler(si ServerInterface) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{})
|
||||
}
|
||||
|
||||
type ChiServerOptions struct {
|
||||
BaseURL string
|
||||
BaseRouter chi.Router
|
||||
Middlewares []MiddlewareFunc
|
||||
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
|
||||
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
|
||||
return HandlerWithOptions(si, ChiServerOptions{
|
||||
BaseURL: baseURL,
|
||||
BaseRouter: r,
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithOptions creates http.Handler with additional options
|
||||
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
|
||||
r := options.BaseRouter
|
||||
|
||||
if r == nil {
|
||||
r = chi.NewRouter()
|
||||
}
|
||||
if options.ErrorHandlerFunc == nil {
|
||||
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
wrapper := ServerInterfaceWrapper{
|
||||
Handler: si,
|
||||
HandlerMiddlewares: options.Middlewares,
|
||||
ErrorHandlerFunc: options.ErrorHandlerFunc,
|
||||
}
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/init", wrapper.PostInit)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Post(options.BaseURL+"/snapshot/prepare", wrapper.PostSnapshotPrepare)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
@ -1,133 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningReadOperation = "read"
|
||||
SigningWriteOperation = "write"
|
||||
|
||||
accessTokenHeader = "X-Access-Token"
|
||||
)
|
||||
|
||||
// paths that are always allowed without general authentication
|
||||
// POST/init is secured via MMDS hash validation instead
|
||||
var authExcludedPaths = []string{
|
||||
"GET/health",
|
||||
"GET/files",
|
||||
"POST/files",
|
||||
"POST/init",
|
||||
"POST/snapshot/prepare",
|
||||
}
|
||||
|
||||
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if a.accessToken.IsSet() {
|
||||
authHeader := req.Header.Get(accessTokenHeader)
|
||||
|
||||
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
|
||||
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
|
||||
|
||||
if !a.accessToken.Equals(authHeader) && !allowedPath {
|
||||
a.logger.Error().Msg("Trying to access secured envd without correct access token")
|
||||
|
||||
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
|
||||
tokenBytes, err := a.accessToken.Bytes()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("access token is not set: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
var signature string
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
|
||||
if signatureExpiration == nil {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
|
||||
} else {
|
||||
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
|
||||
}
|
||||
|
||||
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
|
||||
var expectedSignature string
|
||||
|
||||
// no need to validate signing key if access token is not set
|
||||
if !a.accessToken.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if access token is sent in the header
|
||||
tokenFromHeader := r.Header.Get(accessTokenHeader)
|
||||
if tokenFromHeader != "" {
|
||||
if !a.accessToken.Equals(tokenFromHeader) {
|
||||
return fmt.Errorf("access token present in header but does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if signature == nil {
|
||||
return fmt.Errorf("missing signature query parameter")
|
||||
}
|
||||
|
||||
// Empty string is used when no username is provided and the default user should be used
|
||||
signatureUsername := ""
|
||||
if username != nil {
|
||||
signatureUsername = *username
|
||||
}
|
||||
|
||||
if signatureExpiration == nil {
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
|
||||
} else {
|
||||
exp := int64(*signatureExpiration)
|
||||
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("error generating signing key")
|
||||
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
// signature validation
|
||||
if expectedSignature != *signature {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
// signature expiration
|
||||
if signatureExpiration != nil {
|
||||
exp := int64(*signatureExpiration)
|
||||
if exp < time.Now().Unix() {
|
||||
return fmt.Errorf("signature is already expired")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,64 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
)
|
||||
|
||||
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/demo.txt"
|
||||
username := "root"
|
||||
operation := "write"
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, ×tamp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
|
||||
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
apiToken := "secret-access-token"
|
||||
secureToken := &SecureToken{}
|
||||
err := secureToken.Set([]byte(apiToken))
|
||||
require.NoError(t, err)
|
||||
api := &API{accessToken: secureToken}
|
||||
|
||||
path := "/path/to/resource.txt"
|
||||
username := "user"
|
||||
operation := "read"
|
||||
|
||||
signature, err := api.generateSignature(path, username, operation, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// locally generated signature
|
||||
hasher := keys.NewSHA256Hashing()
|
||||
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
|
||||
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||
|
||||
assert.Equal(t, localSignature, signature)
|
||||
}
|
||||
@ -1,10 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
|
||||
|
||||
package: api
|
||||
output: api.gen.go
|
||||
generate:
|
||||
models: true
|
||||
chi-server: true
|
||||
client: false
|
||||
@ -1,187 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
)
|
||||
|
||||
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File read")
|
||||
}()
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stat, err := os.Stat(resolvedPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
|
||||
errorCode = http.StatusNotFound
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Reject anything that isn't a regular file (devices, pipes, sockets, etc.).
|
||||
// Reading device files like /dev/zero or /dev/urandom produces infinite data
|
||||
// and will exhaust memory on all layers of the stack.
|
||||
if !stat.Mode().IsRegular() {
|
||||
errMsg = fmt.Errorf("path '%s' is not a regular file", resolvedPath)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Accept-Encoding header
|
||||
encoding, err := parseAcceptEncoding(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Tell caches to store separate variants for different Accept-Encoding values
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
|
||||
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
|
||||
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
|
||||
// is acceptable per the Accept-Encoding header.
|
||||
hasRangeOrConditional := r.Header.Get("Range") != "" ||
|
||||
r.Header.Get("If-Modified-Since") != "" ||
|
||||
r.Header.Get("If-None-Match") != "" ||
|
||||
r.Header.Get("If-Range") != ""
|
||||
if hasRangeOrConditional {
|
||||
if !isIdentityAcceptable(r) {
|
||||
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
|
||||
errorCode = http.StatusNotAcceptable
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
encoding = EncodingIdentity
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvedPath)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
|
||||
|
||||
// Serve with gzip encoding if requested.
|
||||
if encoding == EncodingGzip {
|
||||
w.Header().Set("Content-Encoding", EncodingGzip)
|
||||
|
||||
// Set Content-Type based on file extension, preserving the original type
|
||||
contentType := mime.TypeByExtension(filepath.Ext(path))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
gw := gzip.NewWriter(w)
|
||||
defer gw.Close()
|
||||
|
||||
_, err = io.Copy(gw, file)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, path, stat.ModTime(), file)
|
||||
}
|
||||
@ -1,405 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetFilesContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
name: "simple filename",
|
||||
filename: "test.txt",
|
||||
expectedHeader: `inline; filename=test.txt`,
|
||||
},
|
||||
{
|
||||
name: "filename with extension",
|
||||
filename: "presentation.pptx",
|
||||
expectedHeader: `inline; filename=presentation.pptx`,
|
||||
},
|
||||
{
|
||||
name: "filename with multiple dots",
|
||||
filename: "archive.tar.gz",
|
||||
expectedHeader: `inline; filename=archive.tar.gz`,
|
||||
},
|
||||
{
|
||||
name: "filename with spaces",
|
||||
filename: "my document.pdf",
|
||||
expectedHeader: `inline; filename="my document.pdf"`,
|
||||
},
|
||||
{
|
||||
name: "filename with quotes",
|
||||
filename: `file"name.txt`,
|
||||
expectedHeader: `inline; filename="file\"name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "filename with backslash",
|
||||
filename: `file\name.txt`,
|
||||
expectedHeader: `inline; filename="file\\name.txt"`,
|
||||
},
|
||||
{
|
||||
name: "unicode filename",
|
||||
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
||||
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
||||
},
|
||||
{
|
||||
name: "dotfile preserved",
|
||||
filename: ".env",
|
||||
expectedHeader: `inline; filename=.env`,
|
||||
},
|
||||
{
|
||||
name: "dotfile with extension preserved",
|
||||
filename: ".gitignore",
|
||||
expectedHeader: `inline; filename=.gitignore`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory and file
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, tt.filename)
|
||||
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with nested structure
|
||||
tempDir := t.TempDir()
|
||||
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
||||
err = os.MkdirAll(nestedDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(nestedDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify Content-Disposition header uses only the base filename, not the full path
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a temp directory with a test file
|
||||
tempDir := t.TempDir()
|
||||
filename := "document.pdf"
|
||||
tempFile := filepath.Join(tempDir, filename)
|
||||
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
||||
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
// Check response
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestGetFiles_GzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file for gzip compression")
|
||||
|
||||
// Create a temp file with known content
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test.txt")
|
||||
err = os.WriteFile(tempFile, originalContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := GetFilesParams{
|
||||
Path: &tempFile,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Decompress the gzip response body
|
||||
gzReader, err := gzip.NewReader(resp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
|
||||
func TestPostFiles_GzipUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test API
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "uploaded.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
params := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(w, req, params)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify the file was written with the original (decompressed) content
|
||||
data, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
}
|
||||
|
||||
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
||||
|
||||
// --- Upload with gzip ---
|
||||
|
||||
// Build a multipart body
|
||||
var multipartBuf bytes.Buffer
|
||||
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = mpWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Gzip-compress the entire multipart body
|
||||
var gzBuf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&gzBuf)
|
||||
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
||||
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
User: currentUser.Username,
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
|
||||
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
uploadReq.Header.Set("Content-Encoding", "gzip")
|
||||
uploadW := httptest.NewRecorder()
|
||||
|
||||
uploadParams := PostFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.PostFiles(uploadW, uploadReq, uploadParams)
|
||||
|
||||
uploadResp := uploadW.Result()
|
||||
defer uploadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
||||
|
||||
// --- Download with gzip ---
|
||||
|
||||
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
||||
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
||||
downloadW := httptest.NewRecorder()
|
||||
|
||||
downloadParams := GetFilesParams{
|
||||
Path: &destPath,
|
||||
Username: ¤tUser.Username,
|
||||
}
|
||||
api.GetFiles(downloadW, downloadReq, downloadParams)
|
||||
|
||||
downloadResp := downloadW.Result()
|
||||
defer downloadResp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
||||
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
||||
|
||||
// Decompress and verify content matches original
|
||||
gzReader, err := gzip.NewReader(downloadResp.Body)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalContent, decompressed)
|
||||
}
|
||||
@ -1,229 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncodingGzip is the gzip content encoding.
|
||||
EncodingGzip = "gzip"
|
||||
// EncodingIdentity means no encoding (passthrough).
|
||||
EncodingIdentity = "identity"
|
||||
// EncodingWildcard means any encoding is acceptable.
|
||||
EncodingWildcard = "*"
|
||||
)
|
||||
|
||||
// SupportedEncodings lists the content encodings supported for file transfer.
|
||||
// The order matters - encodings are checked in order of preference.
|
||||
var SupportedEncodings = []string{
|
||||
EncodingGzip,
|
||||
}
|
||||
|
||||
// encodingWithQuality holds an encoding name and its quality value.
|
||||
type encodingWithQuality struct {
|
||||
encoding string
|
||||
quality float64
|
||||
}
|
||||
|
||||
// isSupportedEncoding checks if the given encoding is in the supported list.
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func isSupportedEncoding(encoding string) bool {
|
||||
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
|
||||
}
|
||||
|
||||
// parseEncodingWithQuality parses an encoding value and extracts the quality.
|
||||
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
|
||||
// Per RFC 7231, content-coding values are case-insensitive.
|
||||
func parseEncodingWithQuality(value string) encodingWithQuality {
|
||||
value = strings.TrimSpace(value)
|
||||
quality := 1.0
|
||||
|
||||
if idx := strings.Index(value, ";"); idx != -1 {
|
||||
params := value[idx+1:]
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
|
||||
// Parse q=X.X parameter
|
||||
for param := range strings.SplitSeq(params, ";") {
|
||||
param = strings.TrimSpace(param)
|
||||
if strings.HasPrefix(strings.ToLower(param), "q=") {
|
||||
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
|
||||
quality = q
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize encoding to lowercase per RFC 7231
|
||||
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
|
||||
}
|
||||
|
||||
// parseEncoding extracts the encoding name from a header value, stripping quality.
|
||||
func parseEncoding(value string) string {
|
||||
return parseEncodingWithQuality(value).encoding
|
||||
}
|
||||
|
||||
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// If no Content-Encoding header is present, returns empty string.
|
||||
func parseContentEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Content-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encoding := parseEncoding(header)
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if !isSupportedEncoding(encoding) {
|
||||
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
|
||||
}
|
||||
|
||||
return encoding, nil
|
||||
}
|
||||
|
||||
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
|
||||
// the parsed encodings along with the identity rejection state.
|
||||
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
|
||||
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
|
||||
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
|
||||
if header == "" {
|
||||
return nil, false // identity not rejected when header is empty
|
||||
}
|
||||
|
||||
// Parse all encodings with their quality values
|
||||
var encodings []encodingWithQuality
|
||||
for value := range strings.SplitSeq(header, ",") {
|
||||
eq := parseEncodingWithQuality(value)
|
||||
encodings = append(encodings, eq)
|
||||
}
|
||||
|
||||
// Check if identity is rejected per RFC 7231 Section 5.3.4:
|
||||
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
|
||||
// without a more specific entry for identity with q>0.
|
||||
identityRejected := false
|
||||
identityExplicitlyAccepted := false
|
||||
wildcardRejected := false
|
||||
|
||||
for _, eq := range encodings {
|
||||
switch eq.encoding {
|
||||
case EncodingIdentity:
|
||||
if eq.quality == 0 {
|
||||
identityRejected = true
|
||||
} else {
|
||||
identityExplicitlyAccepted = true
|
||||
}
|
||||
case EncodingWildcard:
|
||||
if eq.quality == 0 {
|
||||
wildcardRejected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if wildcardRejected && !identityExplicitlyAccepted {
|
||||
identityRejected = true
|
||||
}
|
||||
|
||||
return encodings, identityRejected
|
||||
}
|
||||
|
||||
// isIdentityAcceptable checks if identity encoding is acceptable based on the
|
||||
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
|
||||
// implicitly acceptable unless explicitly rejected with q=0.
|
||||
func isIdentityAcceptable(r *http.Request) bool {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
_, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
return !identityRejected
|
||||
}
|
||||
|
||||
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
|
||||
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
|
||||
// identity is always implicitly acceptable unless explicitly rejected with q=0.
|
||||
// If no Accept-Encoding header is present, returns empty string (identity).
|
||||
func parseAcceptEncoding(r *http.Request) (string, error) {
|
||||
header := r.Header.Get("Accept-Encoding")
|
||||
if header == "" {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
encodings, identityRejected := parseAcceptEncodingHeader(header)
|
||||
|
||||
// Sort by quality value (highest first)
|
||||
sort.Slice(encodings, func(i, j int) bool {
|
||||
return encodings[i].quality > encodings[j].quality
|
||||
})
|
||||
|
||||
// Find the best supported encoding
|
||||
for _, eq := range encodings {
|
||||
// Skip encodings with q=0 (explicitly rejected)
|
||||
if eq.quality == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if eq.encoding == EncodingIdentity {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
|
||||
if eq.encoding == EncodingWildcard {
|
||||
if identityRejected && len(SupportedEncodings) > 0 {
|
||||
return SupportedEncodings[0], nil
|
||||
}
|
||||
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
if isSupportedEncoding(eq.encoding) {
|
||||
return eq.encoding, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Per RFC 7231, identity is implicitly acceptable unless rejected
|
||||
if !identityRejected {
|
||||
return EncodingIdentity, nil
|
||||
}
|
||||
|
||||
// Identity rejected and no supported encodings found
|
||||
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
|
||||
}
|
||||
|
||||
// getDecompressedBody returns a reader that decompresses the request body based on
|
||||
// Content-Encoding header. Returns the original body if no encoding is specified.
|
||||
// Returns an error if an unsupported encoding is specified.
|
||||
// The caller is responsible for closing both the returned ReadCloser and the
|
||||
// original request body (r.Body) separately.
|
||||
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
|
||||
encoding, err := parseContentEncoding(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encoding == EncodingIdentity {
|
||||
return r.Body, nil
|
||||
}
|
||||
|
||||
switch encoding {
|
||||
case EncodingGzip:
|
||||
gzReader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
|
||||
return gzReader, nil
|
||||
default:
|
||||
// This shouldn't happen if isSupportedEncoding is correct
|
||||
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
|
||||
}
|
||||
}
|
||||
@ -1,496 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSupportedEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("gzip is supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("GZIP"))
|
||||
})
|
||||
|
||||
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, isSupportedEncoding("Gzip"))
|
||||
})
|
||||
|
||||
t.Run("br is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("br"))
|
||||
})
|
||||
|
||||
t.Run("deflate is not supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isSupportedEncoding("deflate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncodingWithQuality(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("parses quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip ; q=0.8")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.8, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=0")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("handles invalid quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("gzip;q=invalid")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
|
||||
})
|
||||
|
||||
t.Run("trims whitespace from encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality(" gzip ")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||
})
|
||||
|
||||
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("GZIP")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
})
|
||||
|
||||
t.Run("normalizes mixed case encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
eq := parseEncodingWithQuality("Gzip;q=0.5")
|
||||
assert.Equal(t, "gzip", eq.encoding)
|
||||
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns encoding as-is", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip"))
|
||||
})
|
||||
|
||||
t.Run("trims whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding(" gzip "))
|
||||
})
|
||||
|
||||
t.Run("strips quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
|
||||
})
|
||||
|
||||
t.Run("strips quality value with whitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseContentEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "Gzip")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := parseContentEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
assert.Contains(t, err.Error(), "supported: [gzip]")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseContentEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns identity when no header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "GZIP")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding)
|
||||
})
|
||||
|
||||
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("skips encoding with q=0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, identity;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*, identity;q=0")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
|
||||
})
|
||||
|
||||
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0")
|
||||
|
||||
_, err := parseAcceptEncoding(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||
})
|
||||
|
||||
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, identity")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingIdentity, encoding)
|
||||
})
|
||||
|
||||
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
|
||||
|
||||
encoding, err := parseAcceptEncoding(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EncodingGzip, encoding)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetDecompressedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid gzip data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
invalidGzip := []byte("this is not gzip data")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
|
||||
req.Header.Set("Content-Encoding", "gzip")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to create gzip reader")
|
||||
})
|
||||
|
||||
t.Run("returns original body for identity encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "identity")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, req.Body, body, "should return original body")
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
content := []byte("test content")
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||
req.Header.Set("Content-Encoding", "br")
|
||||
|
||||
_, err := getDecompressedBody(req)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||
})
|
||||
|
||||
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
originalContent := []byte("test content to compress")
|
||||
|
||||
var compressed bytes.Buffer
|
||||
gw := gzip.NewWriter(&compressed)
|
||||
_, err := gw.Write(originalContent)
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||
|
||||
body, err := getDecompressedBody(req)
|
||||
require.NoError(t, err)
|
||||
defer body.Close()
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalContent, data)
|
||||
})
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
)
|
||||
|
||||
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
|
||||
|
||||
envs := make(EnvVars)
|
||||
a.defaults.EnvVars.Range(func(key, value string) bool {
|
||||
envs[key] = value
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(envs); err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
|
||||
}
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func jsonError(w http.ResponseWriter, code int, err error) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
w.WriteHeader(code)
|
||||
encodeErr := json.NewEncoder(w).Encode(Error{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
})
|
||||
if encodeErr != nil {
|
||||
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml
|
||||
@ -1,296 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/txn2/txeh"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||
)
|
||||
|
||||
// validateInitAccessToken validates the access token for /init requests.
|
||||
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||
// If neither exists, first-time setup is allowed.
|
||||
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||
requestTokenSet := requestToken.IsSet()
|
||||
|
||||
// Fast path: token matches existing
|
||||
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check MMDS only if token didn't match existing
|
||||
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||
|
||||
switch {
|
||||
case matchesMMDS:
|
||||
return nil
|
||||
case !a.accessToken.IsSet() && !mmdsExists:
|
||||
return nil // first-time setup
|
||||
case !requestTokenSet:
|
||||
return ErrAccessTokenResetNotAuthorized
|
||||
default:
|
||||
return ErrAccessTokenMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||
// Returns (matches, mmdsExists).
|
||||
//
|
||||
// The MMDS hash is set by the orchestrator during Resume:
|
||||
// - hash(token): requires this specific token
|
||||
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||
// - "": MMDS not properly configured, no authorization granted
|
||||
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||
if a.isNotFC {
|
||||
return false, false
|
||||
}
|
||||
|
||||
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if mmdsHash == "" {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if !requestToken.IsSet() {
|
||||
return mmdsHash == keys.HashAccessToken(""), true
|
||||
}
|
||||
|
||||
tokenBytes, err := requestToken.Bytes()
|
||||
if err != nil {
|
||||
return false, true
|
||||
}
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||
}
|
||||
|
||||
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||
|
||||
if r.Body != nil {
|
||||
// Read raw body so we can wipe it after parsing
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// Ensure body is wiped after we're done
|
||||
defer memguard.WipeBytes(body)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var initRequest PostInitJSONBody
|
||||
if len(body) > 0 {
|
||||
err = json.Unmarshal(body, &initRequest)
|
||||
if err != nil {
|
||||
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||
defer initRequest.AccessToken.Destroy()
|
||||
|
||||
a.initLock.Lock()
|
||||
defer a.initLock.Unlock()
|
||||
|
||||
// Update data only if the request is newer or if there's no timestamp at all
|
||||
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||
err = a.SetData(ctx, logger, initRequest)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Error().Msgf("Failed to set data: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Write([]byte(err.Error()))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() { //nolint:contextcheck // TODO: fix this later
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||
}()
|
||||
|
||||
// Start the port scanner and forwarder if they were stopped by a
|
||||
// pre-snapshot prepare call. Start is a no-op if already running,
|
||||
// so this is safe on first boot and only takes effect after restore.
|
||||
if a.portSubsystem != nil {
|
||||
a.portSubsystem.Start(a.rootCtx)
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "")
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||
// Validate access token before proceeding with any action
|
||||
// The request must provide a token that is either:
|
||||
// 1. Matches the existing access token (if set), OR
|
||||
// 2. Matches the MMDS hash (for token change during resume)
|
||||
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.EnvVars != nil {
|
||||
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||
|
||||
for key, value := range *data.EnvVars {
|
||||
logger.Debug().Msgf("Setting env var for %s", key)
|
||||
a.defaults.EnvVars.Store(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if data.AccessToken.IsSet() {
|
||||
logger.Debug().Msg("Setting access token")
|
||||
a.accessToken.TakeFrom(data.AccessToken)
|
||||
} else if a.accessToken.IsSet() {
|
||||
logger.Debug().Msg("Clearing access token")
|
||||
a.accessToken.Destroy()
|
||||
}
|
||||
|
||||
if data.HyperloopIP != nil {
|
||||
go a.SetupHyperloop(*data.HyperloopIP)
|
||||
}
|
||||
|
||||
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||
a.defaults.User = *data.DefaultUser
|
||||
}
|
||||
|
||||
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||
a.defaults.Workdir = data.DefaultWorkdir
|
||||
}
|
||||
|
||||
if data.VolumeMounts != nil {
|
||||
for _, volume := range *data.VolumeMounts {
|
||||
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||
|
||||
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||
commands := [][]string{
|
||||
{"mkdir", "-p", path},
|
||||
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||
}
|
||||
|
||||
for _, command := range commands {
|
||||
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||
|
||||
logger := a.getLogger(err)
|
||||
|
||||
logger.
|
||||
Strs("command", command).
|
||||
Str("output", string(data)).
|
||||
Msg("Mount NFS")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) SetupHyperloop(address string) {
|
||||
a.hyperloopLock.Lock()
|
||||
defer a.hyperloopLock.Unlock()
|
||||
|
||||
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||
} else {
|
||||
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||
}
|
||||
}
|
||||
|
||||
const eventsHost = "events.wrenn.local"
|
||||
|
||||
func rewriteHostsFile(address, path string) error {
|
||||
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||
ReadFilePath: path,
|
||||
WriteFilePath: path,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create hosts: %w", err)
|
||||
}
|
||||
|
||||
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||
// This will remove any existing entries for events.wrenn.local first
|
||||
ipFamily, err := getIPFamily(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ip family: %w", err)
|
||||
}
|
||||
|
||||
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||
return nil // nothing to be done
|
||||
}
|
||||
|
||||
hosts.AddHost(address, eventsHost)
|
||||
|
||||
return hosts.Save()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAddress = errors.New("invalid IP address")
|
||||
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||
)
|
||||
|
||||
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||
addressIP, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case addressIP.Is4():
|
||||
return txeh.IPFamilyV4, nil
|
||||
case addressIP.Is6():
|
||||
return txeh.IPFamilyV6, nil
|
||||
default:
|
||||
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||
}
|
||||
}
|
||||
@ -1,524 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
func TestSimpleCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]func(string) string{
|
||||
"both newlines": func(s string) string { return s },
|
||||
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
|
||||
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
|
||||
"no newline prefix or suffix": strings.TrimSpace,
|
||||
}
|
||||
|
||||
for name, preprocessor := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
value := `
|
||||
# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
`
|
||||
value = preprocessor(value)
|
||||
inputPath := filepath.Join(tempDir, "hosts")
|
||||
err := os.WriteFile(inputPath, []byte(value), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rewriteHostsFile("127.0.0.3", inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(inputPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `# comment
|
||||
127.0.0.1 one.host
|
||||
127.0.0.2 two.host
|
||||
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func secureTokenPtr(s string) *SecureToken {
|
||||
token := &SecureToken{}
|
||||
_ = token.Set([]byte(s))
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
type mockMMDSClient struct {
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
|
||||
return m.hash, m.err
|
||||
}
|
||||
|
||||
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
||||
logger := zerolog.Nop()
|
||||
defaults := &execcontext.Defaults{
|
||||
EnvVars: utils.NewMap[string, string](),
|
||||
}
|
||||
api := New(&logger, defaults, nil, false, context.Background(), nil, "test")
|
||||
if accessToken != nil {
|
||||
api.accessToken.TakeFrom(accessToken)
|
||||
}
|
||||
api.mmdsClient = mmdsClient
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
func TestValidateInitAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "fast path: token matches existing",
|
||||
accessToken: secureTokenPtr("secret-token"),
|
||||
requestToken: secureTokenPtr("secret-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "MMDS match: token hash matches MMDS hash",
|
||||
accessToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, MMDS error",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: no existing token, empty MMDS hash",
|
||||
accessToken: nil,
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "first-time setup: both tokens nil, no MMDS",
|
||||
accessToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "mismatch: existing token differs from request, MMDS hash mismatch",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, MMDS exists",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
{
|
||||
name: "conflict: existing token, nil request, no MMDS",
|
||||
accessToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.accessToken, mmdsClient)
|
||||
|
||||
err := api.validateInitAccessToken(ctx, tt.requestToken)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMMDSHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
token := "my-secret-token"
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.False(t, matches)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||
|
||||
assert.True(t, matches)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetData(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := zerolog.Nop()
|
||||
|
||||
t.Run("access token updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existingToken *SecureToken
|
||||
requestToken *SecureToken
|
||||
mmdsHash string
|
||||
mmdsErr error
|
||||
wantErr error
|
||||
wantFinalToken *SecureToken
|
||||
}{
|
||||
{
|
||||
name: "first-time setup: sets initial token",
|
||||
existingToken: nil,
|
||||
requestToken: secureTokenPtr("initial-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("initial-token"),
|
||||
},
|
||||
{
|
||||
name: "first-time setup: nil request token leaves token unset",
|
||||
existingToken: nil,
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
{
|
||||
name: "re-init with same token: token unchanged",
|
||||
existingToken: secureTokenPtr("same-token"),
|
||||
requestToken: secureTokenPtr("same-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("same-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: updates token when hash matches",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("new-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: secureTokenPtr("new-token"),
|
||||
},
|
||||
{
|
||||
name: "resume with MMDS: fails when hash doesn't match",
|
||||
existingToken: secureTokenPtr("old-token"),
|
||||
requestToken: secureTokenPtr("new-token"),
|
||||
mmdsHash: keys.HashAccessToken("different-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("old-token"),
|
||||
},
|
||||
{
|
||||
name: "fails when existing token and request token mismatch without MMDS",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: secureTokenPtr("wrong-token"),
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenMismatch,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request token",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: assert.AnError,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when existing token but nil request with MMDS present",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken("some-token"),
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: "",
|
||||
mmdsErr: nil,
|
||||
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||
wantFinalToken: secureTokenPtr("existing-token"),
|
||||
},
|
||||
{
|
||||
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
|
||||
existingToken: secureTokenPtr("existing-token"),
|
||||
requestToken: nil,
|
||||
mmdsHash: keys.HashAccessToken(""),
|
||||
mmdsErr: nil,
|
||||
wantErr: nil,
|
||||
wantFinalToken: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||
api := newTestAPI(tt.existingToken, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: tt.requestToken,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.wantFinalToken == nil {
|
||||
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
|
||||
} else {
|
||||
require.True(t, api.accessToken.IsSet(), "expected token to be set")
|
||||
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets environment variables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
|
||||
data := PostInitJSONBody{
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
val, ok := api.defaults.EnvVars.Load("FOO")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bar", val)
|
||||
val, ok = api.defaults.EnvVars.Load("BAZ")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "qux", val)
|
||||
})
|
||||
|
||||
t.Run("sets default user", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr("testuser"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("does not set default user when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
api.defaults.User = "original"
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultUser: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "original", api.defaults.User)
|
||||
})
|
||||
|
||||
t.Run("sets default workdir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/home/user", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("does not set default workdir when empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
originalWorkdir := "/original"
|
||||
api.defaults.Workdir = &originalWorkdir
|
||||
|
||||
data := PostInitJSONBody{
|
||||
DefaultWorkdir: utilsShared.ToPtr(""),
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, api.defaults.Workdir)
|
||||
assert.Equal(t, "/original", *api.defaults.Workdir)
|
||||
})
|
||||
|
||||
t.Run("sets multiple fields at once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||
api := newTestAPI(nil, mmdsClient)
|
||||
|
||||
envVars := EnvVars{"KEY": "value"}
|
||||
data := PostInitJSONBody{
|
||||
AccessToken: secureTokenPtr("token"),
|
||||
DefaultUser: utilsShared.ToPtr("user"),
|
||||
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
|
||||
EnvVars: &envVars,
|
||||
}
|
||||
|
||||
err := api.SetData(ctx, logger, data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
|
||||
assert.Equal(t, "user", api.defaults.User)
|
||||
assert.Equal(t, "/workdir", *api.defaults.Workdir)
|
||||
val, ok := api.defaults.EnvVars.Load("KEY")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value", val)
|
||||
})
|
||||
}
|
||||
@ -1,214 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotSet = errors.New("access token not set")
|
||||
ErrTokenEmpty = errors.New("empty token not allowed")
|
||||
)
|
||||
|
||||
// SecureToken wraps memguard for secure token storage.
|
||||
// It uses LockedBuffer which provides memory locking, guard pages,
|
||||
// and secure zeroing on destroy.
|
||||
type SecureToken struct {
|
||||
mu sync.RWMutex
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// Set securely replaces the token, destroying the old one first.
|
||||
// The old token memory is zeroed before the new token is stored.
|
||||
// The input byte slice is wiped after copying to secure memory.
|
||||
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
|
||||
func (s *SecureToken) Set(token []byte) error {
|
||||
if len(token) == 0 {
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Destroy old token first (zeros memory)
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
|
||||
s.buffer = memguard.NewBufferFromBytes(token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
|
||||
// directly into memguard, wiping the input bytes after copying.
|
||||
//
|
||||
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
|
||||
// so they never contain JSON escape sequences.
|
||||
func (s *SecureToken) UnmarshalJSON(data []byte) error {
|
||||
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
|
||||
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token JSON string")
|
||||
}
|
||||
|
||||
content := data[1 : len(data)-1]
|
||||
|
||||
// Access tokens are hex strings - reject if contains backslash
|
||||
if bytes.ContainsRune(content, '\\') {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return errors.New("invalid secure token: unexpected escape sequence")
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return ErrTokenEmpty
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
|
||||
// Allocate secure buffer and copy directly into it
|
||||
s.buffer = memguard.NewBuffer(len(content))
|
||||
copy(s.buffer.Bytes(), content)
|
||||
|
||||
// Wipe the input data
|
||||
memguard.WipeBytes(data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TakeFrom transfers the token from src to this SecureToken, destroying any
|
||||
// existing token. The source token is cleared after transfer.
|
||||
// This avoids copying the underlying bytes.
|
||||
func (s *SecureToken) TakeFrom(src *SecureToken) {
|
||||
if src == nil || s == src {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract buffer from source
|
||||
src.mu.Lock()
|
||||
buffer := src.buffer
|
||||
src.buffer = nil
|
||||
src.mu.Unlock()
|
||||
|
||||
// Install buffer in destination
|
||||
s.mu.Lock()
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
}
|
||||
s.buffer = buffer
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Equals checks if token matches using constant-time comparison.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) Equals(token string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo([]byte(token))
|
||||
}
|
||||
|
||||
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
|
||||
// Returns false if either receiver or other is nil.
|
||||
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
|
||||
if s == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s == other {
|
||||
return s.IsSet()
|
||||
}
|
||||
|
||||
// Get a copy of other's bytes (avoids holding two locks simultaneously)
|
||||
otherBytes, err := other.Bytes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer memguard.WipeBytes(otherBytes)
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return false
|
||||
}
|
||||
|
||||
return s.buffer.EqualTo(otherBytes)
|
||||
}
|
||||
|
||||
// IsSet returns true if a token is stored.
|
||||
// Returns false if the receiver is nil.
|
||||
func (s *SecureToken) IsSet() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.buffer != nil && s.buffer.IsAlive()
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the token bytes (for signature generation).
|
||||
// The caller should zero the returned slice after use.
|
||||
// Returns ErrTokenNotSet if the receiver is nil.
|
||||
func (s *SecureToken) Bytes() ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||
return nil, ErrTokenNotSet
|
||||
}
|
||||
|
||||
// Return a copy (unavoidable for signature generation)
|
||||
src := s.buffer.Bytes()
|
||||
result := make([]byte, len(src))
|
||||
copy(result, src)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Destroy securely wipes the token from memory.
|
||||
// No-op if the receiver is nil.
|
||||
func (s *SecureToken) Destroy() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
}
|
||||
@ -1,463 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSecureTokenSetAndEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Initially not set
|
||||
assert.False(t, st.IsSet(), "token should not be set initially")
|
||||
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
|
||||
|
||||
// Set token
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet(), "token should be set after Set()")
|
||||
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
|
||||
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
|
||||
assert.False(t, st.Equals(""), "equals should return false for empty token")
|
||||
}
|
||||
|
||||
func TestSecureTokenReplace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set initial token
|
||||
err := st.Set([]byte("first-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("first-token"))
|
||||
|
||||
// Replace with new token (old one should be destroyed)
|
||||
err = st.Set([]byte("second-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("second-token"), "should match new token")
|
||||
assert.False(t, st.Equals("first-token"), "should not match old token")
|
||||
}
|
||||
|
||||
func TestSecureTokenDestroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set and then destroy
|
||||
err := st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
|
||||
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
|
||||
|
||||
// Destroy on already destroyed should be safe
|
||||
st.Destroy()
|
||||
assert.False(t, st.IsSet())
|
||||
|
||||
// Nil receiver should be safe
|
||||
var nilToken *SecureToken
|
||||
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
|
||||
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
|
||||
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
|
||||
nilToken.Destroy() // should not panic
|
||||
|
||||
_, err = nilToken.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
|
||||
}
|
||||
|
||||
func TestSecureTokenBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Bytes should return error when not set
|
||||
_, err := st.Bytes()
|
||||
require.ErrorIs(t, err, ErrTokenNotSet)
|
||||
|
||||
// Set token and get bytes
|
||||
err = st.Set([]byte("test-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes, err := st.Bytes()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("test-token"), bytes)
|
||||
|
||||
// Zero out the bytes (as caller should do)
|
||||
memguard.WipeBytes(bytes)
|
||||
|
||||
// Original should still be intact
|
||||
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
|
||||
|
||||
// After destroy, bytes should fail
|
||||
st.Destroy()
|
||||
_, err = st.Bytes()
|
||||
assert.ErrorIs(t, err, ErrTokenNotSet)
|
||||
}
|
||||
|
||||
func TestSecureTokenConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("initial-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 100
|
||||
|
||||
// Concurrent reads
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
st.IsSet()
|
||||
st.Equals("initial-token")
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrent writes
|
||||
for i := range 10 {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
st.Set([]byte("token-" + string(rune('a'+idx))))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should still be in a valid state
|
||||
assert.True(t, st.IsSet())
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Setting empty token should return an error
|
||||
err := st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after empty token error")
|
||||
|
||||
// Setting nil should also return an error
|
||||
err = st.Set(nil)
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet(), "token should not be set after nil token error")
|
||||
}
|
||||
|
||||
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
st := &SecureToken{}
|
||||
|
||||
// Set a valid token first
|
||||
err := st.Set([]byte("valid-token"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
|
||||
// Attempting to set empty token should fail and preserve existing token
|
||||
err = st.Set([]byte{})
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
|
||||
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
|
||||
}
|
||||
|
||||
func TestSecureTokenUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unmarshals valid JSON string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.IsSet())
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
})
|
||||
|
||||
t.Run("returns error for empty string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`""`))
|
||||
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("returns error for invalid JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`not-valid-json`))
|
||||
require.Error(t, err)
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
|
||||
t.Run("replaces existing token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = st.UnmarshalJSON([]byte(`"new-token"`))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, st.Equals("new-token"))
|
||||
assert.False(t, st.Equals("old-token"))
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer after parsing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte(`"secret-token-12345"`)
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("secret-token-12345"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wipes input buffer on error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with an empty token (will error)
|
||||
input := []byte(`""`)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON(input)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the input buffer was still wiped
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects escape sequences", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "escape sequence")
|
||||
assert.False(t, st.IsSet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenSetWipesInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("wipes input buffer after storing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a buffer with a known token
|
||||
input := []byte("my-secret-token")
|
||||
original := make([]byte, len(input))
|
||||
copy(original, input)
|
||||
|
||||
st := &SecureToken{}
|
||||
err := st.Set(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token was stored correctly
|
||||
assert.True(t, st.Equals("my-secret-token"))
|
||||
|
||||
// Verify the input buffer was wiped (all zeros)
|
||||
for i, b := range input {
|
||||
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenTakeFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transfers token from source to destination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("source-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.IsSet())
|
||||
assert.True(t, dst.Equals("source-token"))
|
||||
assert.False(t, src.IsSet(), "source should be empty after transfer")
|
||||
})
|
||||
|
||||
t.Run("replaces existing destination token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
err := src.Set([]byte("new-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst := &SecureToken{}
|
||||
err = dst.Set([]byte("old-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.True(t, dst.Equals("new-token"))
|
||||
assert.False(t, dst.Equals("old-token"))
|
||||
assert.False(t, src.IsSet())
|
||||
})
|
||||
|
||||
t.Run("handles nil source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(nil)
|
||||
|
||||
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
|
||||
assert.True(t, dst.Equals("existing-token"))
|
||||
})
|
||||
|
||||
t.Run("handles empty source", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src := &SecureToken{}
|
||||
dst := &SecureToken{}
|
||||
err := dst.Set([]byte("existing-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
dst.TakeFrom(src)
|
||||
|
||||
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
|
||||
})
|
||||
|
||||
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st.TakeFrom(st)
|
||||
|
||||
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
|
||||
assert.True(t, st.Equals("token"), "token value should be unchanged")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureTokenEqualsSecure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns true for matching tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("same-token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st1.EqualsSecure(st2))
|
||||
assert.True(t, st2.EqualsSecure(st1))
|
||||
})
|
||||
|
||||
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// This test verifies the fix for the lock ordering deadlock bug.
|
||||
|
||||
const iterations = 100
|
||||
|
||||
for range iterations {
|
||||
a := &SecureToken{}
|
||||
err := a.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &SecureToken{}
|
||||
err = b.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Goroutine 1: a.TakeFrom(b)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.TakeFrom(b)
|
||||
}()
|
||||
|
||||
// Goroutine 2: b.EqualsSecure(a)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.EqualsSecure(a)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false for different tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token-a"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err = st2.Set([]byte("token-b"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when comparing with nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st.EqualsSecure(nil))
|
||||
})
|
||||
|
||||
t.Run("returns false when other is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
err := st1.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
st2 := &SecureToken{}
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("returns false when self is not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st1 := &SecureToken{}
|
||||
|
||||
st2 := &SecureToken{}
|
||||
err := st2.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, st1.EqualsSecure(st2))
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns true when set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
err := st.Set([]byte("token"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
|
||||
})
|
||||
|
||||
t.Run("self-comparison returns false when not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := &SecureToken{}
|
||||
|
||||
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
|
||||
})
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder)
|
||||
// and forces a GC cycle before Firecracker takes a VM snapshot. This ensures
|
||||
// the Go runtime's page allocator is in a consistent state when vCPUs are frozen.
|
||||
//
|
||||
// Called by the host agent as a best-effort signal before vm.Pause().
|
||||
func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
if a.portSubsystem != nil {
|
||||
a.portSubsystem.Stop()
|
||||
a.logger.Info().Msg("snapshot/prepare: port subsystem quiesced")
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -1,108 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
// MMDSClient provides access to MMDS metadata.
|
||||
type MMDSClient interface {
|
||||
GetAccessTokenHash(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
|
||||
type DefaultMMDSClient struct{}
|
||||
|
||||
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
|
||||
return host.GetAccessTokenHashFromMMDS(ctx)
|
||||
}
|
||||
|
||||
type API struct {
|
||||
isNotFC bool
|
||||
logger *zerolog.Logger
|
||||
accessToken *SecureToken
|
||||
defaults *execcontext.Defaults
|
||||
version string
|
||||
|
||||
mmdsChan chan *host.MMDSOpts
|
||||
hyperloopLock sync.Mutex
|
||||
mmdsClient MMDSClient
|
||||
|
||||
lastSetTime *utils.AtomicMax
|
||||
initLock sync.Mutex
|
||||
|
||||
// rootCtx is the parent context from main(), used to restart
|
||||
// long-lived goroutines after snapshot restore.
|
||||
rootCtx context.Context
|
||||
portSubsystem *publicport.PortSubsystem
|
||||
}
|
||||
|
||||
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem, version string) *API {
|
||||
return &API{
|
||||
logger: l,
|
||||
defaults: defaults,
|
||||
mmdsChan: mmdsChan,
|
||||
isNotFC: isNotFC,
|
||||
mmdsClient: &DefaultMMDSClient{},
|
||||
lastSetTime: utils.NewAtomicMax(),
|
||||
accessToken: &SecureToken{},
|
||||
rootCtx: rootCtx,
|
||||
portSubsystem: portSubsystem,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Health check")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"version": a.version,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
a.logger.Trace().Msg("Get metrics")
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
metrics, err := host.GetMetrics()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get metrics")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to encode metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) getLogger(err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
|
||||
}
|
||||
|
||||
return a.logger.Info() //nolint:zerologlint // this is only prep
|
||||
}
|
||||
@ -1,311 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
|
||||
|
||||
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
|
||||
logger.Debug().
|
||||
Str("path", path).
|
||||
Msg("File processing")
|
||||
|
||||
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error ensuring directories: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
canBePreChowned := false
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
errMsg := fmt.Errorf("error getting file info: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, errMsg
|
||||
} else if err == nil {
|
||||
if stat.IsDir() {
|
||||
err := fmt.Errorf("path is a directory: %s", path)
|
||||
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
canBePreChowned = true
|
||||
}
|
||||
|
||||
hasBeenChowned := false
|
||||
if canBePreChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
err = fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
} else {
|
||||
hasBeenChowned = true
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = fmt.Errorf("not enough inodes available: %w", err)
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err := fmt.Errorf("error opening file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
if !hasBeenChowned {
|
||||
err = os.Chown(path, uid, gid)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error changing file ownership: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = file.ReadFrom(part)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
err = ErrNoDiskSpace
|
||||
if r.ContentLength > 0 {
|
||||
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
|
||||
}
|
||||
|
||||
return http.StatusInsufficientStorage, err
|
||||
}
|
||||
|
||||
err = fmt.Errorf("error writing file: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return http.StatusNoContent, nil
|
||||
}
|
||||
|
||||
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
|
||||
var pathToResolve string
|
||||
|
||||
if params.Path != nil {
|
||||
pathToResolve = *params.Path
|
||||
} else {
|
||||
var err error
|
||||
customPart := utils.NewCustomPart(part)
|
||||
pathToResolve, err = customPart.FileNameWithPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving path: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range *paths {
|
||||
if entry.Path == filePath {
|
||||
var alreadyUploaded []string
|
||||
for _, uploadedFile := range *paths {
|
||||
if uploadedFile.Path != filePath {
|
||||
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
|
||||
|
||||
if len(alreadyUploaded) > 1 {
|
||||
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
|
||||
}
|
||||
|
||||
return "", errMsg
|
||||
}
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
|
||||
defer part.Close()
|
||||
|
||||
if part.FormName() != "file" {
|
||||
return nil, http.StatusOK, nil
|
||||
}
|
||||
|
||||
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
logger := a.logger.
|
||||
With().
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("event_type", "file_processing").
|
||||
Logger()
|
||||
|
||||
status, err := processFile(r, filePath, part, uid, gid, logger)
|
||||
if err != nil {
|
||||
return nil, status, err
|
||||
}
|
||||
|
||||
return &EntryInfo{
|
||||
Path: filePath,
|
||||
Name: filepath.Base(filePath),
|
||||
Type: File,
|
||||
}, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||
// Capture original body to ensure it's always closed
|
||||
originalBody := r.Body
|
||||
defer originalBody.Close()
|
||||
|
||||
var errorCode int
|
||||
var errMsg error
|
||||
|
||||
var path string
|
||||
if params.Path != nil {
|
||||
path = *params.Path
|
||||
}
|
||||
|
||||
operationID := logs.AssignOperationID()
|
||||
|
||||
// signing authorization if needed
|
||||
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||
jsonError(w, http.StatusUnauthorized, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||
jsonError(w, http.StatusBadRequest, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l := a.logger.
|
||||
Err(errMsg).
|
||||
Str("method", r.Method+" "+r.URL.Path).
|
||||
Str(string(logs.OperationIDKey), operationID).
|
||||
Str("path", path).
|
||||
Str("username", username)
|
||||
|
||||
if errMsg != nil {
|
||||
l = l.Int("error_code", errorCode)
|
||||
}
|
||||
|
||||
l.Msg("File write")
|
||||
}()
|
||||
|
||||
// Handle gzip-encoded request body
|
||||
body, err := getDecompressedBody(r)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error decompressing request body: %w", err)
|
||||
errorCode = http.StatusBadRequest
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = body
|
||||
|
||||
f, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
errorCode = http.StatusUnauthorized
|
||||
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
uid, gid, err := permissions.GetUserIdInts(u)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error getting user ids: %w", err)
|
||||
|
||||
jsonError(w, http.StatusInternalServerError, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
paths := UploadSuccess{}
|
||||
|
||||
for {
|
||||
part, partErr := f.NextPart()
|
||||
|
||||
if partErr == io.EOF {
|
||||
// We're done reading the parts.
|
||||
break
|
||||
} else if partErr != nil {
|
||||
errMsg = fmt.Errorf("error reading form: %w", partErr)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
|
||||
if err != nil {
|
||||
errorCode = status
|
||||
errMsg = err
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if entry != nil {
|
||||
paths = append(paths, *entry)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(paths)
|
||||
if err != nil {
|
||||
errMsg = fmt.Errorf("error marshaling response: %w", err)
|
||||
errorCode = http.StatusInternalServerError
|
||||
jsonError(w, errorCode, errMsg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
@ -1,251 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProcessFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
uid := os.Getuid()
|
||||
gid := os.Getgid()
|
||||
|
||||
newRequest := func(content []byte) (*http.Request, io.Reader) {
|
||||
request := &http.Request{
|
||||
ContentLength: int64(len(content)),
|
||||
}
|
||||
buffer := bytes.NewBuffer(content)
|
||||
|
||||
return request, buffer
|
||||
}
|
||||
|
||||
var emptyReq http.Request
|
||||
var emptyPart *bytes.Buffer
|
||||
var emptyLogger zerolog.Logger
|
||||
|
||||
t.Run("failed to ensure directories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error ensuring directories: ")
|
||||
})
|
||||
|
||||
t.Run("attempt to replace directory with a file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
|
||||
assert.ErrorContains(t, err, "path is a directory: ")
|
||||
})
|
||||
|
||||
t.Run("fail to create file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||
assert.ErrorContains(t, err, "error opening file: ")
|
||||
})
|
||||
|
||||
t.Run("out of disk space", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
mountSize := 1024
|
||||
tempDir := createTmpfsMount(t, mountSize)
|
||||
|
||||
// create test file
|
||||
firstFileSize := mountSize / 2
|
||||
tempFile1 := filepath.Join(tempDir, "test-file-1")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new file that would fill up the
|
||||
secondFileContents := make([]byte, mountSize*2)
|
||||
for index := range secondFileContents {
|
||||
secondFileContents[index] = 'a'
|
||||
}
|
||||
|
||||
// try to replace it
|
||||
request, buffer := newRequest(secondFileContents)
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
|
||||
})
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
})
|
||||
|
||||
t.Run("overwrite file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to replace it
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file on full disk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
sizeInBytes := 1024
|
||||
tempDir := createTmpfsMount(t, 1024)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough disk space available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("write new file with no inodes available", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// make a tiny tmpfs mount
|
||||
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
|
||||
|
||||
// create test file
|
||||
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||
|
||||
// fill it up
|
||||
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// try to write a new file
|
||||
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||
content := []byte("test-file-contents")
|
||||
request, buffer := newRequest(content)
|
||||
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||
require.ErrorContains(t, err, "not enough inodes available")
|
||||
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||
})
|
||||
|
||||
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
|
||||
newContent := []byte("102\n")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
|
||||
t.Run("replace file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "test-file")
|
||||
|
||||
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
newContent := []byte("new-file-contents")
|
||||
request, buffer := newRequest(newContent)
|
||||
|
||||
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newContent, data)
|
||||
})
|
||||
}
|
||||
|
||||
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
|
||||
t.Helper()
|
||||
|
||||
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
|
||||
}
|
||||
|
||||
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
|
||||
t.Helper()
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cmd := exec.CommandContext(t.Context(),
|
||||
"mount",
|
||||
"tmpfs",
|
||||
tempDir,
|
||||
"-t", "tmpfs",
|
||||
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx := context.WithoutCancel(t.Context())
|
||||
cmd := exec.CommandContext(ctx, "umount", tempDir)
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return tempDir
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package execcontext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
type Defaults struct {
|
||||
EnvVars *utils.Map[string, string]
|
||||
User string
|
||||
Workdir *string
|
||||
}
|
||||
|
||||
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
|
||||
if workdir != "" {
|
||||
return workdir
|
||||
}
|
||||
|
||||
if defaultWorkdir != nil {
|
||||
return *defaultWorkdir
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
|
||||
if username != nil {
|
||||
return *username, nil
|
||||
}
|
||||
|
||||
if defaultUsername != "" {
|
||||
return defaultUsername, nil
|
||||
}
|
||||
|
||||
return "", errors.New("username not provided")
|
||||
}
|
||||
@ -1,96 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package host
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
|
||||
|
||||
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
|
||||
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
|
||||
|
||||
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
|
||||
|
||||
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
|
||||
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
|
||||
|
||||
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
|
||||
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
|
||||
}
|
||||
|
||||
func GetMetrics() (*Metrics, error) {
|
||||
v, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memUsedMiB := v.Used / 1024 / 1024
|
||||
memTotalMiB := v.Total / 1024 / 1024
|
||||
|
||||
cpuTotal, err := cpu.Counts(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPcts, err := cpu.Percent(0, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuUsedPct := cpuUsedPcts[0]
|
||||
cpuUsedPctRounded := float32(cpuUsedPct)
|
||||
if cpuUsedPct > 0 {
|
||||
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
|
||||
}
|
||||
|
||||
diskMetrics, err := diskStats("/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
Timestamp: time.Now().UTC().Unix(),
|
||||
CPUCount: uint32(cpuTotal),
|
||||
CPUUsedPercent: cpuUsedPctRounded,
|
||||
MemUsedMiB: memUsedMiB,
|
||||
MemTotalMiB: memTotalMiB,
|
||||
MemTotal: v.Total,
|
||||
MemUsed: v.Used,
|
||||
DiskUsed: diskMetrics.Total - diskMetrics.Available,
|
||||
DiskTotal: diskMetrics.Total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type diskSpace struct {
|
||||
Total uint64
|
||||
Available uint64
|
||||
}
|
||||
|
||||
func diskStats(path string) (diskSpace, error) {
|
||||
var st unix.Statfs_t
|
||||
if err := unix.Statfs(path, &st); err != nil {
|
||||
return diskSpace{}, err
|
||||
}
|
||||
|
||||
block := uint64(st.Bsize)
|
||||
|
||||
// all data blocks
|
||||
total := st.Blocks * block
|
||||
// blocks available
|
||||
available := st.Bavail * block
|
||||
|
||||
return diskSpace{Total: total, Available: available}, nil
|
||||
}
|
||||
@ -1,185 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
|
||||
|
||||
mmdsDefaultAddress = "169.254.169.254"
|
||||
mmdsTokenExpiration = 60 * time.Second
|
||||
|
||||
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var mmdsAccessTokenClient = &http.Client{
|
||||
Timeout: mmdsAccessTokenRequestClientTimeout,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
}
|
||||
|
||||
type MMDSOpts struct {
|
||||
SandboxID string `json:"instanceID"`
|
||||
TemplateID string `json:"envID"`
|
||||
LogsCollectorAddress string `json:"address"`
|
||||
AccessTokenHash string `json:"accessTokenHash"`
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
|
||||
opts.SandboxID = sandboxID
|
||||
opts.TemplateID = templateID
|
||||
opts.LogsCollectorAddress = collectorAddress
|
||||
}
|
||||
|
||||
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
|
||||
parsed := make(map[string]any)
|
||||
|
||||
err := json.Unmarshal(jsonLogs, &parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed["instanceID"] = opts.SandboxID
|
||||
parsed["envID"] = opts.TemplateID
|
||||
|
||||
data, err := json.Marshal(parsed)
|
||||
|
||||
return data, err
|
||||
}
|
||||
|
||||
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := string(body)
|
||||
|
||||
if len(token) == 0 {
|
||||
return "", fmt.Errorf("mmds token is an empty string")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header["X-metadata-token"] = []string{token}
|
||||
request.Header["Accept"] = []string{"application/json"}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts MMDSOpts
|
||||
|
||||
err = json.Unmarshal(body, &opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
|
||||
// This is used to validate that /init requests come from the orchestrator.
|
||||
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
|
||||
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
|
||||
}
|
||||
|
||||
return opts.AccessTokenHash, nil
|
||||
}
|
||||
|
||||
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
|
||||
httpClient := &http.Client{}
|
||||
defer httpClient.CloseIdleConnections()
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
token, err := getMMDSToken(ctx, httpClient)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
|
||||
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
|
||||
}
|
||||
|
||||
if mmdsOpts.LogsCollectorAddress != "" {
|
||||
mmdsChan <- mmdsOpts
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBufferSize = 2 << 15
|
||||
defaultTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
|
||||
timer := time.NewTicker(defaultTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
var buffer []byte
|
||||
defer func() {
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if len(buffer) > 0 {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
}
|
||||
case data, ok := <-dataCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
buffer = append(buffer, data...)
|
||||
|
||||
if len(buffer) >= defaultMaxBufferSize {
|
||||
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||
buffer = nil
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package exporter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
)
|
||||
|
||||
const ExporterTimeout = 10 * time.Second
|
||||
|
||||
type HTTPExporter struct {
|
||||
client http.Client
|
||||
logs [][]byte
|
||||
isNotFC bool
|
||||
mmdsOpts *host.MMDSOpts
|
||||
|
||||
// Concurrency coordination
|
||||
triggers chan struct{}
|
||||
logLock sync.RWMutex
|
||||
mmdsLock sync.RWMutex
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
|
||||
exporter := &HTTPExporter{
|
||||
client: http.Client{
|
||||
Timeout: ExporterTimeout,
|
||||
},
|
||||
triggers: make(chan struct{}, 1),
|
||||
isNotFC: isNotFC,
|
||||
startOnce: sync.Once{},
|
||||
mmdsOpts: &host.MMDSOpts{
|
||||
SandboxID: "unknown",
|
||||
TemplateID: "unknown",
|
||||
LogsCollectorAddress: "",
|
||||
},
|
||||
}
|
||||
|
||||
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
|
||||
|
||||
return exporter
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
|
||||
if address == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := w.client.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printLog(logs []byte) {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(logs))
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case mmdsOpts, ok := <-mmdsChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
w.mmdsLock.Lock()
|
||||
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
|
||||
w.mmdsLock.Unlock()
|
||||
|
||||
w.startOnce.Do(func() {
|
||||
go w.start(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) start(ctx context.Context) {
|
||||
for range w.triggers {
|
||||
logs := w.getAllLogs()
|
||||
|
||||
if len(logs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if w.isNotFC {
|
||||
for _, log := range logs {
|
||||
fmt.Fprintf(os.Stdout, "%v", string(log))
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, logLine := range logs {
|
||||
w.mmdsLock.RLock()
|
||||
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
|
||||
w.mmdsLock.RUnlock()
|
||||
if err != nil {
|
||||
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
|
||||
if err != nil {
|
||||
log.Printf("error sending instance logs: %+v", err)
|
||||
|
||||
printLog(logLine)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) resumeProcessing() {
|
||||
select {
|
||||
case w.triggers <- struct{}{}:
|
||||
default:
|
||||
// Exporter processing already triggered
|
||||
// This is expected behavior if the exporter is already processing logs
|
||||
}
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) Write(logs []byte) (int, error) {
|
||||
logsCopy := make([]byte, len(logs))
|
||||
copy(logsCopy, logs)
|
||||
|
||||
go w.addLogs(logsCopy)
|
||||
|
||||
return len(logs), nil
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) getAllLogs() [][]byte {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
logs := w.logs
|
||||
w.logs = nil
|
||||
|
||||
return logs
|
||||
}
|
||||
|
||||
func (w *HTTPExporter) addLogs(logs []byte) {
|
||||
w.logLock.Lock()
|
||||
defer w.logLock.Unlock()
|
||||
|
||||
w.logs = append(w.logs, logs)
|
||||
|
||||
w.resumeProcessing()
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type OperationID string
|
||||
|
||||
const (
|
||||
OperationIDKey OperationID = "operation_id"
|
||||
DefaultHTTPMethod string = "POST"
|
||||
)
|
||||
|
||||
var operationID = atomic.Int32{}
|
||||
|
||||
func AssignOperationID() string {
|
||||
id := operationID.Add(1)
|
||||
|
||||
return strconv.Itoa(int(id))
|
||||
}
|
||||
|
||||
func AddRequestIDToContext(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
|
||||
}
|
||||
|
||||
func formatMethod(method string) string {
|
||||
parts := strings.Split(method, ".")
|
||||
if len(parts) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
split := strings.Split(parts[1], "/")
|
||||
if len(split) < 2 {
|
||||
return method
|
||||
}
|
||||
|
||||
servicePart := split[0]
|
||||
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
|
||||
|
||||
methodPart := split[1]
|
||||
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
|
||||
|
||||
return fmt.Sprintf("%s %s", servicePart, methodPart)
|
||||
}
|
||||
|
||||
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
|
||||
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return connect.UnaryFunc(func(
|
||||
ctx context.Context,
|
||||
req connect.AnyRequest,
|
||||
) (connect.AnyResponse, error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
res, err := next(ctx, req)
|
||||
|
||||
l := logger.
|
||||
Err(err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
l = l.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
l = l.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
l = l.Interface("response", nil)
|
||||
}
|
||||
|
||||
l.Msg(formatMethod(req.Spec().Procedure))
|
||||
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
return connect.UnaryInterceptorFunc(interceptor)
|
||||
}
|
||||
|
||||
func LogServerStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
req *connect.Request[R],
|
||||
stream *connect.ServerStream[T],
|
||||
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
|
||||
) error {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
l := logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if req != nil {
|
||||
l = l.Interface("request", req.Any())
|
||||
}
|
||||
|
||||
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
err := handler(ctx, req, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
} else {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func LogClientStreamWithoutEvents[T any, R any](
|
||||
ctx context.Context,
|
||||
logger *zerolog.Logger,
|
||||
stream *connect.ClientStream[T],
|
||||
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
|
||||
) (*connect.Response[R], error) {
|
||||
ctx = AddRequestIDToContext(ctx)
|
||||
|
||||
logger.Debug().
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
|
||||
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
res, err := handler(ctx, stream)
|
||||
|
||||
logEvent := getErrDebugLogEvent(logger, err).
|
||||
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||
|
||||
if err != nil {
|
||||
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||
}
|
||||
|
||||
if res != nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", res.Any())
|
||||
}
|
||||
|
||||
if res == nil && err == nil {
|
||||
logEvent = logEvent.Interface("response", nil)
|
||||
}
|
||||
|
||||
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Return logger with error level if err is not nil, otherwise return logger with debug level
|
||||
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
|
||||
if err != nil {
|
||||
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
|
||||
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||
}
|
||||
@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
|
||||
)
|
||||
|
||||
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
|
||||
zerolog.TimestampFieldName = "timestamp"
|
||||
zerolog.TimeFieldFormat = time.RFC3339Nano
|
||||
|
||||
exporters := []io.Writer{}
|
||||
|
||||
if isNotFC {
|
||||
exporters = append(exporters, os.Stdout)
|
||||
} else {
|
||||
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
|
||||
}
|
||||
|
||||
l := zerolog.
|
||||
New(io.MultiWriter(exporters...)).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger().
|
||||
Level(zerolog.DebugLevel)
|
||||
|
||||
return &l
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"connectrpc.com/authn"
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
|
||||
username, _, ok := req.BasicAuth()
|
||||
if !ok {
|
||||
// When no username is provided, ignore the authentication method (not all endpoints require it)
|
||||
// Missing user is then handled in the GetAuthUser function
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid username: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
|
||||
u, ok := authn.GetInfo(ctx).(*user.User)
|
||||
if !ok {
|
||||
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
|
||||
}
|
||||
|
||||
u, err := GetUser(username)
|
||||
if err != nil {
|
||||
return nil, authn.Errorf("invalid default user: '%s'", username)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
)
|
||||
|
||||
const defaultKeepAliveInterval = 90 * time.Second
|
||||
|
||||
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
|
||||
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
|
||||
|
||||
var interval time.Duration
|
||||
|
||||
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
|
||||
if err != nil {
|
||||
interval = defaultKeepAliveInterval
|
||||
} else {
|
||||
interval = time.Duration(keepAliveIntervalInt) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
|
||||
return ticker, func() {
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||
)
|
||||
|
||||
func expand(path, homedir string) (string, error) {
|
||||
if len(path) == 0 {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if path[0] != '~' {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
|
||||
return "", errors.New("cannot expand user-specific home dir")
|
||||
}
|
||||
|
||||
return filepath.Join(homedir, path[1:]), nil
|
||||
}
|
||||
|
||||
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
|
||||
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
|
||||
|
||||
path, err := expand(path, user.HomeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
|
||||
}
|
||||
|
||||
if filepath.IsAbs(path) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// The filepath.Abs can correctly resolve paths like /home/user/../file
|
||||
path = filepath.Join(user.HomeDir, path)
|
||||
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
|
||||
}
|
||||
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
func getSubpaths(path string) (subpaths []string) {
|
||||
for {
|
||||
subpaths = append(subpaths, path)
|
||||
|
||||
path = filepath.Dir(path)
|
||||
if path == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(subpaths)
|
||||
|
||||
return subpaths
|
||||
}
|
||||
|
||||
func EnsureDirs(path string, uid, gid int) error {
|
||||
subpaths := getSubpaths(path)
|
||||
for _, subpath := range subpaths {
|
||||
info, err := os.Stat(subpath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to stat directory: %w", err)
|
||||
}
|
||||
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
err = os.Mkdir(subpath, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(subpath, uid, gid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to chown directory: %w", err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path is a file: %s", subpath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,46 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package permissions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
|
||||
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return uint32(newUID), uint32(newGID), nil
|
||||
}
|
||||
|
||||
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
|
||||
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||
}
|
||||
|
||||
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||
}
|
||||
|
||||
return int(newUID), int(newGID), nil
|
||||
}
|
||||
|
||||
func GetUser(username string) (u *user.User, err error) {
|
||||
u, err = user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
@ -1,165 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package port
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// ConnStat represents a single TCP connection read from /proc/net/tcp(6).
|
||||
// It contains only the fields needed by the port scanner and forwarder.
|
||||
type ConnStat struct {
|
||||
LocalIP string
|
||||
LocalPort uint32
|
||||
Status string
|
||||
Family uint32 // syscall.AF_INET or syscall.AF_INET6
|
||||
Inode uint64 // socket inode, unique per connection
|
||||
}
|
||||
|
||||
// tcpStates maps the hex state values from /proc/net/tcp to string names
|
||||
// matching the gopsutil convention used by ScannerFilter.
|
||||
var tcpStates = map[string]string{
|
||||
"01": "ESTABLISHED",
|
||||
"02": "SYN_SENT",
|
||||
"03": "SYN_RECV",
|
||||
"04": "FIN_WAIT1",
|
||||
"05": "FIN_WAIT2",
|
||||
"06": "TIME_WAIT",
|
||||
"07": "CLOSE",
|
||||
"08": "CLOSE_WAIT",
|
||||
"09": "LAST_ACK",
|
||||
"0A": "LISTEN",
|
||||
"0B": "CLOSING",
|
||||
}
|
||||
|
||||
// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns
|
||||
// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil
|
||||
// performs, which is unsafe across Firecracker snapshot/restore boundaries.
|
||||
func ReadTCPConnections() ([]ConnStat, error) {
|
||||
var conns []ConnStat
|
||||
|
||||
tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse /proc/net/tcp: %w", err)
|
||||
}
|
||||
conns = append(conns, tcp4...)
|
||||
|
||||
tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err)
|
||||
}
|
||||
conns = append(conns, tcp6...)
|
||||
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file.
|
||||
//
|
||||
// Format (fields are whitespace-separated):
|
||||
//
|
||||
// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode
|
||||
// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345
|
||||
func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var conns []ConnStat
|
||||
scanner := bufio.NewScanner(f)
|
||||
|
||||
// Skip header line.
|
||||
scanner.Scan()
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
// fields[1] = local_address (hex_ip:hex_port)
|
||||
ip, port, err := parseHexAddr(fields[1], family)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// fields[3] = state (hex)
|
||||
state, ok := tcpStates[fields[3]]
|
||||
if !ok {
|
||||
state = "UNKNOWN"
|
||||
}
|
||||
|
||||
// fields[9] = inode
|
||||
inode, err := strconv.ParseUint(fields[9], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
conns = append(conns, ConnStat{
|
||||
LocalIP: ip,
|
||||
LocalPort: port,
|
||||
Status: state,
|
||||
Family: family,
|
||||
Inode: inode,
|
||||
})
|
||||
}
|
||||
|
||||
return conns, scanner.Err()
|
||||
}
|
||||
|
||||
// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp.
|
||||
// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word).
|
||||
// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word).
|
||||
func parseHexAddr(s string, family uint32) (string, uint32, error) {
|
||||
parts := strings.SplitN(s, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", 0, fmt.Errorf("invalid address: %s", s)
|
||||
}
|
||||
|
||||
port64, err := strconv.ParseUint(parts[1], 16, 32)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
ipHex := parts[0]
|
||||
ipBytes, err := hex.DecodeString(ipHex)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
if family == syscall.AF_INET {
|
||||
if len(ipBytes) != 4 {
|
||||
return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes))
|
||||
}
|
||||
// /proc/net/tcp stores IPv4 as a single little-endian 32-bit word.
|
||||
ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0])
|
||||
} else {
|
||||
if len(ipBytes) != 16 {
|
||||
return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes))
|
||||
}
|
||||
// /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words.
|
||||
ip = make(net.IP, 16)
|
||||
for i := 0; i < 4; i++ {
|
||||
ip[i*4+0] = ipBytes[i*4+3]
|
||||
ip[i*4+1] = ipBytes[i*4+2]
|
||||
ip[i*4+2] = ipBytes[i*4+1]
|
||||
ip[i*4+3] = ipBytes[i*4+0]
|
||||
}
|
||||
}
|
||||
|
||||
return ip.String(), uint32(port64), nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user