Compare commits
26 Commits
main
...
866f3ac012
| Author | SHA1 | Date | |
|---|---|---|---|
| 866f3ac012 | |||
| 2c66959b92 | |||
| e4ead076e3 | |||
| 1d59b50e49 | |||
| f38d5812d1 | |||
| 931b7d54b3 | |||
| 477d4f8cf6 | |||
| 88246fac2b | |||
| 1846168736 | |||
| c92cc29b88 | |||
| 712b77b01c | |||
| 80a99eec87 | |||
| a0d635ae5e | |||
| 63e9132d38 | |||
| 778894b488 | |||
| a1bd439c75 | |||
| 9b94df7f56 | |||
| 0c245e9e1c | |||
| b4d8edb65b | |||
| ec3360d9ad | |||
| d7b25b0891 | |||
| 34c89e814d | |||
| 6f0c365d44 | |||
| c31ce90306 | |||
| 7753938044 | |||
| a3898d68fb |
17
.env.example
17
.env.example
@ -1,16 +1,18 @@
|
|||||||
# Database
|
# Database
|
||||||
DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
|
DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
REDIS_URL=redis://localhost:6379/0
|
||||||
|
|
||||||
# Control Plane
|
# Control Plane
|
||||||
CP_LISTEN_ADDR=:8000
|
CP_LISTEN_ADDR=:8000
|
||||||
CP_HOST_AGENT_ADDR=localhost:50051
|
CP_HOST_AGENT_ADDR=localhost:50051
|
||||||
|
|
||||||
# Host Agent
|
# Host Agent
|
||||||
AGENT_LISTEN_ADDR=:50051
|
AGENT_LISTEN_ADDR=:50051
|
||||||
AGENT_KERNEL_PATH=/var/lib/wrenn/kernels/vmlinux
|
AGENT_FILES_ROOTDIR=/var/lib/wrenn
|
||||||
AGENT_IMAGES_PATH=/var/lib/wrenn/images
|
|
||||||
AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes
|
|
||||||
AGENT_HOST_INTERFACE=eth0
|
AGENT_HOST_INTERFACE=eth0
|
||||||
|
AGENT_CP_URL=http://localhost:8000
|
||||||
|
|
||||||
# Lago (billing — external service)
|
# Lago (billing — external service)
|
||||||
LAGO_API_URL=http://localhost:3000
|
LAGO_API_URL=http://localhost:3000
|
||||||
@ -22,3 +24,12 @@ S3_REGION=fsn1
|
|||||||
S3_ENDPOINT=https://fsn1.your-objectstorage.com
|
S3_ENDPOINT=https://fsn1.your-objectstorage.com
|
||||||
AWS_ACCESS_KEY_ID=
|
AWS_ACCESS_KEY_ID=
|
||||||
AWS_SECRET_ACCESS_KEY=
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
JWT_SECRET=
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
OAUTH_GITHUB_CLIENT_ID=
|
||||||
|
OAUTH_GITHUB_CLIENT_SECRET=
|
||||||
|
OAUTH_REDIRECT_URL=https://app.wrenn.dev
|
||||||
|
CP_PUBLIC_URL=https://api.wrenn.dev
|
||||||
|
|||||||
97
LICENSE
Normal file
97
LICENSE
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
Wrenn Sandbox License
|
||||||
|
|
||||||
|
Business Source License 1.1
|
||||||
|
|
||||||
|
Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Licensor
|
||||||
|
|
||||||
|
M/S Omukk, Bangladesh
|
||||||
|
|
||||||
|
Contact: [contact@omukk.dev](mailto:contact@omukk.dev)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Licensed Work
|
||||||
|
|
||||||
|
The Licensed Work is the software project known as "Wrenn Sandbox", including all source code and associated files in this repository, except the directory `envd/`, which is licensed separately under the Apache License Version 2.0.
|
||||||
|
|
||||||
|
Initial development of the Licensed Work began in March 2026.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Change Date
|
||||||
|
|
||||||
|
January 1, 2030
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Change License
|
||||||
|
|
||||||
|
On the Change Date, the Licensed Work will automatically become available under the terms of the GNU General Public License, Version 3 (GPL-3.0).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Additional Use Grant (SaaS Restriction)
|
||||||
|
|
||||||
|
The Licensor grants you the right to copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work, provided that you comply with the limitations of this License.
|
||||||
|
|
||||||
|
You may:
|
||||||
|
|
||||||
|
* Use the software for personal use
|
||||||
|
* Use the software internally within your organization
|
||||||
|
* Modify the source code
|
||||||
|
* Experiment, test, and evaluate the software
|
||||||
|
* Distribute unmodified copies of the source code for evaluation
|
||||||
|
|
||||||
|
You may not:
|
||||||
|
|
||||||
|
Provide the Licensed Work to third parties as a managed service, hosted service, software-as-a-service (SaaS), platform service, or any similar commercial offering where the primary value of the service derives from the Licensed Work.
|
||||||
|
|
||||||
|
You may not sell the Licensed Work or offer paid services primarily based on the Licensed Work without a commercial license from M/S Omukk.
|
||||||
|
|
||||||
|
Commercial licenses may be obtained by contacting:
|
||||||
|
|
||||||
|
[contact@omukk.dev](mailto:contact@omukk.dev)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Contributions
|
||||||
|
|
||||||
|
Unless otherwise stated, any Contribution intentionally submitted for inclusion in the Licensed Work shall be licensed under the terms of this Business Source License 1.1.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Business Source License Terms
|
||||||
|
|
||||||
|
Use of the Licensed Work is governed by the Business Source License included in this file.
|
||||||
|
|
||||||
|
The Business Source License is not an Open Source license. However, the Licensed Work will automatically become available under the Change License on the Change Date.
|
||||||
|
|
||||||
|
Licensor grants you a non-exclusive, worldwide, royalty-free license to use, copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work, provided that you comply with the limitations stated in this License.
|
||||||
|
|
||||||
|
All copies of the Licensed Work must include this License file.
|
||||||
|
|
||||||
|
Any use of the Licensed Work in violation of this License will automatically terminate your rights under this License.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Disclaimer of Warranty
|
||||||
|
|
||||||
|
THE LICENSED WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Limitation of Liability
|
||||||
|
|
||||||
|
IN NO EVENT SHALL THE LICENSOR OR CONTRIBUTORS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE LICENSED WORK.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Third-Party Components
|
||||||
|
|
||||||
|
Portions of this project include software licensed under separate open-source licenses.
|
||||||
|
|
||||||
|
See the NOTICE file and THIRD_PARTY_LICENSES directory for details.
|
||||||
20
Makefile
20
Makefile
@ -21,17 +21,17 @@ build-agent:
|
|||||||
|
|
||||||
build-envd:
|
build-envd:
|
||||||
cd $(ENVD_DIR) && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
cd $(ENVD_DIR) && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
|
||||||
go build -ldflags="$(LDFLAGS)" -o ../$(GOBIN)/envd .
|
go build -ldflags="$(LDFLAGS)" -o $(GOBIN)/envd .
|
||||||
@file $(GOBIN)/envd | grep -q "statically linked" || \
|
@file $(GOBIN)/envd | grep -q "statically linked" || \
|
||||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# Development
|
# Development
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
.PHONY: dev dev-cp dev-agent dev-envd dev-infra dev-down dev-seed
|
.PHONY: dev dev-cp dev-agent dev-envd dev-infra dev-down
|
||||||
|
|
||||||
## One command to start everything for local dev
|
## One command to start everything for local dev
|
||||||
dev: dev-infra migrate-up dev-seed dev-cp
|
dev: dev-infra migrate-up dev-cp
|
||||||
|
|
||||||
dev-infra:
|
dev-infra:
|
||||||
docker compose -f deploy/docker-compose.dev.yml up -d
|
docker compose -f deploy/docker-compose.dev.yml up -d
|
||||||
@ -52,8 +52,6 @@ dev-agent:
|
|||||||
dev-envd:
|
dev-envd:
|
||||||
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
|
cd $(ENVD_DIR) && go run . --debug --listen-tcp :3002
|
||||||
|
|
||||||
dev-seed:
|
|
||||||
go run ./scripts/seed.go
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# Database (goose)
|
# Database (goose)
|
||||||
@ -84,16 +82,12 @@ migrate-reset:
|
|||||||
generate: proto sqlc
|
generate: proto sqlc
|
||||||
|
|
||||||
proto:
|
proto:
|
||||||
protoc --go_out=. --go_opt=paths=source_relative \
|
cd proto/envd && buf generate
|
||||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
cd proto/hostagent && buf generate
|
||||||
proto/hostagent/hostagent.proto
|
cd $(ENVD_DIR)/spec && buf generate
|
||||||
protoc --go_out=. --go_opt=paths=source_relative \
|
|
||||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
|
||||||
proto/envd/process.proto proto/envd/filesystem.proto
|
|
||||||
|
|
||||||
sqlc:
|
sqlc:
|
||||||
@if command -v sqlc > /dev/null; then sqlc generate; \
|
sqlc generate
|
||||||
else echo "sqlc not installed, skipping"; fi
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════
|
||||||
# Quality & Testing
|
# Quality & Testing
|
||||||
|
|||||||
19
NOTICE
Normal file
19
NOTICE
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
Wrenn Sandbox
|
||||||
|
Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||||
|
|
||||||
|
This project includes software derived from the following project:
|
||||||
|
|
||||||
|
Project: e2b infra
|
||||||
|
Repository: https://github.com/e2b-dev/infra
|
||||||
|
|
||||||
|
The following files and directories in this repository contain code derived from the above project and are licensed under the Apache License Version 2.0:
|
||||||
|
|
||||||
|
- envd/
|
||||||
|
- proto/envd/*.proto
|
||||||
|
- internal/snapshot/
|
||||||
|
- internal/uffd/
|
||||||
|
|
||||||
|
Modifications to this code were made by M/S Omukk.
|
||||||
|
|
||||||
|
Copyright (c) 2023 FoundryLabs, Inc.
|
||||||
|
Modifications Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||||
277
README.md
277
README.md
@ -2,211 +2,128 @@
|
|||||||
|
|
||||||
MicroVM-based code execution platform. Firecracker VMs, not containers. Pool-based pricing, persistent sandboxes, Python/TS/Go SDKs.
|
MicroVM-based code execution platform. Firecracker VMs, not containers. Pool-based pricing, persistent sandboxes, Python/TS/Go SDKs.
|
||||||
|
|
||||||
## Stack
|
## Deployment
|
||||||
|
|
||||||
| Component | Tech |
|
### Prerequisites
|
||||||
|---|---|
|
|
||||||
| Control plane | Go, chi, pgx, goose, htmx |
|
|
||||||
| Host agent | Go, Firecracker Go SDK, vsock |
|
|
||||||
| Guest agent (envd) | Go (extracted from E2B, standalone binary) |
|
|
||||||
| Database | PostgreSQL |
|
|
||||||
| Cache | Redis |
|
|
||||||
| Billing | Lago (external) |
|
|
||||||
| Snapshot storage | S3 (Seaweedfs for dev) |
|
|
||||||
| Monitoring | Prometheus + Grafana |
|
|
||||||
| Admin UI | htmx + Go html/template |
|
|
||||||
|
|
||||||
## Architecture
|
- Linux host with `/dev/kvm` access (bare metal or nested virt)
|
||||||
|
- Firecracker binary at `/usr/local/bin/firecracker`
|
||||||
|
- PostgreSQL
|
||||||
|
- Go 1.25+
|
||||||
|
|
||||||
```
|
### Build
|
||||||
SDK → HTTPS → Control Plane → gRPC → Host Agent → vsock → envd (inside VM)
|
|
||||||
│ │
|
|
||||||
├── PostgreSQL ├── Firecracker
|
|
||||||
├── Redis ├── TAP/NAT networking
|
|
||||||
└── Lago (billing) ├── CoW rootfs clones
|
|
||||||
└── Prometheus /metrics
|
|
||||||
```
|
|
||||||
|
|
||||||
Control plane is stateless (state in Postgres + Redis). Host agent is stateful (manages VMs on the local machine). envd is a static binary baked into rootfs images — separate Go module, separate build, never imported by anything.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
- Linux with `/dev/kvm` (bare metal or nested virt)
|
|
||||||
- Go 1.22+
|
|
||||||
- Docker (for dev infra)
|
|
||||||
- Firecracker + jailer installed at `/usr/local/bin/`
|
|
||||||
- `protoc` + Go plugins for proto generation
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Firecracker
|
make build # outputs to builds/
|
||||||
ARCH=$(uname -m) VERSION="v1.6.0"
|
|
||||||
curl -L "https://github.com/firecracker-microvm/firecracker/releases/download/${VERSION}/firecracker-${VERSION}-${ARCH}.tgz" | tar xz
|
|
||||||
sudo mv release-*/firecracker-* /usr/local/bin/firecracker
|
|
||||||
sudo mv release-*/jailer-* /usr/local/bin/jailer
|
|
||||||
|
|
||||||
# Go tools
|
|
||||||
go install github.com/pressly/goose/v3/cmd/goose@latest
|
|
||||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
|
||||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
|
|
||||||
go install github.com/air-verse/air@latest
|
|
||||||
go install github.com/fullstorydev/grpcurl/cmd/grpcurl@latest
|
|
||||||
|
|
||||||
# KVM
|
|
||||||
ls /dev/kvm && sudo setfacl -m u:${USER}:rw /dev/kvm
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
Produces three binaries: `wrenn-cp` (control plane), `wrenn-agent` (host agent), `envd` (guest agent).
|
||||||
|
|
||||||
|
### Host setup
|
||||||
|
|
||||||
|
The host agent machine needs:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
# Kernel for guest VMs
|
||||||
make tidy
|
mkdir -p /var/lib/wrenn/kernels
|
||||||
make dev-infra # Postgres, Redis, Prometheus, Grafana
|
# Place a vmlinux kernel at /var/lib/wrenn/kernels/vmlinux
|
||||||
|
|
||||||
|
# Rootfs images
|
||||||
|
mkdir -p /var/lib/wrenn/images
|
||||||
|
# Build or place .ext4 rootfs images (e.g., minimal.ext4)
|
||||||
|
|
||||||
|
# Sandbox working directory
|
||||||
|
mkdir -p /var/lib/wrenn/sandboxes
|
||||||
|
|
||||||
|
# Snapshots directory
|
||||||
|
mkdir -p /var/lib/wrenn/snapshots
|
||||||
|
|
||||||
|
# Enable IP forwarding
|
||||||
|
sysctl -w net.ipv4.ip_forward=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configure
|
||||||
|
|
||||||
|
Copy `.env.example` to `.env` and edit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Required
|
||||||
|
DATABASE_URL=postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable
|
||||||
|
|
||||||
|
# Control plane
|
||||||
|
CP_LISTEN_ADDR=:8000
|
||||||
|
CP_HOST_AGENT_ADDR=http://localhost:50051
|
||||||
|
|
||||||
|
# Host agent
|
||||||
|
AGENT_LISTEN_ADDR=:50051
|
||||||
|
AGENT_FILES_ROOTDIR=/var/lib/wrenn
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Apply database migrations
|
||||||
make migrate-up
|
make migrate-up
|
||||||
make dev-seed
|
|
||||||
|
|
||||||
# Terminal 1
|
# Start control plane
|
||||||
make dev-cp # :8000
|
./builds/wrenn-cp
|
||||||
|
|
||||||
# Terminal 2
|
|
||||||
make dev-agent # :50051 (sudo)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- API: `http://localhost:8000/v1/sandboxes`
|
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`).
|
||||||
- Admin: `http://localhost:8000/admin/`
|
|
||||||
- Grafana: `http://localhost:3001` (admin/admin)
|
|
||||||
- Prometheus: `http://localhost:9090`
|
|
||||||
|
|
||||||
## Layout
|
### Host registration
|
||||||
|
|
||||||
```
|
Hosts must be registered with the control plane before they can serve sandboxes.
|
||||||
cmd/
|
|
||||||
control-plane/ REST API + admin UI + gRPC client + lifecycle manager
|
|
||||||
host-agent/ gRPC server + Firecracker + networking + metrics
|
|
||||||
|
|
||||||
envd/ standalone Go module — separate go.mod, static binary
|
1. **Create a host record** (via API or admin UI):
|
||||||
extracted from e2b-dev/infra, talks gRPC over vsock
|
```bash
|
||||||
|
# As an admin (JWT auth)
|
||||||
|
curl -X POST http://localhost:8000/v1/hosts \
|
||||||
|
-H "Authorization: Bearer $JWT_TOKEN" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"type": "regular"}'
|
||||||
|
```
|
||||||
|
This returns a `registration_token` (valid for 1 hour).
|
||||||
|
|
||||||
proto/
|
2. **Start the host agent** with the registration token and its externally-reachable address:
|
||||||
hostagent/ control plane ↔ host agent
|
```bash
|
||||||
envd/ host agent ↔ guest agent (from E2B spec/)
|
sudo AGENT_CP_URL=http://cp-host:8000 \
|
||||||
|
./builds/wrenn-agent \
|
||||||
|
--register <token-from-step-1> \
|
||||||
|
--address 10.0.1.5:50051
|
||||||
|
```
|
||||||
|
On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `$AGENT_FILES_ROOTDIR/host-token`.
|
||||||
|
|
||||||
internal/
|
3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically:
|
||||||
api/ chi handlers
|
```bash
|
||||||
admin/ htmx + Go templates
|
sudo AGENT_CP_URL=http://cp-host:8000 \
|
||||||
auth/ API key + rate limiting
|
./builds/wrenn-agent --address 10.0.1.5:50051
|
||||||
scheduler/ SingleHost → LeastLoaded
|
```
|
||||||
lifecycle/ auto-pause, auto-hibernate, auto-destroy
|
|
||||||
vm/ Firecracker config, boot, stop, jailer
|
|
||||||
network/ TAP, NAT, IP allocator (/30 subnets)
|
|
||||||
filesystem/ base images, CoW clones (cp --reflink)
|
|
||||||
envdclient/ vsock dialer + gRPC client to envd
|
|
||||||
snapshot/ pause/resume + S3 offload
|
|
||||||
metrics/ cgroup stats + Prometheus exporter
|
|
||||||
models/ Sandbox, Host structs
|
|
||||||
config/ env + YAML loading
|
|
||||||
id/ sb-xxxxxxxx generation
|
|
||||||
|
|
||||||
db/migrations/ goose SQL (00001_initial.sql, ...)
|
4. **If registration fails** (e.g., network error after token was consumed), regenerate a token:
|
||||||
db/queries/ raw SQL or sqlc
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/hosts/$HOST_ID/token \
|
||||||
|
-H "Authorization: Bearer $JWT_TOKEN"
|
||||||
|
```
|
||||||
|
Then restart the agent with the new token.
|
||||||
|
|
||||||
images/templates/ rootfs build scripts (minimal, python311, node20)
|
The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
|
||||||
sdk/ Python, TypeScript, Go client SDKs
|
|
||||||
deploy/ systemd units, ansible, docker-compose.dev.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Commands
|
### Rootfs images
|
||||||
|
|
||||||
|
envd must be baked into every rootfs image. After building:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Dev
|
make build-envd
|
||||||
make dev # everything: infra + migrate + seed + control plane
|
bash scripts/update-debug-rootfs.sh /var/lib/wrenn/images/minimal.ext4
|
||||||
make dev-infra # just Postgres/Redis/Prometheus/Grafana
|
```
|
||||||
make dev-down # tear down
|
|
||||||
make dev-cp # control plane (hot reload with air)
|
|
||||||
make dev-agent # host agent (sudo)
|
|
||||||
make dev-envd # envd in TCP debug mode (no Firecracker)
|
|
||||||
make dev-seed # test API key + data
|
|
||||||
|
|
||||||
# Build
|
## Development
|
||||||
make build # all → bin/
|
|
||||||
make build-envd # static binary, verified
|
|
||||||
|
|
||||||
# DB
|
```bash
|
||||||
make migrate-up
|
make dev # Start PostgreSQL (Docker), run migrations, start control plane
|
||||||
make migrate-down
|
make dev-agent # Start host agent (separate terminal, sudo)
|
||||||
make migrate-create name=xxx
|
|
||||||
make migrate-reset # drop + re-apply
|
|
||||||
|
|
||||||
# Codegen
|
|
||||||
make generate # proto + sqlc
|
|
||||||
make proto
|
|
||||||
|
|
||||||
# Quality
|
|
||||||
make check # fmt + vet + lint + test
|
make check # fmt + vet + lint + test
|
||||||
make test # unit
|
|
||||||
make test-all # unit + integration
|
|
||||||
make tidy # go mod tidy (both modules)
|
|
||||||
|
|
||||||
# Images
|
|
||||||
make images # all rootfs (needs sudo + envd)
|
|
||||||
|
|
||||||
# Deploy
|
|
||||||
make setup-host # one-time KVM/networking setup
|
|
||||||
make install # binaries + systemd
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Database
|
See `CLAUDE.md` for full architecture documentation.
|
||||||
|
|
||||||
Postgres via pgx. No ORM. Migrations via goose (plain SQL).
|
|
||||||
|
|
||||||
Tables: `sandboxes`, `hosts`, `audit_events`, `api_keys`.
|
|
||||||
|
|
||||||
States: `pending → starting → running → paused → hibernated → stopped`. Any → `error`.
|
|
||||||
|
|
||||||
## envd
|
|
||||||
|
|
||||||
From [e2b-dev/infra](https://github.com/e2b-dev/infra) (Apache 2.0). PID 1 inside every VM. Exposes ProcessService + FilesystemService over gRPC on vsock.
|
|
||||||
|
|
||||||
Own `go.mod`. Must be `CGO_ENABLED=0`. Baked into rootfs at `/usr/local/bin/envd`. Kernel args: `init=/usr/local/bin/envd`.
|
|
||||||
|
|
||||||
Host agent connects via Firecracker vsock UDS using `CONNECT <port>\n` handshake.
|
|
||||||
|
|
||||||
## Networking
|
|
||||||
|
|
||||||
Each sandbox: `/30` from `10.0.0.0/16` (~16K per host).
|
|
||||||
|
|
||||||
```
|
|
||||||
Host: tap-sb-a1b2c3d4 (10.0.0.1/30) ↔ Guest eth0 (10.0.0.2/30)
|
|
||||||
NAT: iptables MASQUERADE via host internet interface
|
|
||||||
```
|
|
||||||
|
|
||||||
## Snapshots
|
|
||||||
|
|
||||||
- **Warm pause**: Firecracker snapshot on local NVMe. Resume <1s.
|
|
||||||
- **Cold hibernate**: zstd compressed, uploaded to S3/MinIO. Resume 5-10s.
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /v1/sandboxes create
|
|
||||||
GET /v1/sandboxes list
|
|
||||||
GET /v1/sandboxes/{id} status
|
|
||||||
POST /v1/sandboxes/{id}/exec exec
|
|
||||||
PUT /v1/sandboxes/{id}/files upload
|
|
||||||
GET /v1/sandboxes/{id}/files/* download
|
|
||||||
POST /v1/sandboxes/{id}/pause pause
|
|
||||||
POST /v1/sandboxes/{id}/resume resume
|
|
||||||
DELETE /v1/sandboxes/{id} destroy
|
|
||||||
WS /v1/sandboxes/{id}/terminal shell
|
|
||||||
```
|
|
||||||
|
|
||||||
Auth: `X-API-Key` header. Prefix: `wrn_`.
|
|
||||||
|
|
||||||
## Phases
|
|
||||||
|
|
||||||
1. Boot VM + exec via vsock (W1)
|
|
||||||
2. Host agent + networking (W2)
|
|
||||||
3. Control plane + DB + REST (W3)
|
|
||||||
4. Admin UI / htmx (W4)
|
|
||||||
5. Pause / hibernate / resume (W5)
|
|
||||||
6. SDKs (W6)
|
|
||||||
7. Jailer, cgroups, egress, metrics (W7-8)
|
|
||||||
|
|||||||
@ -0,0 +1,124 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/api"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/config"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||||
|
Level: slog.LevelDebug,
|
||||||
|
})))
|
||||||
|
|
||||||
|
cfg := config.Load()
|
||||||
|
|
||||||
|
if len(cfg.JWTSecret) < 32 {
|
||||||
|
slog.Error("JWT_SECRET must be at least 32 characters")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Database connection pool.
|
||||||
|
pool, err := pgxpool.New(ctx, cfg.DatabaseURL)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to connect to database", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
if err := pool.Ping(ctx); err != nil {
|
||||||
|
slog.Error("failed to ping database", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("connected to database")
|
||||||
|
|
||||||
|
queries := db.New(pool)
|
||||||
|
|
||||||
|
// Redis client.
|
||||||
|
redisOpts, err := redis.ParseURL(cfg.RedisURL)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to parse REDIS_URL", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
rdb := redis.NewClient(redisOpts)
|
||||||
|
defer rdb.Close()
|
||||||
|
|
||||||
|
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||||
|
slog.Error("failed to ping redis", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("connected to redis")
|
||||||
|
|
||||||
|
// Connect RPC client for the host agent.
|
||||||
|
agentHTTP := &http.Client{Timeout: 10 * time.Minute}
|
||||||
|
agentClient := hostagentv1connect.NewHostAgentServiceClient(
|
||||||
|
agentHTTP,
|
||||||
|
cfg.HostAgentAddr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuth provider registry.
|
||||||
|
oauthRegistry := oauth.NewRegistry()
|
||||||
|
if cfg.OAuthGitHubClientID != "" && cfg.OAuthGitHubClientSecret != "" {
|
||||||
|
if cfg.CPPublicURL == "" {
|
||||||
|
slog.Error("CP_PUBLIC_URL must be set when OAuth providers are configured")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
callbackURL := strings.TrimRight(cfg.CPPublicURL, "/") + "/v1/auth/oauth/github/callback"
|
||||||
|
ghProvider := oauth.NewGitHubProvider(cfg.OAuthGitHubClientID, cfg.OAuthGitHubClientSecret, callbackURL)
|
||||||
|
oauthRegistry.Register(ghProvider)
|
||||||
|
slog.Info("registered OAuth provider", "provider", "github")
|
||||||
|
}
|
||||||
|
|
||||||
|
// API server.
|
||||||
|
srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||||
|
|
||||||
|
// Start reconciler.
|
||||||
|
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)
|
||||||
|
reconciler.Start(ctx)
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: cfg.ListenAddr,
|
||||||
|
Handler: srv.Handler(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown on signal.
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
sig := <-sigCh
|
||||||
|
slog.Info("received signal, shutting down", "signal", sig)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||||
|
slog.Error("http server shutdown error", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
slog.Info("control plane starting", "addr", cfg.ListenAddr, "agent", cfg.HostAgentAddr)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("http server error", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("control plane stopped")
|
||||||
|
}
|
||||||
|
|||||||
@ -0,0 +1,132 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
registrationToken := flag.String("register", "", "One-time registration token from the control plane")
|
||||||
|
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||||
|
Level: slog.LevelDebug,
|
||||||
|
})))
|
||||||
|
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
slog.Error("host agent must run as root")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable IP forwarding (required for NAT).
|
||||||
|
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1"), 0644); err != nil {
|
||||||
|
slog.Warn("failed to enable ip_forward", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up any stale dm-snapshot devices from a previous crash.
|
||||||
|
devicemapper.CleanupStaleDevices()
|
||||||
|
|
||||||
|
listenAddr := envOrDefault("AGENT_LISTEN_ADDR", ":50051")
|
||||||
|
rootDir := envOrDefault("AGENT_FILES_ROOTDIR", "/var/lib/wrenn")
|
||||||
|
cpURL := os.Getenv("AGENT_CP_URL")
|
||||||
|
tokenFile := filepath.Join(rootDir, "host-token")
|
||||||
|
|
||||||
|
cfg := sandbox.Config{
|
||||||
|
KernelPath: filepath.Join(rootDir, "kernels", "vmlinux"),
|
||||||
|
ImagesDir: filepath.Join(rootDir, "images"),
|
||||||
|
SandboxesDir: filepath.Join(rootDir, "sandboxes"),
|
||||||
|
SnapshotsDir: filepath.Join(rootDir, "snapshots"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := sandbox.New(cfg)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
mgr.StartTTLReaper(ctx)
|
||||||
|
|
||||||
|
if *advertiseAddr == "" {
|
||||||
|
slog.Error("--address flag is required (externally-reachable ip:port)")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register with the control plane (if configured).
|
||||||
|
if cpURL != "" {
|
||||||
|
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||||
|
CPURL: cpURL,
|
||||||
|
RegistrationToken: *registrationToken,
|
||||||
|
TokenFile: tokenFile,
|
||||||
|
Address: *advertiseAddr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("host registration failed", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostID, err := hostagent.HostIDFromToken(hostToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to extract host ID from token", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("host registered", "host_id", hostID)
|
||||||
|
hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := hostagent.NewServer(mgr)
|
||||||
|
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle(path, handler)
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: listenAddr,
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown on signal.
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
sig := <-sigCh
|
||||||
|
slog.Info("received signal, shutting down", "signal", sig)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
mgr.Shutdown(shutdownCtx)
|
||||||
|
|
||||||
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||||
|
slog.Error("http server shutdown error", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
slog.Info("host agent starting", "addr", listenAddr)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("http server error", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("host agent stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func envOrDefault(key, def string) string {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|||||||
25
db/migrations/20260310094104_initial.sql
Normal file
25
db/migrations/20260310094104_initial.sql
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
CREATE TABLE sandboxes (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
owner_id TEXT NOT NULL DEFAULT '',
|
||||||
|
host_id TEXT NOT NULL DEFAULT 'default',
|
||||||
|
template TEXT NOT NULL DEFAULT 'minimal',
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
vcpus INTEGER NOT NULL DEFAULT 1,
|
||||||
|
memory_mb INTEGER NOT NULL DEFAULT 512,
|
||||||
|
timeout_sec INTEGER NOT NULL DEFAULT 0,
|
||||||
|
guest_ip TEXT NOT NULL DEFAULT '',
|
||||||
|
host_ip TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
started_at TIMESTAMPTZ,
|
||||||
|
last_active_at TIMESTAMPTZ,
|
||||||
|
last_updated TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_sandboxes_status ON sandboxes(status);
|
||||||
|
CREATE INDEX idx_sandboxes_host_status ON sandboxes(host_id, status);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE sandboxes;
|
||||||
14
db/migrations/20260311224925_snapshots.sql
Normal file
14
db/migrations/20260311224925_snapshots.sql
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
CREATE TABLE templates (
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL DEFAULT 'base', -- 'base' or 'snapshot'
|
||||||
|
vcpus INTEGER,
|
||||||
|
memory_mb INTEGER,
|
||||||
|
size_bytes BIGINT NOT NULL DEFAULT 0,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE templates;
|
||||||
46
db/migrations/20260313210608_auth.sql
Normal file
46
db/migrations/20260313210608_auth.sql
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
CREATE TABLE users (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
email TEXT NOT NULL UNIQUE,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE teams (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE users_teams (
|
||||||
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
|
||||||
|
is_default BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
role TEXT NOT NULL DEFAULT 'owner',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (team_id, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_users_teams_user ON users_teams(user_id);
|
||||||
|
|
||||||
|
CREATE TABLE team_api_keys (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
team_id TEXT NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
|
||||||
|
name TEXT NOT NULL DEFAULT '',
|
||||||
|
key_hash TEXT NOT NULL UNIQUE,
|
||||||
|
key_prefix TEXT NOT NULL DEFAULT '',
|
||||||
|
created_by TEXT NOT NULL REFERENCES users(id),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
last_used TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_team_api_keys_team ON team_api_keys(team_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE team_api_keys;
|
||||||
|
DROP TABLE users_teams;
|
||||||
|
DROP TABLE teams;
|
||||||
|
DROP TABLE users;
|
||||||
31
db/migrations/20260313210611_team_ownership.sql
Normal file
31
db/migrations/20260313210611_team_ownership.sql
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE sandboxes
|
||||||
|
ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
|
||||||
|
|
||||||
|
UPDATE sandboxes SET team_id = owner_id WHERE owner_id != '';
|
||||||
|
|
||||||
|
ALTER TABLE sandboxes
|
||||||
|
DROP COLUMN owner_id;
|
||||||
|
|
||||||
|
ALTER TABLE templates
|
||||||
|
ADD COLUMN team_id TEXT NOT NULL DEFAULT '';
|
||||||
|
|
||||||
|
CREATE INDEX idx_sandboxes_team ON sandboxes(team_id);
|
||||||
|
CREATE INDEX idx_templates_team ON templates(team_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
ALTER TABLE sandboxes
|
||||||
|
ADD COLUMN owner_id TEXT NOT NULL DEFAULT '';
|
||||||
|
|
||||||
|
UPDATE sandboxes SET owner_id = team_id WHERE team_id != '';
|
||||||
|
|
||||||
|
ALTER TABLE sandboxes
|
||||||
|
DROP COLUMN team_id;
|
||||||
|
|
||||||
|
ALTER TABLE templates
|
||||||
|
DROP COLUMN team_id;
|
||||||
|
|
||||||
|
DROP INDEX IF EXISTS idx_sandboxes_team;
|
||||||
|
DROP INDEX IF EXISTS idx_templates_team;
|
||||||
22
db/migrations/20260315001514_oauth.sql
Normal file
22
db/migrations/20260315001514_oauth.sql
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE users
|
||||||
|
ALTER COLUMN password_hash DROP NOT NULL;
|
||||||
|
|
||||||
|
CREATE TABLE oauth_providers (
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
provider_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
email TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (provider, provider_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE oauth_providers;
|
||||||
|
|
||||||
|
UPDATE users SET password_hash = '' WHERE password_hash IS NULL;
|
||||||
|
ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL;
|
||||||
21
db/migrations/20260316203135_admin_users.sql
Normal file
21
db/migrations/20260316203135_admin_users.sql
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE users
|
||||||
|
ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
CREATE TABLE admin_permissions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
permission TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
UNIQUE (user_id, permission)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_admin_permissions_user ON admin_permissions(user_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE admin_permissions;
|
||||||
|
|
||||||
|
ALTER TABLE users
|
||||||
|
DROP COLUMN is_admin;
|
||||||
9
db/migrations/20260316203138_byoc_teams.sql
Normal file
9
db/migrations/20260316203138_byoc_teams.sql
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE teams
|
||||||
|
ADD COLUMN is_byoc BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
ALTER TABLE teams
|
||||||
|
DROP COLUMN is_byoc;
|
||||||
47
db/migrations/20260316203142_hosts.sql
Normal file
47
db/migrations/20260316203142_hosts.sql
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
CREATE TABLE hosts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL DEFAULT 'regular', -- 'regular' or 'byoc'
|
||||||
|
team_id TEXT REFERENCES teams(id) ON DELETE SET NULL,
|
||||||
|
provider TEXT,
|
||||||
|
availability_zone TEXT,
|
||||||
|
arch TEXT,
|
||||||
|
cpu_cores INTEGER,
|
||||||
|
memory_mb INTEGER,
|
||||||
|
disk_gb INTEGER,
|
||||||
|
address TEXT, -- ip:port of host agent
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending', -- 'pending', 'online', 'offline', 'draining'
|
||||||
|
last_heartbeat_at TIMESTAMPTZ,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}',
|
||||||
|
created_by TEXT NOT NULL REFERENCES users(id),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE host_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||||
|
created_by TEXT NOT NULL REFERENCES users(id),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
used_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE host_tags (
|
||||||
|
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||||
|
tag TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (host_id, tag)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_hosts_type ON hosts(type);
|
||||||
|
CREATE INDEX idx_hosts_team ON hosts(team_id);
|
||||||
|
CREATE INDEX idx_hosts_status ON hosts(status);
|
||||||
|
CREATE INDEX idx_host_tokens_host ON host_tokens(host_id);
|
||||||
|
CREATE INDEX idx_host_tags_tag ON host_tags(tag);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE host_tags;
|
||||||
|
DROP TABLE host_tokens;
|
||||||
|
DROP TABLE hosts;
|
||||||
11
db/migrations/20260316223629_host_mtls.sql
Normal file
11
db/migrations/20260316223629_host_mtls.sql
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE hosts
|
||||||
|
ADD COLUMN cert_fingerprint TEXT,
|
||||||
|
ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
ALTER TABLE hosts
|
||||||
|
DROP COLUMN cert_fingerprint,
|
||||||
|
DROP COLUMN mtls_enabled;
|
||||||
16
db/queries/api_keys.sql
Normal file
16
db/queries/api_keys.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
-- name: InsertAPIKey :one
|
||||||
|
INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetAPIKeyByHash :one
|
||||||
|
SELECT * FROM team_api_keys WHERE key_hash = $1;
|
||||||
|
|
||||||
|
-- name: ListAPIKeysByTeam :many
|
||||||
|
SELECT * FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: DeleteAPIKey :exec
|
||||||
|
DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2;
|
||||||
|
|
||||||
|
-- name: UpdateAPIKeyLastUsed :exec
|
||||||
|
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1;
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
-- name: InsertHost :one
|
||||||
|
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetHost :one
|
||||||
|
SELECT * FROM hosts WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: ListHosts :many
|
||||||
|
SELECT * FROM hosts ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListHostsByType :many
|
||||||
|
SELECT * FROM hosts WHERE type = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListHostsByTeam :many
|
||||||
|
SELECT * FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListHostsByStatus :many
|
||||||
|
SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: RegisterHost :execrows
|
||||||
|
UPDATE hosts
|
||||||
|
SET arch = $2,
|
||||||
|
cpu_cores = $3,
|
||||||
|
memory_mb = $4,
|
||||||
|
disk_gb = $5,
|
||||||
|
address = $6,
|
||||||
|
status = 'online',
|
||||||
|
last_heartbeat_at = NOW(),
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1 AND status = 'pending';
|
||||||
|
|
||||||
|
-- name: UpdateHostStatus :exec
|
||||||
|
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: UpdateHostHeartbeat :exec
|
||||||
|
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: DeleteHost :exec
|
||||||
|
DELETE FROM hosts WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: AddHostTag :exec
|
||||||
|
INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING;
|
||||||
|
|
||||||
|
-- name: RemoveHostTag :exec
|
||||||
|
DELETE FROM host_tags WHERE host_id = $1 AND tag = $2;
|
||||||
|
|
||||||
|
-- name: GetHostTags :many
|
||||||
|
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag;
|
||||||
|
|
||||||
|
-- name: ListHostsByTag :many
|
||||||
|
SELECT h.* FROM hosts h
|
||||||
|
JOIN host_tags ht ON ht.host_id = h.id
|
||||||
|
WHERE ht.tag = $1
|
||||||
|
ORDER BY h.created_at DESC;
|
||||||
|
|
||||||
|
-- name: InsertHostToken :one
|
||||||
|
INSERT INTO host_tokens (id, host_id, created_by, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: MarkHostTokenUsed :exec
|
||||||
|
UPDATE host_tokens SET used_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: GetHostTokensByHost :many
|
||||||
|
SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: GetHostByTeam :one
|
||||||
|
SELECT * FROM hosts WHERE id = $1 AND team_id = $2;
|
||||||
|
|||||||
7
db/queries/oauth.sql
Normal file
7
db/queries/oauth.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
-- name: InsertOAuthProvider :exec
|
||||||
|
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
|
||||||
|
VALUES ($1, $2, $3, $4);
|
||||||
|
|
||||||
|
-- name: GetOAuthProvider :one
|
||||||
|
SELECT * FROM oauth_providers
|
||||||
|
WHERE provider = $1 AND provider_id = $2;
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
-- name: InsertSandbox :one
|
||||||
|
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetSandbox :one
|
||||||
|
SELECT * FROM sandboxes WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: GetSandboxByTeam :one
|
||||||
|
SELECT * FROM sandboxes WHERE id = $1 AND team_id = $2;
|
||||||
|
|
||||||
|
-- name: ListSandboxes :many
|
||||||
|
SELECT * FROM sandboxes ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListSandboxesByTeam :many
|
||||||
|
SELECT * FROM sandboxes WHERE team_id = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListSandboxesByHostAndStatus :many
|
||||||
|
SELECT * FROM sandboxes
|
||||||
|
WHERE host_id = $1 AND status = ANY($2::text[])
|
||||||
|
ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: UpdateSandboxRunning :one
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = 'running',
|
||||||
|
host_ip = $2,
|
||||||
|
guest_ip = $3,
|
||||||
|
started_at = $4,
|
||||||
|
last_active_at = $4,
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: UpdateSandboxStatus :one
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = $2,
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: UpdateLastActive :exec
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET last_active_at = $2,
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: BulkUpdateStatusByIDs :exec
|
||||||
|
UPDATE sandboxes
|
||||||
|
SET status = $2,
|
||||||
|
last_updated = NOW()
|
||||||
|
WHERE id = ANY($1::text[]);
|
||||||
|
|||||||
26
db/queries/teams.sql
Normal file
26
db/queries/teams.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
-- name: InsertTeam :one
|
||||||
|
INSERT INTO teams (id, name)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetTeam :one
|
||||||
|
SELECT * FROM teams WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: InsertTeamMember :exec
|
||||||
|
INSERT INTO users_teams (user_id, team_id, is_default, role)
|
||||||
|
VALUES ($1, $2, $3, $4);
|
||||||
|
|
||||||
|
-- name: GetDefaultTeamForUser :one
|
||||||
|
SELECT t.* FROM teams t
|
||||||
|
JOIN users_teams ut ON ut.team_id = t.id
|
||||||
|
WHERE ut.user_id = $1 AND ut.is_default = TRUE
|
||||||
|
LIMIT 1;
|
||||||
|
|
||||||
|
-- name: SetTeamBYOC :exec
|
||||||
|
UPDATE teams SET is_byoc = $2 WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: GetBYOCTeams :many
|
||||||
|
SELECT * FROM teams WHERE is_byoc = TRUE ORDER BY created_at;
|
||||||
|
|
||||||
|
-- name: GetTeamMembership :one
|
||||||
|
SELECT * FROM users_teams WHERE user_id = $1 AND team_id = $2;
|
||||||
28
db/queries/templates.sql
Normal file
28
db/queries/templates.sql
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
-- name: InsertTemplate :one
|
||||||
|
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetTemplate :one
|
||||||
|
SELECT * FROM templates WHERE name = $1;
|
||||||
|
|
||||||
|
-- name: GetTemplateByTeam :one
|
||||||
|
SELECT * FROM templates WHERE name = $1 AND team_id = $2;
|
||||||
|
|
||||||
|
-- name: ListTemplates :many
|
||||||
|
SELECT * FROM templates ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListTemplatesByType :many
|
||||||
|
SELECT * FROM templates WHERE type = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListTemplatesByTeam :many
|
||||||
|
SELECT * FROM templates WHERE team_id = $1 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: ListTemplatesByTeamAndType :many
|
||||||
|
SELECT * FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC;
|
||||||
|
|
||||||
|
-- name: DeleteTemplate :exec
|
||||||
|
DELETE FROM templates WHERE name = $1;
|
||||||
|
|
||||||
|
-- name: DeleteTemplateByTeam :exec
|
||||||
|
DELETE FROM templates WHERE name = $1 AND team_id = $2;
|
||||||
36
db/queries/users.sql
Normal file
36
db/queries/users.sql
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
-- name: InsertUser :one
|
||||||
|
INSERT INTO users (id, email, password_hash)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetUserByEmail :one
|
||||||
|
SELECT * FROM users WHERE email = $1;
|
||||||
|
|
||||||
|
-- name: GetUserByID :one
|
||||||
|
SELECT * FROM users WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: InsertUserOAuth :one
|
||||||
|
INSERT INTO users (id, email)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: SetUserAdmin :exec
|
||||||
|
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: GetAdminUsers :many
|
||||||
|
SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at;
|
||||||
|
|
||||||
|
-- name: InsertAdminPermission :exec
|
||||||
|
INSERT INTO admin_permissions (id, user_id, permission)
|
||||||
|
VALUES ($1, $2, $3);
|
||||||
|
|
||||||
|
-- name: DeleteAdminPermission :exec
|
||||||
|
DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2;
|
||||||
|
|
||||||
|
-- name: GetAdminPermissions :many
|
||||||
|
SELECT * FROM admin_permissions WHERE user_id = $1 ORDER BY permission;
|
||||||
|
|
||||||
|
-- name: HasAdminPermission :one
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1 FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||||
|
) AS has_permission;
|
||||||
@ -10,6 +10,11 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- pgdata:/var/lib/postgresql/data
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
|
||||||
prometheus:
|
prometheus:
|
||||||
image: prom/prometheus:latest
|
image: prom/prometheus:latest
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
201
envd/LICENSE
Normal file
201
envd/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2023 FoundryLabs, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@ -1,17 +1,62 @@
|
|||||||
LDFLAGS := -s -w
|
BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
|
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
|
||||||
|
BUILDS := ../builds
|
||||||
|
|
||||||
.PHONY: build clean fmt vet
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Build
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
.PHONY: build build-debug
|
||||||
|
|
||||||
build:
|
build:
|
||||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o envd .
|
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
|
||||||
@file envd | grep -q "statically linked" || \
|
@file $(BUILDS)/envd | grep -q "statically linked" || \
|
||||||
(echo "ERROR: envd is not statically linked!" && exit 1)
|
(echo "ERROR: envd is not statically linked!" && exit 1)
|
||||||
|
|
||||||
clean:
|
build-debug:
|
||||||
rm -f envd
|
CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Run (debug mode, not inside a VM)
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
.PHONY: run-debug
|
||||||
|
|
||||||
|
run-debug: build-debug
|
||||||
|
$(BUILDS)/debug/envd -isnotfc -port 49983
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Code Generation
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
.PHONY: generate proto openapi
|
||||||
|
|
||||||
|
generate: proto openapi
|
||||||
|
|
||||||
|
proto:
|
||||||
|
cd spec && buf generate --template buf.gen.yaml
|
||||||
|
|
||||||
|
openapi:
|
||||||
|
go generate ./internal/api/...
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Quality
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
.PHONY: fmt vet test tidy
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
gofmt -w .
|
gofmt -w .
|
||||||
|
|
||||||
vet:
|
vet:
|
||||||
go vet ./...
|
go vet ./...
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test -race -v ./...
|
||||||
|
|
||||||
|
tidy:
|
||||||
|
go mod tidy
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
# Clean
|
||||||
|
# ═══════════════════════════════════════════════════
|
||||||
|
.PHONY: clean
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd
|
||||||
|
|||||||
43
envd/go.mod
43
envd/go.mod
@ -1,9 +1,42 @@
|
|||||||
module github.com/wrenn-dev/envd
|
module git.omukk.dev/wrenn/sandbox/envd
|
||||||
|
|
||||||
go 1.23.0
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/mdlayher/vsock v1.2.1
|
connectrpc.com/authn v0.1.0
|
||||||
google.golang.org/grpc v1.71.0
|
connectrpc.com/connect v1.19.1
|
||||||
google.golang.org/protobuf v1.36.5
|
connectrpc.com/cors v0.1.0
|
||||||
|
github.com/awnumar/memguard v0.23.0
|
||||||
|
github.com/creack/pty v1.1.24
|
||||||
|
github.com/dchest/uniuri v1.2.0
|
||||||
|
github.com/e2b-dev/fsnotify v0.0.1
|
||||||
|
github.com/go-chi/chi/v5 v5.2.5
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/oapi-codegen/runtime v1.2.0
|
||||||
|
github.com/orcaman/concurrent-map/v2 v2.0.1
|
||||||
|
github.com/rs/cors v1.11.1
|
||||||
|
github.com/rs/zerolog v1.34.0
|
||||||
|
github.com/shirou/gopsutil/v4 v4.26.2
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/txn2/txeh v1.8.0
|
||||||
|
golang.org/x/sys v0.42.0
|
||||||
|
google.golang.org/protobuf v1.36.11
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||||
|
github.com/awnumar/memcall v0.4.0 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/ebitengine/purego v0.10.0 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||||
|
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||||
|
github.com/tklauser/numcpus v0.11.0 // indirect
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
|
golang.org/x/crypto v0.41.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
92
envd/go.sum
Normal file
92
envd/go.sum
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
|
||||||
|
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
|
||||||
|
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
|
||||||
|
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
|
||||||
|
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
|
||||||
|
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
|
||||||
|
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||||
|
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||||
|
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||||
|
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||||
|
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||||
|
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||||
|
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
|
||||||
|
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||||
|
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
|
||||||
|
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
|
||||||
|
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
|
||||||
|
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
|
||||||
|
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||||
|
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||||
|
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
|
||||||
|
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
|
||||||
|
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
|
||||||
|
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
|
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||||
|
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
|
||||||
|
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||||
|
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||||
|
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
|
||||||
|
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
|
||||||
|
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||||
|
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||||
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||||
|
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||||
|
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||||
568
envd/internal/api/api.gen.go
Normal file
568
envd/internal/api/api.gen.go
Normal file
@ -0,0 +1,568 @@
|
|||||||
|
// Package api provides primitives to interact with the openapi HTTP API.
|
||||||
|
//
|
||||||
|
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/oapi-codegen/runtime"
|
||||||
|
openapi_types "github.com/oapi-codegen/runtime/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Defines values for EntryInfoType.
|
||||||
|
const (
|
||||||
|
File EntryInfoType = "file"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EntryInfo defines model for EntryInfo.
|
||||||
|
type EntryInfo struct {
|
||||||
|
// Name Name of the file
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Path Path to the file
|
||||||
|
Path string `json:"path"`
|
||||||
|
|
||||||
|
// Type Type of the file
|
||||||
|
Type EntryInfoType `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EntryInfoType Type of the file
|
||||||
|
type EntryInfoType string
|
||||||
|
|
||||||
|
// EnvVars Environment variables to set
|
||||||
|
type EnvVars map[string]string
|
||||||
|
|
||||||
|
// Error defines model for Error.
|
||||||
|
type Error struct {
|
||||||
|
// Code Error code
|
||||||
|
Code int `json:"code"`
|
||||||
|
|
||||||
|
// Message Error message
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics Resource usage metrics
|
||||||
|
type Metrics struct {
|
||||||
|
// CpuCount Number of CPU cores
|
||||||
|
CpuCount *int `json:"cpu_count,omitempty"`
|
||||||
|
|
||||||
|
// CpuUsedPct CPU usage percentage
|
||||||
|
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
|
||||||
|
|
||||||
|
// DiskTotal Total disk space in bytes
|
||||||
|
DiskTotal *int `json:"disk_total,omitempty"`
|
||||||
|
|
||||||
|
// DiskUsed Used disk space in bytes
|
||||||
|
DiskUsed *int `json:"disk_used,omitempty"`
|
||||||
|
|
||||||
|
// MemTotal Total virtual memory in bytes
|
||||||
|
MemTotal *int `json:"mem_total,omitempty"`
|
||||||
|
|
||||||
|
// MemUsed Used virtual memory in bytes
|
||||||
|
MemUsed *int `json:"mem_used,omitempty"`
|
||||||
|
|
||||||
|
// Ts Unix timestamp in UTC for current sandbox time
|
||||||
|
Ts *int64 `json:"ts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VolumeMount Volume
|
||||||
|
type VolumeMount struct {
|
||||||
|
NfsTarget string `json:"nfs_target"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilePath defines model for FilePath.
|
||||||
|
type FilePath = string
|
||||||
|
|
||||||
|
// Signature defines model for Signature.
|
||||||
|
type Signature = string
|
||||||
|
|
||||||
|
// SignatureExpiration defines model for SignatureExpiration.
|
||||||
|
type SignatureExpiration = int
|
||||||
|
|
||||||
|
// User defines model for User.
|
||||||
|
type User = string
|
||||||
|
|
||||||
|
// FileNotFound defines model for FileNotFound.
|
||||||
|
type FileNotFound = Error
|
||||||
|
|
||||||
|
// InternalServerError defines model for InternalServerError.
|
||||||
|
type InternalServerError = Error
|
||||||
|
|
||||||
|
// InvalidPath defines model for InvalidPath.
|
||||||
|
type InvalidPath = Error
|
||||||
|
|
||||||
|
// InvalidUser defines model for InvalidUser.
|
||||||
|
type InvalidUser = Error
|
||||||
|
|
||||||
|
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
|
||||||
|
type NotEnoughDiskSpace = Error
|
||||||
|
|
||||||
|
// UploadSuccess defines model for UploadSuccess.
|
||||||
|
type UploadSuccess = []EntryInfo
|
||||||
|
|
||||||
|
// GetFilesParams defines parameters for GetFiles.
|
||||||
|
type GetFilesParams struct {
|
||||||
|
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||||
|
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||||
|
|
||||||
|
// Username User used for setting the owner, or resolving relative paths.
|
||||||
|
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||||
|
|
||||||
|
// Signature Signature used for file access permission verification.
|
||||||
|
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||||
|
|
||||||
|
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||||
|
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostFilesMultipartBody defines parameters for PostFiles.
|
||||||
|
type PostFilesMultipartBody struct {
|
||||||
|
File *openapi_types.File `json:"file,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostFilesParams defines parameters for PostFiles.
|
||||||
|
type PostFilesParams struct {
|
||||||
|
// Path Path to the file, URL encoded. Can be relative to user's home directory.
|
||||||
|
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
|
||||||
|
|
||||||
|
// Username User used for setting the owner, or resolving relative paths.
|
||||||
|
Username *User `form:"username,omitempty" json:"username,omitempty"`
|
||||||
|
|
||||||
|
// Signature Signature used for file access permission verification.
|
||||||
|
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
|
||||||
|
|
||||||
|
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
|
||||||
|
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostInitJSONBody defines parameters for PostInit.
|
||||||
|
type PostInitJSONBody struct {
|
||||||
|
// AccessToken Access token for secure access to envd service
|
||||||
|
AccessToken *SecureToken `json:"accessToken,omitempty"`
|
||||||
|
|
||||||
|
// DefaultUser The default user to use for operations
|
||||||
|
DefaultUser *string `json:"defaultUser,omitempty"`
|
||||||
|
|
||||||
|
// DefaultWorkdir The default working directory to use for operations
|
||||||
|
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
|
||||||
|
|
||||||
|
// EnvVars Environment variables to set
|
||||||
|
EnvVars *EnvVars `json:"envVars,omitempty"`
|
||||||
|
|
||||||
|
// HyperloopIP IP address of the hyperloop server to connect to
|
||||||
|
HyperloopIP *string `json:"hyperloopIP,omitempty"`
|
||||||
|
|
||||||
|
// Timestamp The current timestamp in RFC3339 format
|
||||||
|
Timestamp *time.Time `json:"timestamp,omitempty"`
|
||||||
|
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
|
||||||
|
type PostFilesMultipartRequestBody PostFilesMultipartBody
|
||||||
|
|
||||||
|
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
|
||||||
|
type PostInitJSONRequestBody PostInitJSONBody
|
||||||
|
|
||||||
|
// ServerInterface represents all server handlers.
|
||||||
|
type ServerInterface interface {
|
||||||
|
// Get the environment variables
|
||||||
|
// (GET /envs)
|
||||||
|
GetEnvs(w http.ResponseWriter, r *http.Request)
|
||||||
|
// Download a file
|
||||||
|
// (GET /files)
|
||||||
|
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
|
||||||
|
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||||
|
// (POST /files)
|
||||||
|
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
|
||||||
|
// Check the health of the service
|
||||||
|
// (GET /health)
|
||||||
|
GetHealth(w http.ResponseWriter, r *http.Request)
|
||||||
|
// Set initial vars, ensure the time and metadata is synced with the host
|
||||||
|
// (POST /init)
|
||||||
|
PostInit(w http.ResponseWriter, r *http.Request)
|
||||||
|
// Get the stats of the service
|
||||||
|
// (GET /metrics)
|
||||||
|
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||||
|
|
||||||
|
type Unimplemented struct{}
|
||||||
|
|
||||||
|
// Get the environment variables
|
||||||
|
// (GET /envs)
|
||||||
|
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download a file
|
||||||
|
// (GET /files)
|
||||||
|
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
|
||||||
|
// (POST /files)
|
||||||
|
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the health of the service
|
||||||
|
// (GET /health)
|
||||||
|
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set initial vars, ensure the time and metadata is synced with the host
|
||||||
|
// (POST /init)
|
||||||
|
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the stats of the service
|
||||||
|
// (GET /metrics)
|
||||||
|
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerInterfaceWrapper converts contexts to parameters.
|
||||||
|
type ServerInterfaceWrapper struct {
|
||||||
|
Handler ServerInterface
|
||||||
|
HandlerMiddlewares []MiddlewareFunc
|
||||||
|
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MiddlewareFunc func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// GetEnvs operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.GetEnvs(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFiles operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
// Parameter object where we will unmarshal all parameters from the context
|
||||||
|
var params GetFilesParams
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "path" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "username" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "signature" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "signature_expiration" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.GetFiles(w, r, params)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostFiles operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
// Parameter object where we will unmarshal all parameters from the context
|
||||||
|
var params PostFilesParams
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "path" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "username" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "signature" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------- Optional query parameter "signature_expiration" -------------
|
||||||
|
|
||||||
|
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
||||||
|
if err != nil {
|
||||||
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.PostFiles(w, r, params)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealth operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.GetHealth(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostInit operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.PostInit(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.GetMetrics(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnescapedCookieParamError struct {
|
||||||
|
ParamName string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnescapedCookieParamError) Error() string {
|
||||||
|
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnescapedCookieParamError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnmarshalingParamError struct {
|
||||||
|
ParamName string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnmarshalingParamError) Error() string {
|
||||||
|
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnmarshalingParamError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequiredParamError struct {
|
||||||
|
ParamName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RequiredParamError) Error() string {
|
||||||
|
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequiredHeaderError struct {
|
||||||
|
ParamName string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RequiredHeaderError) Error() string {
|
||||||
|
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RequiredHeaderError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
type InvalidParamFormatError struct {
|
||||||
|
ParamName string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *InvalidParamFormatError) Error() string {
|
||||||
|
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *InvalidParamFormatError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
type TooManyValuesForParamError struct {
|
||||||
|
ParamName string
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *TooManyValuesForParamError) Error() string {
|
||||||
|
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler creates http.Handler with routing matching OpenAPI spec.
|
||||||
|
func Handler(si ServerInterface) http.Handler {
|
||||||
|
return HandlerWithOptions(si, ChiServerOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChiServerOptions struct {
|
||||||
|
BaseURL string
|
||||||
|
BaseRouter chi.Router
|
||||||
|
Middlewares []MiddlewareFunc
|
||||||
|
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
|
||||||
|
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
|
||||||
|
return HandlerWithOptions(si, ChiServerOptions{
|
||||||
|
BaseRouter: r,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
|
||||||
|
return HandlerWithOptions(si, ChiServerOptions{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
BaseRouter: r,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerWithOptions creates http.Handler with additional options
|
||||||
|
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
|
||||||
|
r := options.BaseRouter
|
||||||
|
|
||||||
|
if r == nil {
|
||||||
|
r = chi.NewRouter()
|
||||||
|
}
|
||||||
|
if options.ErrorHandlerFunc == nil {
|
||||||
|
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wrapper := ServerInterfaceWrapper{
|
||||||
|
Handler: si,
|
||||||
|
HandlerMiddlewares: options.Middlewares,
|
||||||
|
ErrorHandlerFunc: options.ErrorHandlerFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
|
||||||
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
|
||||||
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
|
||||||
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
|
||||||
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Post(options.BaseURL+"/init", wrapper.PostInit)
|
||||||
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
131
envd/internal/api/auth.go
Normal file
131
envd/internal/api/auth.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SigningReadOperation = "read"
|
||||||
|
SigningWriteOperation = "write"
|
||||||
|
|
||||||
|
accessTokenHeader = "X-Access-Token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// paths that are always allowed without general authentication
|
||||||
|
// POST/init is secured via MMDS hash validation instead
|
||||||
|
var authExcludedPaths = []string{
|
||||||
|
"GET/health",
|
||||||
|
"GET/files",
|
||||||
|
"POST/files",
|
||||||
|
"POST/init",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if a.accessToken.IsSet() {
|
||||||
|
authHeader := req.Header.Get(accessTokenHeader)
|
||||||
|
|
||||||
|
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
|
||||||
|
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
|
||||||
|
|
||||||
|
if !a.accessToken.Equals(authHeader) && !allowedPath {
|
||||||
|
a.logger.Error().Msg("Trying to access secured envd without correct access token")
|
||||||
|
|
||||||
|
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
|
||||||
|
jsonError(w, http.StatusUnauthorized, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
|
||||||
|
tokenBytes, err := a.accessToken.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("access token is not set: %w", err)
|
||||||
|
}
|
||||||
|
defer memguard.WipeBytes(tokenBytes)
|
||||||
|
|
||||||
|
var signature string
|
||||||
|
hasher := keys.NewSHA256Hashing()
|
||||||
|
|
||||||
|
if signatureExpiration == nil {
|
||||||
|
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
|
||||||
|
} else {
|
||||||
|
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
|
||||||
|
var expectedSignature string
|
||||||
|
|
||||||
|
// no need to validate signing key if access token is not set
|
||||||
|
if !a.accessToken.IsSet() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if access token is sent in the header
|
||||||
|
tokenFromHeader := r.Header.Get(accessTokenHeader)
|
||||||
|
if tokenFromHeader != "" {
|
||||||
|
if !a.accessToken.Equals(tokenFromHeader) {
|
||||||
|
return fmt.Errorf("access token present in header but does not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if signature == nil {
|
||||||
|
return fmt.Errorf("missing signature query parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty string is used when no username is provided and the default user should be used
|
||||||
|
signatureUsername := ""
|
||||||
|
if username != nil {
|
||||||
|
signatureUsername = *username
|
||||||
|
}
|
||||||
|
|
||||||
|
if signatureExpiration == nil {
|
||||||
|
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
|
||||||
|
} else {
|
||||||
|
exp := int64(*signatureExpiration)
|
||||||
|
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("error generating signing key")
|
||||||
|
|
||||||
|
return errors.New("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// signature validation
|
||||||
|
if expectedSignature != *signature {
|
||||||
|
return fmt.Errorf("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// signature expiration
|
||||||
|
if signatureExpiration != nil {
|
||||||
|
exp := int64(*signatureExpiration)
|
||||||
|
if exp < time.Now().Unix() {
|
||||||
|
return fmt.Errorf("signature is already expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
64
envd/internal/api/auth_test.go
Normal file
64
envd/internal/api/auth_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
apiToken := "secret-access-token"
|
||||||
|
secureToken := &SecureToken{}
|
||||||
|
err := secureToken.Set([]byte(apiToken))
|
||||||
|
require.NoError(t, err)
|
||||||
|
api := &API{accessToken: secureToken}
|
||||||
|
|
||||||
|
path := "/path/to/demo.txt"
|
||||||
|
username := "root"
|
||||||
|
operation := "write"
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
|
||||||
|
signature, err := api.generateSignature(path, username, operation, ×tamp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// locally generated signature
|
||||||
|
hasher := keys.NewSHA256Hashing()
|
||||||
|
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
|
||||||
|
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||||
|
|
||||||
|
assert.Equal(t, localSignature, signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
apiToken := "secret-access-token"
|
||||||
|
secureToken := &SecureToken{}
|
||||||
|
err := secureToken.Set([]byte(apiToken))
|
||||||
|
require.NoError(t, err)
|
||||||
|
api := &API{accessToken: secureToken}
|
||||||
|
|
||||||
|
path := "/path/to/resource.txt"
|
||||||
|
username := "user"
|
||||||
|
operation := "read"
|
||||||
|
|
||||||
|
signature, err := api.generateSignature(path, username, operation, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// locally generated signature
|
||||||
|
hasher := keys.NewSHA256Hashing()
|
||||||
|
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
|
||||||
|
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
|
||||||
|
|
||||||
|
assert.Equal(t, localSignature, signature)
|
||||||
|
}
|
||||||
10
envd/internal/api/cfg.yaml
Normal file
10
envd/internal/api/cfg.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
|
||||||
|
|
||||||
|
package: api
|
||||||
|
output: api.gen.go
|
||||||
|
generate:
|
||||||
|
models: true
|
||||||
|
chi-server: true
|
||||||
|
client: false
|
||||||
175
envd/internal/api/download.go
Normal file
175
envd/internal/api/download.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
var errorCode int
|
||||||
|
var errMsg error
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if params.Path != nil {
|
||||||
|
path = *params.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
operationID := logs.AssignOperationID()
|
||||||
|
|
||||||
|
// signing authorization if needed
|
||||||
|
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||||
|
jsonError(w, http.StatusUnauthorized, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||||
|
jsonError(w, http.StatusBadRequest, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
l := a.logger.
|
||||||
|
Err(errMsg).
|
||||||
|
Str("method", r.Method+" "+r.URL.Path).
|
||||||
|
Str(string(logs.OperationIDKey), operationID).
|
||||||
|
Str("path", path).
|
||||||
|
Str("username", username)
|
||||||
|
|
||||||
|
if errMsg != nil {
|
||||||
|
l = l.Int("error_code", errorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Msg("File read")
|
||||||
|
}()
|
||||||
|
|
||||||
|
u, err := user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||||
|
errorCode = http.StatusUnauthorized
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
|
||||||
|
errorCode = http.StatusBadRequest
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, err := os.Stat(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
|
||||||
|
errorCode = http.StatusNotFound
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
|
||||||
|
errorCode = http.StatusInternalServerError
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.IsDir() {
|
||||||
|
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
|
||||||
|
errorCode = http.StatusBadRequest
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Accept-Encoding header
|
||||||
|
encoding, err := parseAcceptEncoding(r)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
|
||||||
|
errorCode = http.StatusNotAcceptable
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tell caches to store separate variants for different Accept-Encoding values
|
||||||
|
w.Header().Set("Vary", "Accept-Encoding")
|
||||||
|
|
||||||
|
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
|
||||||
|
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
|
||||||
|
// is acceptable per the Accept-Encoding header.
|
||||||
|
hasRangeOrConditional := r.Header.Get("Range") != "" ||
|
||||||
|
r.Header.Get("If-Modified-Since") != "" ||
|
||||||
|
r.Header.Get("If-None-Match") != "" ||
|
||||||
|
r.Header.Get("If-Range") != ""
|
||||||
|
if hasRangeOrConditional {
|
||||||
|
if !isIdentityAcceptable(r) {
|
||||||
|
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
|
||||||
|
errorCode = http.StatusNotAcceptable
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
encoding = EncodingIdentity
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
|
||||||
|
errorCode = http.StatusInternalServerError
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
|
||||||
|
|
||||||
|
// Serve with gzip encoding if requested.
|
||||||
|
if encoding == EncodingGzip {
|
||||||
|
w.Header().Set("Content-Encoding", EncodingGzip)
|
||||||
|
|
||||||
|
// Set Content-Type based on file extension, preserving the original type
|
||||||
|
contentType := mime.TypeByExtension(filepath.Ext(path))
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", contentType)
|
||||||
|
|
||||||
|
gw := gzip.NewWriter(w)
|
||||||
|
defer gw.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(gw, file)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeContent(w, r, path, stat.ModTime(), file)
|
||||||
|
}
|
||||||
403
envd/internal/api/download_test.go
Normal file
403
envd/internal/api/download_test.go
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetFilesContentDisposition(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filename string
|
||||||
|
expectedHeader string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple filename",
|
||||||
|
filename: "test.txt",
|
||||||
|
expectedHeader: `inline; filename=test.txt`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filename with extension",
|
||||||
|
filename: "presentation.pptx",
|
||||||
|
expectedHeader: `inline; filename=presentation.pptx`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filename with multiple dots",
|
||||||
|
filename: "archive.tar.gz",
|
||||||
|
expectedHeader: `inline; filename=archive.tar.gz`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filename with spaces",
|
||||||
|
filename: "my document.pdf",
|
||||||
|
expectedHeader: `inline; filename="my document.pdf"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filename with quotes",
|
||||||
|
filename: `file"name.txt`,
|
||||||
|
expectedHeader: `inline; filename="file\"name.txt"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filename with backslash",
|
||||||
|
filename: `file\name.txt`,
|
||||||
|
expectedHeader: `inline; filename="file\\name.txt"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode filename",
|
||||||
|
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
||||||
|
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dotfile preserved",
|
||||||
|
filename: ".env",
|
||||||
|
expectedHeader: `inline; filename=.env`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dotfile with extension preserved",
|
||||||
|
filename: ".gitignore",
|
||||||
|
expectedHeader: `inline; filename=.gitignore`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create a temp directory and file
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, tt.filename)
|
||||||
|
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test API
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
// Create request and response recorder
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
params := GetFilesParams{
|
||||||
|
Path: &tempFile,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.GetFiles(w, req, params)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Verify Content-Disposition header
|
||||||
|
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||||
|
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a temp directory with nested structure
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
||||||
|
err = os.MkdirAll(nestedDir, 0o755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
filename := "document.pdf"
|
||||||
|
tempFile := filepath.Join(nestedDir, filename)
|
||||||
|
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test API
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
// Create request and response recorder
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
params := GetFilesParams{
|
||||||
|
Path: &tempFile,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.GetFiles(w, req, params)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Verify Content-Disposition header uses only the base filename, not the full path
|
||||||
|
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||||
|
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a temp directory with a test file
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
filename := "document.pdf"
|
||||||
|
tempFile := filepath.Join(tempDir, filename)
|
||||||
|
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test API
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
// Create request and response recorder
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
||||||
|
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
params := GetFilesParams{
|
||||||
|
Path: &tempFile,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.GetFiles(w, req, params)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFiles_GzipDownload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originalContent := []byte("hello world, this is a test file for gzip compression")
|
||||||
|
|
||||||
|
// Create a temp file with known content
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "test.txt")
|
||||||
|
err = os.WriteFile(tempFile, originalContent, 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
params := GetFilesParams{
|
||||||
|
Path: &tempFile,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.GetFiles(w, req, params)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||||
|
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
// Decompress the gzip response body
|
||||||
|
gzReader, err := gzip.NewReader(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer gzReader.Close()
|
||||||
|
|
||||||
|
decompressed, err := io.ReadAll(gzReader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, originalContent, decompressed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostFiles_GzipUpload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
||||||
|
|
||||||
|
// Build a multipart body
|
||||||
|
var multipartBuf bytes.Buffer
|
||||||
|
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||||
|
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write(originalContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = mpWriter.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Gzip-compress the entire multipart body
|
||||||
|
var gzBuf bytes.Buffer
|
||||||
|
gzWriter := gzip.NewWriter(&gzBuf)
|
||||||
|
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gzWriter.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test API
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
destPath := filepath.Join(tempDir, "uploaded.txt")
|
||||||
|
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||||
|
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||||
|
req.Header.Set("Content-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
params := PostFilesParams{
|
||||||
|
Path: &destPath,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.PostFiles(w, req, params)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Verify the file was written with the original (decompressed) content
|
||||||
|
data, err := os.ReadFile(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, originalContent, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
||||||
|
|
||||||
|
// --- Upload with gzip ---
|
||||||
|
|
||||||
|
// Build a multipart body
|
||||||
|
var multipartBuf bytes.Buffer
|
||||||
|
mpWriter := multipart.NewWriter(&multipartBuf)
|
||||||
|
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write(originalContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = mpWriter.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Gzip-compress the entire multipart body
|
||||||
|
var gzBuf bytes.Buffer
|
||||||
|
gzWriter := gzip.NewWriter(&gzBuf)
|
||||||
|
_, err = gzWriter.Write(multipartBuf.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gzWriter.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
||||||
|
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
User: currentUser.Username,
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
|
||||||
|
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||||
|
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||||
|
uploadReq.Header.Set("Content-Encoding", "gzip")
|
||||||
|
uploadW := httptest.NewRecorder()
|
||||||
|
|
||||||
|
uploadParams := PostFilesParams{
|
||||||
|
Path: &destPath,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.PostFiles(uploadW, uploadReq, uploadParams)
|
||||||
|
|
||||||
|
uploadResp := uploadW.Result()
|
||||||
|
defer uploadResp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
||||||
|
|
||||||
|
// --- Download with gzip ---
|
||||||
|
|
||||||
|
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
||||||
|
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
downloadW := httptest.NewRecorder()
|
||||||
|
|
||||||
|
downloadParams := GetFilesParams{
|
||||||
|
Path: &destPath,
|
||||||
|
Username: ¤tUser.Username,
|
||||||
|
}
|
||||||
|
api.GetFiles(downloadW, downloadReq, downloadParams)
|
||||||
|
|
||||||
|
downloadResp := downloadW.Result()
|
||||||
|
defer downloadResp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
||||||
|
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
||||||
|
|
||||||
|
// Decompress and verify content matches original
|
||||||
|
gzReader, err := gzip.NewReader(downloadResp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer gzReader.Close()
|
||||||
|
|
||||||
|
decompressed, err := io.ReadAll(gzReader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, originalContent, decompressed)
|
||||||
|
}
|
||||||
229
envd/internal/api/encoding.go
Normal file
229
envd/internal/api/encoding.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// EncodingGzip is the gzip content encoding.
|
||||||
|
EncodingGzip = "gzip"
|
||||||
|
// EncodingIdentity means no encoding (passthrough).
|
||||||
|
EncodingIdentity = "identity"
|
||||||
|
// EncodingWildcard means any encoding is acceptable.
|
||||||
|
EncodingWildcard = "*"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SupportedEncodings lists the content encodings supported for file transfer.
|
||||||
|
// The order matters - encodings are checked in order of preference.
|
||||||
|
var SupportedEncodings = []string{
|
||||||
|
EncodingGzip,
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodingWithQuality holds an encoding name and its quality value.
|
||||||
|
type encodingWithQuality struct {
|
||||||
|
encoding string
|
||||||
|
quality float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSupportedEncoding checks if the given encoding is in the supported list.
|
||||||
|
// Per RFC 7231, content-coding values are case-insensitive.
|
||||||
|
func isSupportedEncoding(encoding string) bool {
|
||||||
|
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseEncodingWithQuality parses an encoding value and extracts the quality.
|
||||||
|
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
|
||||||
|
// Per RFC 7231, content-coding values are case-insensitive.
|
||||||
|
func parseEncodingWithQuality(value string) encodingWithQuality {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
quality := 1.0
|
||||||
|
|
||||||
|
if idx := strings.Index(value, ";"); idx != -1 {
|
||||||
|
params := value[idx+1:]
|
||||||
|
value = strings.TrimSpace(value[:idx])
|
||||||
|
|
||||||
|
// Parse q=X.X parameter
|
||||||
|
for param := range strings.SplitSeq(params, ";") {
|
||||||
|
param = strings.TrimSpace(param)
|
||||||
|
if strings.HasPrefix(strings.ToLower(param), "q=") {
|
||||||
|
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
|
||||||
|
quality = q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize encoding to lowercase per RFC 7231
|
||||||
|
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseEncoding extracts the encoding name from a header value, stripping quality.
|
||||||
|
func parseEncoding(value string) string {
|
||||||
|
return parseEncodingWithQuality(value).encoding
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
|
||||||
|
// Returns an error if an unsupported encoding is specified.
|
||||||
|
// If no Content-Encoding header is present, returns empty string.
|
||||||
|
func parseContentEncoding(r *http.Request) (string, error) {
|
||||||
|
header := r.Header.Get("Content-Encoding")
|
||||||
|
if header == "" {
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
encoding := parseEncoding(header)
|
||||||
|
|
||||||
|
if encoding == EncodingIdentity {
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSupportedEncoding(encoding) {
|
||||||
|
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
|
||||||
|
}
|
||||||
|
|
||||||
|
return encoding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
|
||||||
|
// the parsed encodings along with the identity rejection state.
|
||||||
|
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
|
||||||
|
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
|
||||||
|
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
|
||||||
|
if header == "" {
|
||||||
|
return nil, false // identity not rejected when header is empty
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse all encodings with their quality values
|
||||||
|
var encodings []encodingWithQuality
|
||||||
|
for value := range strings.SplitSeq(header, ",") {
|
||||||
|
eq := parseEncodingWithQuality(value)
|
||||||
|
encodings = append(encodings, eq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if identity is rejected per RFC 7231 Section 5.3.4:
|
||||||
|
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
|
||||||
|
// without a more specific entry for identity with q>0.
|
||||||
|
identityRejected := false
|
||||||
|
identityExplicitlyAccepted := false
|
||||||
|
wildcardRejected := false
|
||||||
|
|
||||||
|
for _, eq := range encodings {
|
||||||
|
switch eq.encoding {
|
||||||
|
case EncodingIdentity:
|
||||||
|
if eq.quality == 0 {
|
||||||
|
identityRejected = true
|
||||||
|
} else {
|
||||||
|
identityExplicitlyAccepted = true
|
||||||
|
}
|
||||||
|
case EncodingWildcard:
|
||||||
|
if eq.quality == 0 {
|
||||||
|
wildcardRejected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if wildcardRejected && !identityExplicitlyAccepted {
|
||||||
|
identityRejected = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodings, identityRejected
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIdentityAcceptable checks if identity encoding is acceptable based on the
|
||||||
|
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
|
||||||
|
// implicitly acceptable unless explicitly rejected with q=0.
|
||||||
|
func isIdentityAcceptable(r *http.Request) bool {
|
||||||
|
header := r.Header.Get("Accept-Encoding")
|
||||||
|
_, identityRejected := parseAcceptEncodingHeader(header)
|
||||||
|
|
||||||
|
return !identityRejected
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
|
||||||
|
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
|
||||||
|
// identity is always implicitly acceptable unless explicitly rejected with q=0.
|
||||||
|
// If no Accept-Encoding header is present, returns empty string (identity).
|
||||||
|
func parseAcceptEncoding(r *http.Request) (string, error) {
|
||||||
|
header := r.Header.Get("Accept-Encoding")
|
||||||
|
if header == "" {
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
encodings, identityRejected := parseAcceptEncodingHeader(header)
|
||||||
|
|
||||||
|
// Sort by quality value (highest first)
|
||||||
|
sort.Slice(encodings, func(i, j int) bool {
|
||||||
|
return encodings[i].quality > encodings[j].quality
|
||||||
|
})
|
||||||
|
|
||||||
|
// Find the best supported encoding
|
||||||
|
for _, eq := range encodings {
|
||||||
|
// Skip encodings with q=0 (explicitly rejected)
|
||||||
|
if eq.quality == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if eq.encoding == EncodingIdentity {
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
|
||||||
|
if eq.encoding == EncodingWildcard {
|
||||||
|
if identityRejected && len(SupportedEncodings) > 0 {
|
||||||
|
return SupportedEncodings[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isSupportedEncoding(eq.encoding) {
|
||||||
|
return eq.encoding, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per RFC 7231, identity is implicitly acceptable unless rejected
|
||||||
|
if !identityRejected {
|
||||||
|
return EncodingIdentity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identity rejected and no supported encodings found
|
||||||
|
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDecompressedBody returns a reader that decompresses the request body based on
|
||||||
|
// Content-Encoding header. Returns the original body if no encoding is specified.
|
||||||
|
// Returns an error if an unsupported encoding is specified.
|
||||||
|
// The caller is responsible for closing both the returned ReadCloser and the
|
||||||
|
// original request body (r.Body) separately.
|
||||||
|
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
|
||||||
|
encoding, err := parseContentEncoding(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding == EncodingIdentity {
|
||||||
|
return r.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch encoding {
|
||||||
|
case EncodingGzip:
|
||||||
|
gzReader, err := gzip.NewReader(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gzReader, nil
|
||||||
|
default:
|
||||||
|
// This shouldn't happen if isSupportedEncoding is correct
|
||||||
|
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
|
||||||
|
}
|
||||||
|
}
|
||||||
496
envd/internal/api/encoding_test.go
Normal file
496
envd/internal/api/encoding_test.go
Normal file
@ -0,0 +1,496 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsSupportedEncoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("gzip is supported", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.True(t, isSupportedEncoding("gzip"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.True(t, isSupportedEncoding("GZIP"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.True(t, isSupportedEncoding("Gzip"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("br is not supported", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.False(t, isSupportedEncoding("br"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deflate is not supported", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.False(t, isSupportedEncoding("deflate"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseEncodingWithQuality(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("gzip")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parses quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("gzip;q=0.5")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parses quality value with whitespace", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("gzip ; q=0.8")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 0.8, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles q=0", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("gzip;q=0")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 0.0, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles invalid quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("gzip;q=invalid")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("trims whitespace from encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality(" gzip ")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 1.0, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("GZIP")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normalizes mixed case encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
eq := parseEncodingWithQuality("Gzip;q=0.5")
|
||||||
|
assert.Equal(t, "gzip", eq.encoding)
|
||||||
|
assert.InDelta(t, 0.5, eq.quality, 0.001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseEncoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns encoding as-is", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.Equal(t, "gzip", parseEncoding("gzip"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("trims whitespace", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.Equal(t, "gzip", parseEncoding(" gzip "))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strips quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strips quality value with whitespace", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseContentEncoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns identity when no header", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "gzip")
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "GZIP")
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "Gzip")
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "identity")
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "br")
|
||||||
|
|
||||||
|
_, err := parseContentEncoding(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||||
|
assert.Contains(t, err.Error(), "supported: [gzip]")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
|
||||||
|
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||||
|
|
||||||
|
encoding, err := parseContentEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAcceptEncoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns identity when no header", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "GZIP")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip with quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns identity for identity encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "identity")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "*")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "deflate, br")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skips encoding with q=0", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br, identity;q=0")
|
||||||
|
|
||||||
|
_, err := parseAcceptEncoding(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "*, identity;q=0")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "*;q=0")
|
||||||
|
|
||||||
|
_, err := parseAcceptEncoding(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no acceptable encoding found")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "*;q=0, identity")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingIdentity, encoding)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
|
||||||
|
|
||||||
|
encoding, err := parseAcceptEncoding(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, EncodingGzip, encoding)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDecompressedBody(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
content := []byte("test content")
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||||
|
|
||||||
|
body, err := getDecompressedBody(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, req.Body, body, "should return original body")
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, content, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
originalContent := []byte("test content to compress")
|
||||||
|
|
||||||
|
var compressed bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&compressed)
|
||||||
|
_, err := gw.Write(originalContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gw.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||||
|
req.Header.Set("Content-Encoding", "gzip")
|
||||||
|
|
||||||
|
body, err := getDecompressedBody(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, originalContent, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for invalid gzip data", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
invalidGzip := []byte("this is not gzip data")
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
|
||||||
|
req.Header.Set("Content-Encoding", "gzip")
|
||||||
|
|
||||||
|
_, err := getDecompressedBody(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to create gzip reader")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns original body for identity encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
content := []byte("test content")
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||||
|
req.Header.Set("Content-Encoding", "identity")
|
||||||
|
|
||||||
|
body, err := getDecompressedBody(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, req.Body, body, "should return original body")
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, content, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for unsupported encoding", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
content := []byte("test content")
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
|
||||||
|
req.Header.Set("Content-Encoding", "br")
|
||||||
|
|
||||||
|
_, err := getDecompressedBody(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles gzip with quality value", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
originalContent := []byte("test content to compress")
|
||||||
|
|
||||||
|
var compressed bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&compressed)
|
||||||
|
_, err := gw.Write(originalContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = gw.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
|
||||||
|
req.Header.Set("Content-Encoding", "gzip;q=1.0")
|
||||||
|
|
||||||
|
body, err := getDecompressedBody(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, originalContent, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
31
envd/internal/api/envs.go
Normal file
31
envd/internal/api/envs.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
operationID := logs.AssignOperationID()
|
||||||
|
|
||||||
|
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
|
||||||
|
|
||||||
|
envs := make(EnvVars)
|
||||||
|
a.defaults.EnvVars.Range(func(key, value string) bool {
|
||||||
|
envs[key] = value
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(envs); err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
|
||||||
|
}
|
||||||
|
}
|
||||||
23
envd/internal/api/error.go
Normal file
23
envd/internal/api/error.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func jsonError(w http.ResponseWriter, code int, err error) {
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
|
w.WriteHeader(code)
|
||||||
|
encodeErr := json.NewEncoder(w).Encode(Error{
|
||||||
|
Code: code,
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
if encodeErr != nil {
|
||||||
|
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
5
envd/internal/api/generate.go
Normal file
5
envd/internal/api/generate.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml
|
||||||
317
envd/internal/api/init.go
Normal file
317
envd/internal/api/init.go
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/txn2/txeh"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAccessTokenMismatch = errors.New("access token validation failed")
|
||||||
|
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxTimeInPast = 50 * time.Millisecond
|
||||||
|
maxTimeInFuture = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateInitAccessToken validates the access token for /init requests.
|
||||||
|
// Token is valid if it matches the existing token OR the MMDS hash.
|
||||||
|
// If neither exists, first-time setup is allowed.
|
||||||
|
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
|
||||||
|
requestTokenSet := requestToken.IsSet()
|
||||||
|
|
||||||
|
// Fast path: token matches existing
|
||||||
|
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check MMDS only if token didn't match existing
|
||||||
|
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case matchesMMDS:
|
||||||
|
return nil
|
||||||
|
case !a.accessToken.IsSet() && !mmdsExists:
|
||||||
|
return nil // first-time setup
|
||||||
|
case !requestTokenSet:
|
||||||
|
return ErrAccessTokenResetNotAuthorized
|
||||||
|
default:
|
||||||
|
return ErrAccessTokenMismatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMMDSHash checks if the request token matches the MMDS hash.
|
||||||
|
// Returns (matches, mmdsExists).
|
||||||
|
//
|
||||||
|
// The MMDS hash is set by the orchestrator during Resume:
|
||||||
|
// - hash(token): requires this specific token
|
||||||
|
// - hash(""): explicitly allows nil token (token reset authorized)
|
||||||
|
// - "": MMDS not properly configured, no authorization granted
|
||||||
|
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
|
||||||
|
if a.isNotFC {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if mmdsHash == "" {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !requestToken.IsSet() {
|
||||||
|
return mmdsHash == keys.HashAccessToken(""), true
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenBytes, err := requestToken.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
defer memguard.WipeBytes(tokenBytes)
|
||||||
|
|
||||||
|
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
operationID := logs.AssignOperationID()
|
||||||
|
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
|
||||||
|
|
||||||
|
if r.Body != nil {
|
||||||
|
// Read raw body so we can wipe it after parsing
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
// Ensure body is wiped after we're done
|
||||||
|
defer memguard.WipeBytes(body)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error().Msgf("Failed to read request body: %v", err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var initRequest PostInitJSONBody
|
||||||
|
if len(body) > 0 {
|
||||||
|
err = json.Unmarshal(body, &initRequest)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error().Msgf("Failed to decode request: %v", err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure request token is destroyed if not transferred via TakeFrom.
|
||||||
|
// This handles: validation failures, timestamp-based skips, and any early returns.
|
||||||
|
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
|
||||||
|
defer initRequest.AccessToken.Destroy()
|
||||||
|
|
||||||
|
a.initLock.Lock()
|
||||||
|
defer a.initLock.Unlock()
|
||||||
|
|
||||||
|
// Update data only if the request is newer or if there's no timestamp at all
|
||||||
|
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
|
||||||
|
err = a.SetData(ctx, logger, initRequest)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
default:
|
||||||
|
logger.Error().Msgf("Failed to set data: %v", err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
w.Write([]byte(err.Error()))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() { //nolint:contextcheck // TODO: fix this later
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||||
|
}()
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.Header().Set("Content-Type", "")
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
|
||||||
|
// Validate access token before proceeding with any action
|
||||||
|
// The request must provide a token that is either:
|
||||||
|
// 1. Matches the existing access token (if set), OR
|
||||||
|
// 2. Matches the MMDS hash (for token change during resume)
|
||||||
|
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Timestamp != nil {
|
||||||
|
// Check if current time differs significantly from the received timestamp
|
||||||
|
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
|
||||||
|
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
|
||||||
|
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
|
||||||
|
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error().Msgf("Failed to set system time: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.EnvVars != nil {
|
||||||
|
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
|
||||||
|
|
||||||
|
for key, value := range *data.EnvVars {
|
||||||
|
logger.Debug().Msgf("Setting env var for %s", key)
|
||||||
|
a.defaults.EnvVars.Store(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.AccessToken.IsSet() {
|
||||||
|
logger.Debug().Msg("Setting access token")
|
||||||
|
a.accessToken.TakeFrom(data.AccessToken)
|
||||||
|
} else if a.accessToken.IsSet() {
|
||||||
|
logger.Debug().Msg("Clearing access token")
|
||||||
|
a.accessToken.Destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.HyperloopIP != nil {
|
||||||
|
go a.SetupHyperloop(*data.HyperloopIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.DefaultUser != nil && *data.DefaultUser != "" {
|
||||||
|
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
|
||||||
|
a.defaults.User = *data.DefaultUser
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
|
||||||
|
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
|
||||||
|
a.defaults.Workdir = data.DefaultWorkdir
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.VolumeMounts != nil {
|
||||||
|
for _, volume := range *data.VolumeMounts {
|
||||||
|
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
|
||||||
|
|
||||||
|
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
|
||||||
|
commands := [][]string{
|
||||||
|
{"mkdir", "-p", path},
|
||||||
|
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, command := range commands {
|
||||||
|
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
|
||||||
|
|
||||||
|
logger := a.getLogger(err)
|
||||||
|
|
||||||
|
logger.
|
||||||
|
Strs("command", command).
|
||||||
|
Str("output", string(data)).
|
||||||
|
Msg("Mount NFS")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) SetupHyperloop(address string) {
|
||||||
|
a.hyperloopLock.Lock()
|
||||||
|
defer a.hyperloopLock.Unlock()
|
||||||
|
|
||||||
|
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("failed to modify hosts file")
|
||||||
|
} else {
|
||||||
|
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const eventsHost = "events.wrenn.local"
|
||||||
|
|
||||||
|
func rewriteHostsFile(address, path string) error {
|
||||||
|
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
|
||||||
|
ReadFilePath: path,
|
||||||
|
WriteFilePath: path,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create hosts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
|
||||||
|
// This will remove any existing entries for events.wrenn.local first
|
||||||
|
ipFamily, err := getIPFamily(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get ip family: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
|
||||||
|
return nil // nothing to be done
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts.AddHost(address, eventsHost)
|
||||||
|
|
||||||
|
return hosts.Save()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidAddress = errors.New("invalid IP address")
|
||||||
|
ErrUnknownAddressFormat = errors.New("unknown IP address format")
|
||||||
|
)
|
||||||
|
|
||||||
|
func getIPFamily(address string) (txeh.IPFamily, error) {
|
||||||
|
addressIP, err := netip.ParseAddr(address)
|
||||||
|
if err != nil {
|
||||||
|
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case addressIP.Is4():
|
||||||
|
return txeh.IPFamilyV4, nil
|
||||||
|
case addressIP.Is6():
|
||||||
|
return txeh.IPFamilyV6, nil
|
||||||
|
default:
|
||||||
|
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
|
||||||
|
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
|
||||||
|
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
|
||||||
|
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
|
||||||
|
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
|
||||||
|
}
|
||||||
590
envd/internal/api/init_test.go
Normal file
590
envd/internal/api/init_test.go
Normal file
@ -0,0 +1,590 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
|
||||||
|
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimpleCases(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := map[string]func(string) string{
|
||||||
|
"both newlines": func(s string) string { return s },
|
||||||
|
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
|
||||||
|
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
|
||||||
|
"no newline prefix or suffix": strings.TrimSpace,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, preprocessor := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
value := `
|
||||||
|
# comment
|
||||||
|
127.0.0.1 one.host
|
||||||
|
127.0.0.2 two.host
|
||||||
|
`
|
||||||
|
value = preprocessor(value)
|
||||||
|
inputPath := filepath.Join(tempDir, "hosts")
|
||||||
|
err := os.WriteFile(inputPath, []byte(value), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = rewriteHostsFile("127.0.0.3", inputPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(inputPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, `# comment
|
||||||
|
127.0.0.1 one.host
|
||||||
|
127.0.0.2 two.host
|
||||||
|
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldSetSystemTime(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
sandboxTime := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hostTime time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "sandbox time far ahead of host time (should set)",
|
||||||
|
hostTime: sandboxTime.Add(-10 * time.Second),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(-50 * time.Millisecond),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(-40 * time.Millisecond),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time slightly ahead of host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(-10 * time.Millisecond),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time equals host time (should not set)",
|
||||||
|
hostTime: sandboxTime,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time slightly behind host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(1 * time.Second),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(4 * time.Second),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
|
||||||
|
hostTime: sandboxTime.Add(5 * time.Second),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sandbox time far behind host time (should set)",
|
||||||
|
hostTime: sandboxTime.Add(1 * time.Minute),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func secureTokenPtr(s string) *SecureToken {
|
||||||
|
token := &SecureToken{}
|
||||||
|
_ = token.Set([]byte(s))
|
||||||
|
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockMMDSClient struct {
|
||||||
|
hash string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
|
||||||
|
return m.hash, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
defaults := &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
}
|
||||||
|
api := New(&logger, defaults, nil, false)
|
||||||
|
if accessToken != nil {
|
||||||
|
api.accessToken.TakeFrom(accessToken)
|
||||||
|
}
|
||||||
|
api.mmdsClient = mmdsClient
|
||||||
|
|
||||||
|
return api
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateInitAccessToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessToken *SecureToken
|
||||||
|
requestToken *SecureToken
|
||||||
|
mmdsHash string
|
||||||
|
mmdsErr error
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fast path: token matches existing",
|
||||||
|
accessToken: secureTokenPtr("secret-token"),
|
||||||
|
requestToken: secureTokenPtr("secret-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MMDS match: token hash matches MMDS hash",
|
||||||
|
accessToken: secureTokenPtr("old-token"),
|
||||||
|
requestToken: secureTokenPtr("new-token"),
|
||||||
|
mmdsHash: keys.HashAccessToken("new-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first-time setup: no existing token, MMDS error",
|
||||||
|
accessToken: nil,
|
||||||
|
requestToken: secureTokenPtr("new-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first-time setup: no existing token, empty MMDS hash",
|
||||||
|
accessToken: nil,
|
||||||
|
requestToken: secureTokenPtr("new-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first-time setup: both tokens nil, no MMDS",
|
||||||
|
accessToken: nil,
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatch: existing token differs from request, no MMDS",
|
||||||
|
accessToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: secureTokenPtr("wrong-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: ErrAccessTokenMismatch,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatch: existing token differs from request, MMDS hash mismatch",
|
||||||
|
accessToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: secureTokenPtr("wrong-token"),
|
||||||
|
mmdsHash: keys.HashAccessToken("different-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: ErrAccessTokenMismatch,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflict: existing token, nil request, MMDS exists",
|
||||||
|
accessToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: keys.HashAccessToken("some-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflict: existing token, nil request, no MMDS",
|
||||||
|
accessToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||||
|
api := newTestAPI(tt.accessToken, mmdsClient)
|
||||||
|
|
||||||
|
err := api.validateInitAccessToken(ctx, tt.requestToken)
|
||||||
|
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckMMDSHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
token := "my-secret-token"
|
||||||
|
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
|
||||||
|
|
||||||
|
assert.True(t, matches)
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
|
||||||
|
|
||||||
|
assert.False(t, matches)
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||||
|
|
||||||
|
assert.False(t, matches)
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||||
|
|
||||||
|
assert.False(t, matches)
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
|
||||||
|
|
||||||
|
assert.False(t, matches)
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||||
|
|
||||||
|
assert.False(t, matches)
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
matches, exists := api.checkMMDSHash(ctx, nil)
|
||||||
|
|
||||||
|
assert.True(t, matches)
|
||||||
|
assert.True(t, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := context.Background()
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
|
||||||
|
t.Run("access token updates", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
existingToken *SecureToken
|
||||||
|
requestToken *SecureToken
|
||||||
|
mmdsHash string
|
||||||
|
mmdsErr error
|
||||||
|
wantErr error
|
||||||
|
wantFinalToken *SecureToken
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first-time setup: sets initial token",
|
||||||
|
existingToken: nil,
|
||||||
|
requestToken: secureTokenPtr("initial-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: nil,
|
||||||
|
wantFinalToken: secureTokenPtr("initial-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first-time setup: nil request token leaves token unset",
|
||||||
|
existingToken: nil,
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: nil,
|
||||||
|
wantFinalToken: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "re-init with same token: token unchanged",
|
||||||
|
existingToken: secureTokenPtr("same-token"),
|
||||||
|
requestToken: secureTokenPtr("same-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: nil,
|
||||||
|
wantFinalToken: secureTokenPtr("same-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resume with MMDS: updates token when hash matches",
|
||||||
|
existingToken: secureTokenPtr("old-token"),
|
||||||
|
requestToken: secureTokenPtr("new-token"),
|
||||||
|
mmdsHash: keys.HashAccessToken("new-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
wantFinalToken: secureTokenPtr("new-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resume with MMDS: fails when hash doesn't match",
|
||||||
|
existingToken: secureTokenPtr("old-token"),
|
||||||
|
requestToken: secureTokenPtr("new-token"),
|
||||||
|
mmdsHash: keys.HashAccessToken("different-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: ErrAccessTokenMismatch,
|
||||||
|
wantFinalToken: secureTokenPtr("old-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fails when existing token and request token mismatch without MMDS",
|
||||||
|
existingToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: secureTokenPtr("wrong-token"),
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: ErrAccessTokenMismatch,
|
||||||
|
wantFinalToken: secureTokenPtr("existing-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflict when existing token but nil request token",
|
||||||
|
existingToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: assert.AnError,
|
||||||
|
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||||
|
wantFinalToken: secureTokenPtr("existing-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflict when existing token but nil request with MMDS present",
|
||||||
|
existingToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: keys.HashAccessToken("some-token"),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||||
|
wantFinalToken: secureTokenPtr("existing-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
|
||||||
|
existingToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: "",
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: ErrAccessTokenResetNotAuthorized,
|
||||||
|
wantFinalToken: secureTokenPtr("existing-token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
|
||||||
|
existingToken: secureTokenPtr("existing-token"),
|
||||||
|
requestToken: nil,
|
||||||
|
mmdsHash: keys.HashAccessToken(""),
|
||||||
|
mmdsErr: nil,
|
||||||
|
wantErr: nil,
|
||||||
|
wantFinalToken: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
|
||||||
|
api := newTestAPI(tt.existingToken, mmdsClient)
|
||||||
|
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
AccessToken: tt.requestToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantFinalToken == nil {
|
||||||
|
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
|
||||||
|
} else {
|
||||||
|
require.True(t, api.accessToken.IsSet(), "expected token to be set")
|
||||||
|
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets environment variables", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
EnvVars: &envVars,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
val, ok := api.defaults.EnvVars.Load("FOO")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
val, ok = api.defaults.EnvVars.Load("BAZ")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "qux", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets default user", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
DefaultUser: utilsShared.ToPtr("testuser"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "testuser", api.defaults.User)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not set default user when empty", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
api.defaults.User = "original"
|
||||||
|
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
DefaultUser: utilsShared.ToPtr(""),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "original", api.defaults.User)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets default workdir", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, api.defaults.Workdir)
|
||||||
|
assert.Equal(t, "/home/user", *api.defaults.Workdir)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not set default workdir when empty", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
originalWorkdir := "/original"
|
||||||
|
api.defaults.Workdir = &originalWorkdir
|
||||||
|
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
DefaultWorkdir: utilsShared.ToPtr(""),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, api.defaults.Workdir)
|
||||||
|
assert.Equal(t, "/original", *api.defaults.Workdir)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets multiple fields at once", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
|
||||||
|
api := newTestAPI(nil, mmdsClient)
|
||||||
|
|
||||||
|
envVars := EnvVars{"KEY": "value"}
|
||||||
|
data := PostInitJSONBody{
|
||||||
|
AccessToken: secureTokenPtr("token"),
|
||||||
|
DefaultUser: utilsShared.ToPtr("user"),
|
||||||
|
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
|
||||||
|
EnvVars: &envVars,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := api.SetData(ctx, logger, data)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
|
||||||
|
assert.Equal(t, "user", api.defaults.User)
|
||||||
|
assert.Equal(t, "/workdir", *api.defaults.Workdir)
|
||||||
|
val, ok := api.defaults.EnvVars.Load("KEY")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "value", val)
|
||||||
|
})
|
||||||
|
}
|
||||||
214
envd/internal/api/secure_token.go
Normal file
214
envd/internal/api/secure_token.go
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTokenNotSet = errors.New("access token not set")
|
||||||
|
ErrTokenEmpty = errors.New("empty token not allowed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecureToken wraps memguard for secure token storage.
|
||||||
|
// It uses LockedBuffer which provides memory locking, guard pages,
|
||||||
|
// and secure zeroing on destroy.
|
||||||
|
type SecureToken struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
buffer *memguard.LockedBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set securely replaces the token, destroying the old one first.
|
||||||
|
// The old token memory is zeroed before the new token is stored.
|
||||||
|
// The input byte slice is wiped after copying to secure memory.
|
||||||
|
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
|
||||||
|
func (s *SecureToken) Set(token []byte) error {
|
||||||
|
if len(token) == 0 {
|
||||||
|
return ErrTokenEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Destroy old token first (zeros memory)
|
||||||
|
if s.buffer != nil {
|
||||||
|
s.buffer.Destroy()
|
||||||
|
s.buffer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
|
||||||
|
s.buffer = memguard.NewBufferFromBytes(token)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
|
||||||
|
// directly into memguard, wiping the input bytes after copying.
|
||||||
|
//
|
||||||
|
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
|
||||||
|
// so they never contain JSON escape sequences.
|
||||||
|
func (s *SecureToken) UnmarshalJSON(data []byte) error {
|
||||||
|
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
|
||||||
|
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||||
|
memguard.WipeBytes(data)
|
||||||
|
|
||||||
|
return errors.New("invalid secure token JSON string")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := data[1 : len(data)-1]
|
||||||
|
|
||||||
|
// Access tokens are hex strings - reject if contains backslash
|
||||||
|
if bytes.ContainsRune(content, '\\') {
|
||||||
|
memguard.WipeBytes(data)
|
||||||
|
|
||||||
|
return errors.New("invalid secure token: unexpected escape sequence")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(content) == 0 {
|
||||||
|
memguard.WipeBytes(data)
|
||||||
|
|
||||||
|
return ErrTokenEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.buffer != nil {
|
||||||
|
s.buffer.Destroy()
|
||||||
|
s.buffer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate secure buffer and copy directly into it
|
||||||
|
s.buffer = memguard.NewBuffer(len(content))
|
||||||
|
copy(s.buffer.Bytes(), content)
|
||||||
|
|
||||||
|
// Wipe the input data
|
||||||
|
memguard.WipeBytes(data)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TakeFrom transfers the token from src to this SecureToken, destroying any
|
||||||
|
// existing token. The source token is cleared after transfer.
|
||||||
|
// This avoids copying the underlying bytes.
|
||||||
|
func (s *SecureToken) TakeFrom(src *SecureToken) {
|
||||||
|
if src == nil || s == src {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract buffer from source
|
||||||
|
src.mu.Lock()
|
||||||
|
buffer := src.buffer
|
||||||
|
src.buffer = nil
|
||||||
|
src.mu.Unlock()
|
||||||
|
|
||||||
|
// Install buffer in destination
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.buffer != nil {
|
||||||
|
s.buffer.Destroy()
|
||||||
|
}
|
||||||
|
s.buffer = buffer
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equals checks if token matches using constant-time comparison.
|
||||||
|
// Returns false if the receiver is nil.
|
||||||
|
func (s *SecureToken) Equals(token string) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.buffer.EqualTo([]byte(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
|
||||||
|
// Returns false if either receiver or other is nil.
|
||||||
|
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
|
||||||
|
if s == nil || other == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s == other {
|
||||||
|
return s.IsSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a copy of other's bytes (avoids holding two locks simultaneously)
|
||||||
|
otherBytes, err := other.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer memguard.WipeBytes(otherBytes)
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.buffer.EqualTo(otherBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSet returns true if a token is stored.
|
||||||
|
// Returns false if the receiver is nil.
|
||||||
|
func (s *SecureToken) IsSet() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
return s.buffer != nil && s.buffer.IsAlive()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns a copy of the token bytes (for signature generation).
|
||||||
|
// The caller should zero the returned slice after use.
|
||||||
|
// Returns ErrTokenNotSet if the receiver is nil.
|
||||||
|
func (s *SecureToken) Bytes() ([]byte, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, ErrTokenNotSet
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if s.buffer == nil || !s.buffer.IsAlive() {
|
||||||
|
return nil, ErrTokenNotSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy (unavoidable for signature generation)
|
||||||
|
src := s.buffer.Bytes()
|
||||||
|
result := make([]byte, len(src))
|
||||||
|
copy(result, src)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy securely wipes the token from memory.
|
||||||
|
// No-op if the receiver is nil.
|
||||||
|
func (s *SecureToken) Destroy() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.buffer != nil {
|
||||||
|
s.buffer.Destroy()
|
||||||
|
s.buffer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
463
envd/internal/api/secure_token_test.go
Normal file
463
envd/internal/api/secure_token_test.go
Normal file
@ -0,0 +1,463 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSecureTokenSetAndEquals(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Initially not set
|
||||||
|
assert.False(t, st.IsSet(), "token should not be set initially")
|
||||||
|
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
|
||||||
|
|
||||||
|
// Set token
|
||||||
|
err := st.Set([]byte("test-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.IsSet(), "token should be set after Set()")
|
||||||
|
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
|
||||||
|
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
|
||||||
|
assert.False(t, st.Equals(""), "equals should return false for empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenReplace(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Set initial token
|
||||||
|
err := st.Set([]byte("first-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.Equals("first-token"))
|
||||||
|
|
||||||
|
// Replace with new token (old one should be destroyed)
|
||||||
|
err = st.Set([]byte("second-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.Equals("second-token"), "should match new token")
|
||||||
|
assert.False(t, st.Equals("first-token"), "should not match old token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenDestroy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Set and then destroy
|
||||||
|
err := st.Set([]byte("test-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.IsSet())
|
||||||
|
|
||||||
|
st.Destroy()
|
||||||
|
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
|
||||||
|
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
|
||||||
|
|
||||||
|
// Destroy on already destroyed should be safe
|
||||||
|
st.Destroy()
|
||||||
|
assert.False(t, st.IsSet())
|
||||||
|
|
||||||
|
// Nil receiver should be safe
|
||||||
|
var nilToken *SecureToken
|
||||||
|
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
|
||||||
|
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
|
||||||
|
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
|
||||||
|
nilToken.Destroy() // should not panic
|
||||||
|
|
||||||
|
_, err = nilToken.Bytes()
|
||||||
|
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenBytes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Bytes should return error when not set
|
||||||
|
_, err := st.Bytes()
|
||||||
|
require.ErrorIs(t, err, ErrTokenNotSet)
|
||||||
|
|
||||||
|
// Set token and get bytes
|
||||||
|
err = st.Set([]byte("test-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
bytes, err := st.Bytes()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("test-token"), bytes)
|
||||||
|
|
||||||
|
// Zero out the bytes (as caller should do)
|
||||||
|
memguard.WipeBytes(bytes)
|
||||||
|
|
||||||
|
// Original should still be intact
|
||||||
|
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
|
||||||
|
|
||||||
|
// After destroy, bytes should fail
|
||||||
|
st.Destroy()
|
||||||
|
_, err = st.Bytes()
|
||||||
|
assert.ErrorIs(t, err, ErrTokenNotSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenConcurrentAccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set([]byte("initial-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
// Concurrent reads
|
||||||
|
for range numGoroutines {
|
||||||
|
wg.Go(func() {
|
||||||
|
st.IsSet()
|
||||||
|
st.Equals("initial-token")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent writes
|
||||||
|
for i := range 10 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
st.Set([]byte("token-" + string(rune('a'+idx))))
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Should still be in a valid state
|
||||||
|
assert.True(t, st.IsSet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenEmptyToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Setting empty token should return an error
|
||||||
|
err := st.Set([]byte{})
|
||||||
|
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||||
|
assert.False(t, st.IsSet(), "token should not be set after empty token error")
|
||||||
|
|
||||||
|
// Setting nil should also return an error
|
||||||
|
err = st.Set(nil)
|
||||||
|
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||||
|
assert.False(t, st.IsSet(), "token should not be set after nil token error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
// Set a valid token first
|
||||||
|
err := st.Set([]byte("valid-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.IsSet())
|
||||||
|
|
||||||
|
// Attempting to set empty token should fail and preserve existing token
|
||||||
|
err = st.Set([]byte{})
|
||||||
|
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||||
|
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
|
||||||
|
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenUnmarshalJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("unmarshals valid JSON string", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.IsSet())
|
||||||
|
assert.True(t, st.Equals("my-secret-token"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for empty string", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON([]byte(`""`))
|
||||||
|
require.ErrorIs(t, err, ErrTokenEmpty)
|
||||||
|
assert.False(t, st.IsSet())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for invalid JSON", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON([]byte(`not-valid-json`))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.False(t, st.IsSet())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces existing token", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set([]byte("old-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = st.UnmarshalJSON([]byte(`"new-token"`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, st.Equals("new-token"))
|
||||||
|
assert.False(t, st.Equals("old-token"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wipes input buffer after parsing", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Create a buffer with a known token
|
||||||
|
input := []byte(`"secret-token-12345"`)
|
||||||
|
original := make([]byte, len(input))
|
||||||
|
copy(original, input)
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the token was stored correctly
|
||||||
|
assert.True(t, st.Equals("secret-token-12345"))
|
||||||
|
|
||||||
|
// Verify the input buffer was wiped (all zeros)
|
||||||
|
for i, b := range input {
|
||||||
|
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wipes input buffer on error", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Create a buffer with an empty token (will error)
|
||||||
|
input := []byte(`""`)
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON(input)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Verify the input buffer was still wiped
|
||||||
|
for i, b := range input {
|
||||||
|
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects escape sequences", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "escape sequence")
|
||||||
|
assert.False(t, st.IsSet())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenSetWipesInput(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("wipes input buffer after storing", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Create a buffer with a known token
|
||||||
|
input := []byte("my-secret-token")
|
||||||
|
original := make([]byte, len(input))
|
||||||
|
copy(original, input)
|
||||||
|
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the token was stored correctly
|
||||||
|
assert.True(t, st.Equals("my-secret-token"))
|
||||||
|
|
||||||
|
// Verify the input buffer was wiped (all zeros)
|
||||||
|
for i, b := range input {
|
||||||
|
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenTakeFrom(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("transfers token from source to destination", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := &SecureToken{}
|
||||||
|
err := src.Set([]byte("source-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dst := &SecureToken{}
|
||||||
|
dst.TakeFrom(src)
|
||||||
|
|
||||||
|
assert.True(t, dst.IsSet())
|
||||||
|
assert.True(t, dst.Equals("source-token"))
|
||||||
|
assert.False(t, src.IsSet(), "source should be empty after transfer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces existing destination token", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := &SecureToken{}
|
||||||
|
err := src.Set([]byte("new-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dst := &SecureToken{}
|
||||||
|
err = dst.Set([]byte("old-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dst.TakeFrom(src)
|
||||||
|
|
||||||
|
assert.True(t, dst.Equals("new-token"))
|
||||||
|
assert.False(t, dst.Equals("old-token"))
|
||||||
|
assert.False(t, src.IsSet())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles nil source", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dst := &SecureToken{}
|
||||||
|
err := dst.Set([]byte("existing-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dst.TakeFrom(nil)
|
||||||
|
|
||||||
|
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
|
||||||
|
assert.True(t, dst.Equals("existing-token"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles empty source", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
src := &SecureToken{}
|
||||||
|
dst := &SecureToken{}
|
||||||
|
err := dst.Set([]byte("existing-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dst.TakeFrom(src)
|
||||||
|
|
||||||
|
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set([]byte("token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
st.TakeFrom(st)
|
||||||
|
|
||||||
|
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
|
||||||
|
assert.True(t, st.Equals("token"), "token value should be unchanged")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureTokenEqualsSecure(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("returns true for matching tokens", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st1 := &SecureToken{}
|
||||||
|
err := st1.Set([]byte("same-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
st2 := &SecureToken{}
|
||||||
|
err = st2.Set([]byte("same-token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, st1.EqualsSecure(st2))
|
||||||
|
assert.True(t, st2.EqualsSecure(st1))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// This test verifies the fix for the lock ordering deadlock bug.
|
||||||
|
|
||||||
|
const iterations = 100
|
||||||
|
|
||||||
|
for range iterations {
|
||||||
|
a := &SecureToken{}
|
||||||
|
err := a.Set([]byte("token-a"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
b := &SecureToken{}
|
||||||
|
err = b.Set([]byte("token-b"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// Goroutine 1: a.TakeFrom(b)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
a.TakeFrom(b)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Goroutine 2: b.EqualsSecure(a)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
b.EqualsSecure(a)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false for different tokens", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st1 := &SecureToken{}
|
||||||
|
err := st1.Set([]byte("token-a"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
st2 := &SecureToken{}
|
||||||
|
err = st2.Set([]byte("token-b"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, st1.EqualsSecure(st2))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false when comparing with nil", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set([]byte("token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, st.EqualsSecure(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false when other is not set", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st1 := &SecureToken{}
|
||||||
|
err := st1.Set([]byte("token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
st2 := &SecureToken{}
|
||||||
|
|
||||||
|
assert.False(t, st1.EqualsSecure(st2))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false when self is not set", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st1 := &SecureToken{}
|
||||||
|
|
||||||
|
st2 := &SecureToken{}
|
||||||
|
err := st2.Set([]byte("token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, st1.EqualsSecure(st2))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("self-comparison returns true when set", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
err := st.Set([]byte("token"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("self-comparison returns false when not set", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := &SecureToken{}
|
||||||
|
|
||||||
|
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
|
||||||
|
})
|
||||||
|
}
|
||||||
95
envd/internal/api/store.go
Normal file
95
envd/internal/api/store.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MMDSClient provides access to MMDS metadata.
|
||||||
|
type MMDSClient interface {
|
||||||
|
GetAccessTokenHash(ctx context.Context) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
|
||||||
|
type DefaultMMDSClient struct{}
|
||||||
|
|
||||||
|
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
|
||||||
|
return host.GetAccessTokenHashFromMMDS(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
type API struct {
|
||||||
|
isNotFC bool
|
||||||
|
logger *zerolog.Logger
|
||||||
|
accessToken *SecureToken
|
||||||
|
defaults *execcontext.Defaults
|
||||||
|
|
||||||
|
mmdsChan chan *host.MMDSOpts
|
||||||
|
hyperloopLock sync.Mutex
|
||||||
|
mmdsClient MMDSClient
|
||||||
|
|
||||||
|
lastSetTime *utils.AtomicMax
|
||||||
|
initLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
|
||||||
|
return &API{
|
||||||
|
logger: l,
|
||||||
|
defaults: defaults,
|
||||||
|
mmdsChan: mmdsChan,
|
||||||
|
isNotFC: isNotFC,
|
||||||
|
mmdsClient: &DefaultMMDSClient{},
|
||||||
|
lastSetTime: utils.NewAtomicMax(),
|
||||||
|
accessToken: &SecureToken{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
a.logger.Trace().Msg("Health check")
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.Header().Set("Content-Type", "")
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
a.logger.Trace().Msg("Get metrics")
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
metrics, err := host.GetMetrics()
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("Failed to get metrics")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("Failed to encode metrics")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) getLogger(err error) *zerolog.Event {
|
||||||
|
if err != nil {
|
||||||
|
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.logger.Info() //nolint:zerologlint // this is only prep
|
||||||
|
}
|
||||||
311
envd/internal/api/upload.go
Normal file
311
envd/internal/api/upload.go
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
|
||||||
|
|
||||||
|
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
|
||||||
|
logger.Debug().
|
||||||
|
Str("path", path).
|
||||||
|
Msg("File processing")
|
||||||
|
|
||||||
|
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("error ensuring directories: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
canBePreChowned := false
|
||||||
|
stat, err := os.Stat(path)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
errMsg := fmt.Errorf("error getting file info: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, errMsg
|
||||||
|
} else if err == nil {
|
||||||
|
if stat.IsDir() {
|
||||||
|
err := fmt.Errorf("path is a directory: %s", path)
|
||||||
|
|
||||||
|
return http.StatusBadRequest, err
|
||||||
|
}
|
||||||
|
canBePreChowned = true
|
||||||
|
}
|
||||||
|
|
||||||
|
hasBeenChowned := false
|
||||||
|
if canBePreChowned {
|
||||||
|
err = os.Chown(path, uid, gid)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
err = fmt.Errorf("error changing file ownership: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hasBeenChowned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, syscall.ENOSPC) {
|
||||||
|
err = fmt.Errorf("not enough inodes available: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInsufficientStorage, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := fmt.Errorf("error opening file: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if !hasBeenChowned {
|
||||||
|
err = os.Chown(path, uid, gid)
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("error changing file ownership: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = file.ReadFrom(part)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, syscall.ENOSPC) {
|
||||||
|
err = ErrNoDiskSpace
|
||||||
|
if r.ContentLength > 0 {
|
||||||
|
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.StatusInsufficientStorage, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fmt.Errorf("error writing file: %w", err)
|
||||||
|
|
||||||
|
return http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.StatusNoContent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
|
||||||
|
var pathToResolve string
|
||||||
|
|
||||||
|
if params.Path != nil {
|
||||||
|
pathToResolve = *params.Path
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
customPart := utils.NewCustomPart(part)
|
||||||
|
pathToResolve, err = customPart.FileNameWithPath()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error resolving path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range *paths {
|
||||||
|
if entry.Path == filePath {
|
||||||
|
var alreadyUploaded []string
|
||||||
|
for _, uploadedFile := range *paths {
|
||||||
|
if uploadedFile.Path != filePath {
|
||||||
|
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
|
||||||
|
|
||||||
|
if len(alreadyUploaded) > 1 {
|
||||||
|
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
|
||||||
|
defer part.Close()
|
||||||
|
|
||||||
|
if part.FormName() != "file" {
|
||||||
|
return nil, http.StatusOK, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, http.StatusBadRequest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := a.logger.
|
||||||
|
With().
|
||||||
|
Str(string(logs.OperationIDKey), operationID).
|
||||||
|
Str("event_type", "file_processing").
|
||||||
|
Logger()
|
||||||
|
|
||||||
|
status, err := processFile(r, filePath, part, uid, gid, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &EntryInfo{
|
||||||
|
Path: filePath,
|
||||||
|
Name: filepath.Base(filePath),
|
||||||
|
Type: File,
|
||||||
|
}, http.StatusOK, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
|
||||||
|
// Capture original body to ensure it's always closed
|
||||||
|
originalBody := r.Body
|
||||||
|
defer originalBody.Close()
|
||||||
|
|
||||||
|
var errorCode int
|
||||||
|
var errMsg error
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if params.Path != nil {
|
||||||
|
path = *params.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
operationID := logs.AssignOperationID()
|
||||||
|
|
||||||
|
// signing authorization if needed
|
||||||
|
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
|
||||||
|
jsonError(w, http.StatusUnauthorized, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
|
||||||
|
jsonError(w, http.StatusBadRequest, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
l := a.logger.
|
||||||
|
Err(errMsg).
|
||||||
|
Str("method", r.Method+" "+r.URL.Path).
|
||||||
|
Str(string(logs.OperationIDKey), operationID).
|
||||||
|
Str("path", path).
|
||||||
|
Str("username", username)
|
||||||
|
|
||||||
|
if errMsg != nil {
|
||||||
|
l = l.Int("error_code", errorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Msg("File write")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Handle gzip-encoded request body
|
||||||
|
body, err := getDecompressedBody(r)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error decompressing request body: %w", err)
|
||||||
|
errorCode = http.StatusBadRequest
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
r.Body = body
|
||||||
|
|
||||||
|
f, err := r.MultipartReader()
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
|
||||||
|
errorCode = http.StatusInternalServerError
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||||
|
errorCode = http.StatusUnauthorized
|
||||||
|
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uid, gid, err := permissions.GetUserIdInts(u)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error getting user ids: %w", err)
|
||||||
|
|
||||||
|
jsonError(w, http.StatusInternalServerError, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
paths := UploadSuccess{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
part, partErr := f.NextPart()
|
||||||
|
|
||||||
|
if partErr == io.EOF {
|
||||||
|
// We're done reading the parts.
|
||||||
|
break
|
||||||
|
} else if partErr != nil {
|
||||||
|
errMsg = fmt.Errorf("error reading form: %w", partErr)
|
||||||
|
errorCode = http.StatusInternalServerError
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
|
||||||
|
if err != nil {
|
||||||
|
errorCode = status
|
||||||
|
errMsg = err
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry != nil {
|
||||||
|
paths = append(paths, *entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(paths)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = fmt.Errorf("error marshaling response: %w", err)
|
||||||
|
errorCode = http.StatusInternalServerError
|
||||||
|
jsonError(w, errorCode, errMsg)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
251
envd/internal/api/upload_test.go
Normal file
251
envd/internal/api/upload_test.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
uid := os.Getuid()
|
||||||
|
gid := os.Getgid()
|
||||||
|
|
||||||
|
newRequest := func(content []byte) (*http.Request, io.Reader) {
|
||||||
|
request := &http.Request{
|
||||||
|
ContentLength: int64(len(content)),
|
||||||
|
}
|
||||||
|
buffer := bytes.NewBuffer(content)
|
||||||
|
|
||||||
|
return request, buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyReq http.Request
|
||||||
|
var emptyPart *bytes.Buffer
|
||||||
|
var emptyLogger zerolog.Logger
|
||||||
|
|
||||||
|
t.Run("failed to ensure directories", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||||
|
assert.ErrorContains(t, err, "error ensuring directories: ")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("attempt to replace directory with a file", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
|
||||||
|
assert.ErrorContains(t, err, "path is a directory: ")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fail to create file", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, httpStatus)
|
||||||
|
assert.ErrorContains(t, err, "error opening file: ")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("out of disk space", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// make a tiny tmpfs mount
|
||||||
|
mountSize := 1024
|
||||||
|
tempDir := createTmpfsMount(t, mountSize)
|
||||||
|
|
||||||
|
// create test file
|
||||||
|
firstFileSize := mountSize / 2
|
||||||
|
tempFile1 := filepath.Join(tempDir, "test-file-1")
|
||||||
|
|
||||||
|
// fill it up
|
||||||
|
cmd := exec.CommandContext(t.Context(),
|
||||||
|
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a new file that would fill up the
|
||||||
|
secondFileContents := make([]byte, mountSize*2)
|
||||||
|
for index := range secondFileContents {
|
||||||
|
secondFileContents[index] = 'a'
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to replace it
|
||||||
|
request, buffer := newRequest(secondFileContents)
|
||||||
|
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||||
|
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||||
|
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("happy path", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "test-file")
|
||||||
|
|
||||||
|
content := []byte("test-file-contents")
|
||||||
|
request, buffer := newRequest(content)
|
||||||
|
|
||||||
|
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tempFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, content, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwrite file on full disk", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// make a tiny tmpfs mount
|
||||||
|
sizeInBytes := 1024
|
||||||
|
tempDir := createTmpfsMount(t, 1024)
|
||||||
|
|
||||||
|
// create test file
|
||||||
|
tempFile := filepath.Join(tempDir, "test-file")
|
||||||
|
|
||||||
|
// fill it up
|
||||||
|
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// try to replace it
|
||||||
|
content := []byte("test-file-contents")
|
||||||
|
request, buffer := newRequest(content)
|
||||||
|
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write new file on full disk", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// make a tiny tmpfs mount
|
||||||
|
sizeInBytes := 1024
|
||||||
|
tempDir := createTmpfsMount(t, 1024)
|
||||||
|
|
||||||
|
// create test file
|
||||||
|
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||||
|
|
||||||
|
// fill it up
|
||||||
|
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// try to write a new file
|
||||||
|
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||||
|
content := []byte("test-file-contents")
|
||||||
|
request, buffer := newRequest(content)
|
||||||
|
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||||
|
require.ErrorContains(t, err, "not enough disk space available")
|
||||||
|
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write new file with no inodes available", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// make a tiny tmpfs mount
|
||||||
|
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
|
||||||
|
|
||||||
|
// create test file
|
||||||
|
tempFile1 := filepath.Join(tempDir, "test-file")
|
||||||
|
|
||||||
|
// fill it up
|
||||||
|
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// try to write a new file
|
||||||
|
tempFile2 := filepath.Join(tempDir, "test-file-2")
|
||||||
|
content := []byte("test-file-contents")
|
||||||
|
request, buffer := newRequest(content)
|
||||||
|
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
|
||||||
|
require.ErrorContains(t, err, "not enough inodes available")
|
||||||
|
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
|
||||||
|
newContent := []byte("102\n")
|
||||||
|
request, buffer := newRequest(newContent)
|
||||||
|
|
||||||
|
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, newContent, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replace file", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "test-file")
|
||||||
|
|
||||||
|
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newContent := []byte("new-file-contents")
|
||||||
|
request, buffer := newRequest(newContent)
|
||||||
|
|
||||||
|
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusNoContent, httpStatus)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tempFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, newContent, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(t.Context(),
|
||||||
|
"mount",
|
||||||
|
"tmpfs",
|
||||||
|
tempDir,
|
||||||
|
"-t", "tmpfs",
|
||||||
|
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
ctx := context.WithoutCancel(t.Context())
|
||||||
|
cmd := exec.CommandContext(ctx, "umount", tempDir)
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return tempDir
|
||||||
|
}
|
||||||
39
envd/internal/execcontext/context.go
Normal file
39
envd/internal/execcontext/context.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package execcontext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Defaults struct {
|
||||||
|
EnvVars *utils.Map[string, string]
|
||||||
|
User string
|
||||||
|
Workdir *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
|
||||||
|
if workdir != "" {
|
||||||
|
return workdir
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultWorkdir != nil {
|
||||||
|
return *defaultWorkdir
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
|
||||||
|
if username != nil {
|
||||||
|
return *username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultUsername != "" {
|
||||||
|
return defaultUsername, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errors.New("username not provided")
|
||||||
|
}
|
||||||
96
envd/internal/host/metrics.go
Normal file
96
envd/internal/host/metrics.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package host
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shirou/gopsutil/v4/cpu"
|
||||||
|
"github.com/shirou/gopsutil/v4/mem"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Metrics struct {
|
||||||
|
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
|
||||||
|
|
||||||
|
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
|
||||||
|
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
|
||||||
|
|
||||||
|
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||||
|
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
|
||||||
|
|
||||||
|
// Deprecated: kept for backwards compatibility with older orchestrators.
|
||||||
|
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
|
||||||
|
|
||||||
|
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
|
||||||
|
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
|
||||||
|
|
||||||
|
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
|
||||||
|
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMetrics() (*Metrics, error) {
|
||||||
|
v, err := mem.VirtualMemory()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memUsedMiB := v.Used / 1024 / 1024
|
||||||
|
memTotalMiB := v.Total / 1024 / 1024
|
||||||
|
|
||||||
|
cpuTotal, err := cpu.Counts(true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cpuUsedPcts, err := cpu.Percent(0, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cpuUsedPct := cpuUsedPcts[0]
|
||||||
|
cpuUsedPctRounded := float32(cpuUsedPct)
|
||||||
|
if cpuUsedPct > 0 {
|
||||||
|
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
diskMetrics, err := diskStats("/")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Metrics{
|
||||||
|
Timestamp: time.Now().UTC().Unix(),
|
||||||
|
CPUCount: uint32(cpuTotal),
|
||||||
|
CPUUsedPercent: cpuUsedPctRounded,
|
||||||
|
MemUsedMiB: memUsedMiB,
|
||||||
|
MemTotalMiB: memTotalMiB,
|
||||||
|
MemTotal: v.Total,
|
||||||
|
MemUsed: v.Used,
|
||||||
|
DiskUsed: diskMetrics.Total - diskMetrics.Available,
|
||||||
|
DiskTotal: diskMetrics.Total,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type diskSpace struct {
|
||||||
|
Total uint64
|
||||||
|
Available uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func diskStats(path string) (diskSpace, error) {
|
||||||
|
var st unix.Statfs_t
|
||||||
|
if err := unix.Statfs(path, &st); err != nil {
|
||||||
|
return diskSpace{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block := uint64(st.Bsize)
|
||||||
|
|
||||||
|
// all data blocks
|
||||||
|
total := st.Blocks * block
|
||||||
|
// blocks available
|
||||||
|
available := st.Bavail * block
|
||||||
|
|
||||||
|
return diskSpace{Total: total, Available: available}, nil
|
||||||
|
}
|
||||||
185
envd/internal/host/mmds.go
Normal file
185
envd/internal/host/mmds.go
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package host
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
|
||||||
|
|
||||||
|
mmdsDefaultAddress = "169.254.169.254"
|
||||||
|
mmdsTokenExpiration = 60 * time.Second
|
||||||
|
|
||||||
|
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var mmdsAccessTokenClient = &http.Client{
|
||||||
|
Timeout: mmdsAccessTokenRequestClientTimeout,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type MMDSOpts struct {
|
||||||
|
SandboxID string `json:"instanceID"`
|
||||||
|
TemplateID string `json:"envID"`
|
||||||
|
LogsCollectorAddress string `json:"address"`
|
||||||
|
AccessTokenHash string `json:"accessTokenHash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
|
||||||
|
opts.SandboxID = sandboxID
|
||||||
|
opts.TemplateID = templateID
|
||||||
|
opts.LogsCollectorAddress = collectorAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
|
||||||
|
parsed := make(map[string]any)
|
||||||
|
|
||||||
|
err := json.Unmarshal(jsonLogs, &parsed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed["instanceID"] = opts.SandboxID
|
||||||
|
parsed["envID"] = opts.TemplateID
|
||||||
|
|
||||||
|
data, err := json.Marshal(parsed)
|
||||||
|
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
|
||||||
|
|
||||||
|
response, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
token := string(body)
|
||||||
|
|
||||||
|
if len(token) == 0 {
|
||||||
|
return "", fmt.Errorf("mmds token is an empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header["X-metadata-token"] = []string{token}
|
||||||
|
request.Header["Accept"] = []string{"application/json"}
|
||||||
|
|
||||||
|
response, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var opts MMDSOpts
|
||||||
|
|
||||||
|
err = json.Unmarshal(body, &opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
|
||||||
|
// This is used to validate that /init requests come from the orchestrator.
|
||||||
|
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
|
||||||
|
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get MMDS token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts.AccessTokenHash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
defer httpClient.CloseIdleConnections()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
|
||||||
|
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
token, err := getMMDSToken(ctx, httpClient)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
|
||||||
|
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
|
||||||
|
|
||||||
|
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mmdsOpts.LogsCollectorAddress != "" {
|
||||||
|
mmdsChan <- mmdsOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
49
envd/internal/logs/bufferedEvents.go
Normal file
49
envd/internal/logs/bufferedEvents.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMaxBufferSize = 2 << 15
|
||||||
|
defaultTimeout = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
|
||||||
|
timer := time.NewTicker(defaultTimeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
var buffer []byte
|
||||||
|
defer func() {
|
||||||
|
if len(buffer) > 0 {
|
||||||
|
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
if len(buffer) > 0 {
|
||||||
|
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||||
|
buffer = nil
|
||||||
|
}
|
||||||
|
case data, ok := <-dataCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = append(buffer, data...)
|
||||||
|
|
||||||
|
if len(buffer) >= defaultMaxBufferSize {
|
||||||
|
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
|
||||||
|
buffer = nil
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
174
envd/internal/logs/exporter/exporter.go
Normal file
174
envd/internal/logs/exporter/exporter.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package exporter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ExporterTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
type HTTPExporter struct {
|
||||||
|
client http.Client
|
||||||
|
logs [][]byte
|
||||||
|
isNotFC bool
|
||||||
|
mmdsOpts *host.MMDSOpts
|
||||||
|
|
||||||
|
// Concurrency coordination
|
||||||
|
triggers chan struct{}
|
||||||
|
logLock sync.RWMutex
|
||||||
|
mmdsLock sync.RWMutex
|
||||||
|
startOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
|
||||||
|
exporter := &HTTPExporter{
|
||||||
|
client: http.Client{
|
||||||
|
Timeout: ExporterTimeout,
|
||||||
|
},
|
||||||
|
triggers: make(chan struct{}, 1),
|
||||||
|
isNotFC: isNotFC,
|
||||||
|
startOnce: sync.Once{},
|
||||||
|
mmdsOpts: &host.MMDSOpts{
|
||||||
|
SandboxID: "unknown",
|
||||||
|
TemplateID: "unknown",
|
||||||
|
LogsCollectorAddress: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
|
||||||
|
|
||||||
|
return exporter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
|
||||||
|
if address == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
response, err := w.client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printLog(logs []byte) {
|
||||||
|
fmt.Fprintf(os.Stdout, "%v", string(logs))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case mmdsOpts, ok := <-mmdsChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.mmdsLock.Lock()
|
||||||
|
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
|
||||||
|
w.mmdsLock.Unlock()
|
||||||
|
|
||||||
|
w.startOnce.Do(func() {
|
||||||
|
go w.start(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) start(ctx context.Context) {
|
||||||
|
for range w.triggers {
|
||||||
|
logs := w.getAllLogs()
|
||||||
|
|
||||||
|
if len(logs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.isNotFC {
|
||||||
|
for _, log := range logs {
|
||||||
|
fmt.Fprintf(os.Stdout, "%v", string(log))
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, logLine := range logs {
|
||||||
|
w.mmdsLock.RLock()
|
||||||
|
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
|
||||||
|
w.mmdsLock.RUnlock()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
|
||||||
|
|
||||||
|
printLog(logLine)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error sending instance logs: %+v", err)
|
||||||
|
|
||||||
|
printLog(logLine)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) resumeProcessing() {
|
||||||
|
select {
|
||||||
|
case w.triggers <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Exporter processing already triggered
|
||||||
|
// This is expected behavior if the exporter is already processing logs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) Write(logs []byte) (int, error) {
|
||||||
|
logsCopy := make([]byte, len(logs))
|
||||||
|
copy(logsCopy, logs)
|
||||||
|
|
||||||
|
go w.addLogs(logsCopy)
|
||||||
|
|
||||||
|
return len(logs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) getAllLogs() [][]byte {
|
||||||
|
w.logLock.Lock()
|
||||||
|
defer w.logLock.Unlock()
|
||||||
|
|
||||||
|
logs := w.logs
|
||||||
|
w.logs = nil
|
||||||
|
|
||||||
|
return logs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *HTTPExporter) addLogs(logs []byte) {
|
||||||
|
w.logLock.Lock()
|
||||||
|
defer w.logLock.Unlock()
|
||||||
|
|
||||||
|
w.logs = append(w.logs, logs)
|
||||||
|
|
||||||
|
w.resumeProcessing()
|
||||||
|
}
|
||||||
174
envd/internal/logs/interceptor.go
Normal file
174
envd/internal/logs/interceptor.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OperationID string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OperationIDKey OperationID = "operation_id"
|
||||||
|
DefaultHTTPMethod string = "POST"
|
||||||
|
)
|
||||||
|
|
||||||
|
var operationID = atomic.Int32{}
|
||||||
|
|
||||||
|
func AssignOperationID() string {
|
||||||
|
id := operationID.Add(1)
|
||||||
|
|
||||||
|
return strconv.Itoa(int(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddRequestIDToContext(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatMethod(method string) string {
|
||||||
|
parts := strings.Split(method, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return method
|
||||||
|
}
|
||||||
|
|
||||||
|
split := strings.Split(parts[1], "/")
|
||||||
|
if len(split) < 2 {
|
||||||
|
return method
|
||||||
|
}
|
||||||
|
|
||||||
|
servicePart := split[0]
|
||||||
|
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
|
||||||
|
|
||||||
|
methodPart := split[1]
|
||||||
|
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s %s", servicePart, methodPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
|
||||||
|
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
|
||||||
|
return connect.UnaryFunc(func(
|
||||||
|
ctx context.Context,
|
||||||
|
req connect.AnyRequest,
|
||||||
|
) (connect.AnyResponse, error) {
|
||||||
|
ctx = AddRequestIDToContext(ctx)
|
||||||
|
|
||||||
|
res, err := next(ctx, req)
|
||||||
|
|
||||||
|
l := logger.
|
||||||
|
Err(err).
|
||||||
|
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||||
|
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
l = l.Int("error_code", int(connect.CodeOf(err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if req != nil {
|
||||||
|
l = l.Interface("request", req.Any())
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != nil && err == nil {
|
||||||
|
l = l.Interface("response", res.Any())
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil && err == nil {
|
||||||
|
l = l.Interface("response", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Msg(formatMethod(req.Spec().Procedure))
|
||||||
|
|
||||||
|
return res, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.UnaryInterceptorFunc(interceptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogServerStreamWithoutEvents[T any, R any](
|
||||||
|
ctx context.Context,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
req *connect.Request[R],
|
||||||
|
stream *connect.ServerStream[T],
|
||||||
|
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
|
||||||
|
) error {
|
||||||
|
ctx = AddRequestIDToContext(ctx)
|
||||||
|
|
||||||
|
l := logger.Debug().
|
||||||
|
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||||
|
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||||
|
|
||||||
|
if req != nil {
|
||||||
|
l = l.Interface("request", req.Any())
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
|
||||||
|
|
||||||
|
err := handler(ctx, req, stream)
|
||||||
|
|
||||||
|
logEvent := getErrDebugLogEvent(logger, err).
|
||||||
|
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
|
||||||
|
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||||
|
} else {
|
||||||
|
logEvent = logEvent.Interface("response", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogClientStreamWithoutEvents[T any, R any](
|
||||||
|
ctx context.Context,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
stream *connect.ClientStream[T],
|
||||||
|
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
|
||||||
|
) (*connect.Response[R], error) {
|
||||||
|
ctx = AddRequestIDToContext(ctx)
|
||||||
|
|
||||||
|
logger.Debug().
|
||||||
|
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||||
|
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
|
||||||
|
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
|
||||||
|
|
||||||
|
res, err := handler(ctx, stream)
|
||||||
|
|
||||||
|
logEvent := getErrDebugLogEvent(logger, err).
|
||||||
|
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
|
||||||
|
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != nil && err == nil {
|
||||||
|
logEvent = logEvent.Interface("response", res.Any())
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil && err == nil {
|
||||||
|
logEvent = logEvent.Interface("response", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
|
||||||
|
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return logger with error level if err is not nil, otherwise return logger with debug level
|
||||||
|
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
|
||||||
|
if err != nil {
|
||||||
|
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||||
|
}
|
||||||
|
|
||||||
|
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
|
||||||
|
}
|
||||||
37
envd/internal/logs/logger.go
Normal file
37
envd/internal/logs/logger.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package logs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
|
||||||
|
zerolog.TimestampFieldName = "timestamp"
|
||||||
|
zerolog.TimeFieldFormat = time.RFC3339Nano
|
||||||
|
|
||||||
|
exporters := []io.Writer{}
|
||||||
|
|
||||||
|
if isNotFC {
|
||||||
|
exporters = append(exporters, os.Stdout)
|
||||||
|
} else {
|
||||||
|
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := zerolog.
|
||||||
|
New(io.MultiWriter(exporters...)).
|
||||||
|
With().
|
||||||
|
Timestamp().
|
||||||
|
Logger().
|
||||||
|
Level(zerolog.DebugLevel)
|
||||||
|
|
||||||
|
return &l
|
||||||
|
}
|
||||||
49
envd/internal/permissions/authenticate.go
Normal file
49
envd/internal/permissions/authenticate.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package permissions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/user"
|
||||||
|
|
||||||
|
"connectrpc.com/authn"
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
|
||||||
|
username, _, ok := req.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
// When no username is provided, ignore the authentication method (not all endpoints require it)
|
||||||
|
// Missing user is then handled in the GetAuthUser function
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := GetUser(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, authn.Errorf("invalid username: '%s'", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
|
||||||
|
u, ok := authn.GetInfo(ctx).(*user.User)
|
||||||
|
if !ok {
|
||||||
|
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := GetUser(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, authn.Errorf("invalid default user: '%s'", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
31
envd/internal/permissions/keepalive.go
Normal file
31
envd/internal/permissions/keepalive.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package permissions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultKeepAliveInterval = 90 * time.Second
|
||||||
|
|
||||||
|
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
|
||||||
|
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
|
||||||
|
|
||||||
|
var interval time.Duration
|
||||||
|
|
||||||
|
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
|
||||||
|
if err != nil {
|
||||||
|
interval = defaultKeepAliveInterval
|
||||||
|
} else {
|
||||||
|
interval = time.Duration(keepAliveIntervalInt) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
|
||||||
|
return ticker, func() {
|
||||||
|
ticker.Reset(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
98
envd/internal/permissions/path.go
Normal file
98
envd/internal/permissions/path.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package permissions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
)
|
||||||
|
|
||||||
|
func expand(path, homedir string) (string, error) {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if path[0] != '~' {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
|
||||||
|
return "", errors.New("cannot expand user-specific home dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(homedir, path[1:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
|
||||||
|
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
|
||||||
|
|
||||||
|
path, err := expand(path, user.HomeDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filepath.IsAbs(path) {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The filepath.Abs can correctly resolve paths like /home/user/../file
|
||||||
|
path = filepath.Join(user.HomeDir, path)
|
||||||
|
|
||||||
|
abs, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return abs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSubpaths(path string) (subpaths []string) {
|
||||||
|
for {
|
||||||
|
subpaths = append(subpaths, path)
|
||||||
|
|
||||||
|
path = filepath.Dir(path)
|
||||||
|
if path == "/" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Reverse(subpaths)
|
||||||
|
|
||||||
|
return subpaths
|
||||||
|
}
|
||||||
|
|
||||||
|
func EnsureDirs(path string, uid, gid int) error {
|
||||||
|
subpaths := getSubpaths(path)
|
||||||
|
for _, subpath := range subpaths {
|
||||||
|
info, err := os.Stat(subpath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to stat directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && os.IsNotExist(err) {
|
||||||
|
err = os.Mkdir(subpath, 0o755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Chown(subpath, uid, gid)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to chown directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() {
|
||||||
|
return fmt.Errorf("path is a file: %s", subpath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
46
envd/internal/permissions/user.go
Normal file
46
envd/internal/permissions/user.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package permissions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/user"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
|
||||||
|
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(newUID), uint32(newGID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
|
||||||
|
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(newUID), int(newGID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUser(username string) (u *user.User, err error) {
|
||||||
|
u, err = user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
220
envd/internal/port/forward.go
Normal file
220
envd/internal/port/forward.go
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
|
||||||
|
// and launches `socat` process for every such port in the background.
|
||||||
|
// socat forward traffic from `sourceIP`:port to the 127.0.0.1:port.
|
||||||
|
|
||||||
|
// WARNING: portf isn't thread safe!
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PortState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PortStateForward PortState = "FORWARD"
|
||||||
|
PortStateDelete PortState = "DELETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
|
||||||
|
|
||||||
|
type PortToForward struct {
|
||||||
|
socat *exec.Cmd
|
||||||
|
// Process ID of the process that's listening on port.
|
||||||
|
pid int32
|
||||||
|
// family version of the ip.
|
||||||
|
family uint32
|
||||||
|
state PortState
|
||||||
|
port uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
logger *zerolog.Logger
|
||||||
|
cgroupManager cgroups.Manager
|
||||||
|
// Map of ports that are being currently forwarded.
|
||||||
|
ports map[string]*PortToForward
|
||||||
|
scannerSubscriber *ScannerSubscriber
|
||||||
|
sourceIP net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewForwarder(
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
scanner *Scanner,
|
||||||
|
cgroupManager cgroups.Manager,
|
||||||
|
) *Forwarder {
|
||||||
|
scannerSub := scanner.AddSubscriber(
|
||||||
|
logger,
|
||||||
|
"port-forwarder",
|
||||||
|
// We only want to forward ports that are actively listening on localhost.
|
||||||
|
&ScannerFilter{
|
||||||
|
IPs: []string{"127.0.0.1", "localhost", "::1"},
|
||||||
|
State: "LISTEN",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return &Forwarder{
|
||||||
|
logger: logger,
|
||||||
|
sourceIP: defaultGatewayIP,
|
||||||
|
ports: make(map[string]*PortToForward),
|
||||||
|
scannerSubscriber: scannerSub,
|
||||||
|
cgroupManager: cgroupManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) StartForwarding(ctx context.Context) {
|
||||||
|
if f.scannerSubscriber == nil {
|
||||||
|
f.logger.Error().Msg("Cannot start forwarding because scanner subscriber is nil")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
// procs is an array of currently opened ports.
|
||||||
|
if procs, ok := <-f.scannerSubscriber.Messages; ok {
|
||||||
|
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
|
||||||
|
// and maybe remove some.
|
||||||
|
|
||||||
|
// Go through the ports that are currently being forwarded and set all of them
|
||||||
|
// to the `DELETE` state. We don't know yet if they will be there after refresh.
|
||||||
|
for _, v := range f.ports {
|
||||||
|
v.state = PortStateDelete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
|
||||||
|
// This will make sure we won't delete them later.
|
||||||
|
for _, p := range procs {
|
||||||
|
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
|
||||||
|
|
||||||
|
// We check if the opened port is in our map of forwarded ports.
|
||||||
|
val, portOk := f.ports[key]
|
||||||
|
if portOk {
|
||||||
|
// Just mark the port as being forwarded so we don't delete it.
|
||||||
|
// The actual socat process that handles forwarding should be running from the last iteration.
|
||||||
|
val.state = PortStateForward
|
||||||
|
} else {
|
||||||
|
f.logger.Debug().
|
||||||
|
Str("ip", p.Laddr.IP).
|
||||||
|
Uint32("port", p.Laddr.Port).
|
||||||
|
Uint32("family", familyToIPVersion(p.Family)).
|
||||||
|
Str("state", p.Status).
|
||||||
|
Msg("Detected new opened port on localhost that is not forwarded")
|
||||||
|
|
||||||
|
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
|
||||||
|
ptf := &PortToForward{
|
||||||
|
pid: p.Pid,
|
||||||
|
port: p.Laddr.Port,
|
||||||
|
state: PortStateForward,
|
||||||
|
family: familyToIPVersion(p.Family),
|
||||||
|
}
|
||||||
|
f.ports[key] = ptf
|
||||||
|
f.startPortForwarding(ctx, ptf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We go through the ports map one more time and stop forwarding all ports
|
||||||
|
// that stayed marked as "DELETE".
|
||||||
|
for _, v := range f.ports {
|
||||||
|
if v.state == PortStateDelete {
|
||||||
|
f.stopPortForwarding(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
|
||||||
|
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
|
||||||
|
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
|
||||||
|
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
|
||||||
|
cmd := exec.CommandContext(ctx,
|
||||||
|
"socat", "-d", "-d", "-d",
|
||||||
|
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
|
||||||
|
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
|
||||||
|
)
|
||||||
|
|
||||||
|
cgroupFD, ok := f.cgroupManager.GetFileDescriptor(cgroups.ProcessTypeSocat)
|
||||||
|
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setpgid: true,
|
||||||
|
CgroupFD: cgroupFD,
|
||||||
|
UseCgroupFD: ok,
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Debug().
|
||||||
|
Str("socatCmd", cmd.String()).
|
||||||
|
Int32("pid", p.pid).
|
||||||
|
Uint32("family", p.family).
|
||||||
|
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||||
|
Uint32("port", p.port).
|
||||||
|
Msg("About to start port forwarding")
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
f.logger.
|
||||||
|
Error().
|
||||||
|
Str("socatCmd", cmd.String()).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to start port forwarding - failed to start socat")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
f.logger.
|
||||||
|
Debug().
|
||||||
|
Str("socatCmd", cmd.String()).
|
||||||
|
Err(err).
|
||||||
|
Msg("Port forwarding socat process exited")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
p.socat = cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) stopPortForwarding(p *PortToForward) {
|
||||||
|
if p.socat == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { p.socat = nil }()
|
||||||
|
|
||||||
|
logger := f.logger.With().
|
||||||
|
Str("socatCmd", p.socat.String()).
|
||||||
|
Int32("pid", p.pid).
|
||||||
|
Uint32("family", p.family).
|
||||||
|
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||||
|
Uint32("port", p.port).
|
||||||
|
Logger()
|
||||||
|
|
||||||
|
logger.Debug().Msg("Stopping port forwarding")
|
||||||
|
|
||||||
|
if err := syscall.Kill(-p.socat.Process.Pid, syscall.SIGKILL); err != nil {
|
||||||
|
logger.Error().Err(err).Msg("Failed to kill process group")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug().Msg("Stopped port forwarding")
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyToIPVersion(family uint32) uint32 {
|
||||||
|
switch family {
|
||||||
|
case syscall.AF_INET:
|
||||||
|
return 4
|
||||||
|
case syscall.AF_INET6:
|
||||||
|
return 6
|
||||||
|
default:
|
||||||
|
return 0 // Unknown or unsupported family
|
||||||
|
}
|
||||||
|
}
|
||||||
61
envd/internal/port/scan.go
Normal file
61
envd/internal/port/scan.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/shirou/gopsutil/v4/net"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Scanner struct {
|
||||||
|
Processes chan net.ConnectionStat
|
||||||
|
scanExit chan struct{}
|
||||||
|
subs *smap.Map[*ScannerSubscriber]
|
||||||
|
period time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scanner) Destroy() {
|
||||||
|
close(s.scanExit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScanner(period time.Duration) *Scanner {
|
||||||
|
return &Scanner{
|
||||||
|
period: period,
|
||||||
|
subs: smap.New[*ScannerSubscriber](),
|
||||||
|
scanExit: make(chan struct{}),
|
||||||
|
Processes: make(chan net.ConnectionStat),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
||||||
|
subscriber := NewScannerSubscriber(logger, id, filter)
|
||||||
|
s.subs.Insert(id, subscriber)
|
||||||
|
|
||||||
|
return subscriber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
|
||||||
|
s.subs.Remove(sub.ID())
|
||||||
|
sub.Destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
||||||
|
func (s *Scanner) ScanAndBroadcast() {
|
||||||
|
for {
|
||||||
|
// tcp monitors both ipv4 and ipv6 connections.
|
||||||
|
processes, _ := net.Connections("tcp")
|
||||||
|
for _, sub := range s.subs.Items() {
|
||||||
|
sub.Signal(processes)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.scanExit:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
time.Sleep(s.period)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
52
envd/internal/port/scanSubscriber.go
Normal file
52
envd/internal/port/scanSubscriber.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/shirou/gopsutil/v4/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// If we want to create a listener/subscriber pattern somewhere else we should move
|
||||||
|
// from a concrete implementation to combination of generics and interfaces.
|
||||||
|
|
||||||
|
type ScannerSubscriber struct {
|
||||||
|
logger *zerolog.Logger
|
||||||
|
filter *ScannerFilter
|
||||||
|
Messages chan ([]net.ConnectionStat)
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
||||||
|
return &ScannerSubscriber{
|
||||||
|
logger: logger,
|
||||||
|
id: id,
|
||||||
|
filter: filter,
|
||||||
|
Messages: make(chan []net.ConnectionStat),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *ScannerSubscriber) ID() string {
|
||||||
|
return ss.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *ScannerSubscriber) Destroy() {
|
||||||
|
close(ss.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
|
||||||
|
// Filter isn't specified. Accept everything.
|
||||||
|
if ss.filter == nil {
|
||||||
|
ss.Messages <- proc
|
||||||
|
} else {
|
||||||
|
filtered := []net.ConnectionStat{}
|
||||||
|
for i := range proc {
|
||||||
|
// We need to access the list directly otherwise there will be implicit memory aliasing
|
||||||
|
// If the filter matched a process, we will send it to a channel.
|
||||||
|
if ss.filter.Match(&proc[i]) {
|
||||||
|
filtered = append(filtered, proc[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss.Messages <- filtered
|
||||||
|
}
|
||||||
|
}
|
||||||
29
envd/internal/port/scanfilter.go
Normal file
29
envd/internal/port/scanfilter.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/shirou/gopsutil/v4/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ScannerFilter struct {
|
||||||
|
State string
|
||||||
|
IPs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
|
||||||
|
// Filter is an empty struct.
|
||||||
|
if sf.State == "" && len(sf.IPs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
|
||||||
|
|
||||||
|
if ipMatch && sf.State == proc.Status {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
129
envd/internal/services/cgroups/cgroup2.go
Normal file
129
envd/internal/services/cgroups/cgroup2.go
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cgroups
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cgroup2Manager struct {
|
||||||
|
cgroupFDs map[ProcessType]int
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Manager = (*Cgroup2Manager)(nil)
|
||||||
|
|
||||||
|
type cgroup2Config struct {
|
||||||
|
rootPath string
|
||||||
|
processTypes map[ProcessType]Cgroup2Config
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cgroup2ManagerOption func(*cgroup2Config)
|
||||||
|
|
||||||
|
func WithCgroup2RootSysFSPath(path string) Cgroup2ManagerOption {
|
||||||
|
return func(config *cgroup2Config) {
|
||||||
|
config.rootPath = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCgroup2ProcessType(processType ProcessType, path string, properties map[string]string) Cgroup2ManagerOption {
|
||||||
|
return func(config *cgroup2Config) {
|
||||||
|
if config.processTypes == nil {
|
||||||
|
config.processTypes = make(map[ProcessType]Cgroup2Config)
|
||||||
|
}
|
||||||
|
config.processTypes[processType] = Cgroup2Config{Path: path, Properties: properties}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cgroup2Config struct {
|
||||||
|
Path string
|
||||||
|
Properties map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCgroup2Manager(opts ...Cgroup2ManagerOption) (*Cgroup2Manager, error) {
|
||||||
|
config := cgroup2Config{
|
||||||
|
rootPath: "/sys/fs/cgroup",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&config)
|
||||||
|
}
|
||||||
|
|
||||||
|
cgroupFDs, err := createCgroups(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create cgroups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Cgroup2Manager{cgroupFDs: cgroupFDs}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCgroups(configs cgroup2Config) (map[ProcessType]int, error) {
|
||||||
|
var (
|
||||||
|
results = make(map[ProcessType]int)
|
||||||
|
errs []error
|
||||||
|
)
|
||||||
|
|
||||||
|
for procType, config := range configs.processTypes {
|
||||||
|
fullPath := filepath.Join(configs.rootPath, config.Path)
|
||||||
|
fd, err := createCgroup(fullPath, config.Properties)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to create %s cgroup: %w", procType, err))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results[procType] = fd
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
for procType, fd := range results {
|
||||||
|
err := unix.Close(fd)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCgroup(fullPath string, properties map[string]string) (int, error) {
|
||||||
|
if err := os.MkdirAll(fullPath, 0o755); err != nil {
|
||||||
|
return -1, fmt.Errorf("failed to create cgroup root: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for name, value := range properties {
|
||||||
|
if err := os.WriteFile(filepath.Join(fullPath, name), []byte(value), 0o644); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to write cgroup property: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return -1, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.Open(fullPath, unix.O_RDONLY, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Cgroup2Manager) GetFileDescriptor(procType ProcessType) (int, bool) {
|
||||||
|
fd, ok := c.cgroupFDs[procType]
|
||||||
|
|
||||||
|
return fd, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Cgroup2Manager) Close() error {
|
||||||
|
var errs []error
|
||||||
|
for procType, fd := range c.cgroupFDs {
|
||||||
|
if err := unix.Close(fd); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
|
||||||
|
}
|
||||||
|
delete(c.cgroupFDs, procType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
187
envd/internal/services/cgroups/cgroup2_test.go
Normal file
187
envd/internal/services/cgroups/cgroup2_test.go
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cgroups
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
oneByte = 1
|
||||||
|
kilobyte = 1024 * oneByte
|
||||||
|
megabyte = 1024 * kilobyte
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCgroupRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip("must run as root")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
maxTimeout := time.Second * 5
|
||||||
|
|
||||||
|
t.Run("process does not die without cgroups", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// create manager
|
||||||
|
m, err := NewCgroup2Manager()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create new child process
|
||||||
|
cmd := startProcess(t, m, "not-a-real-one")
|
||||||
|
|
||||||
|
// wait for child process to die
|
||||||
|
err = waitForProcess(t, cmd, maxTimeout)
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("process dies with cgroups", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cgroupPath := createCgroupPath(t, "real-one")
|
||||||
|
|
||||||
|
// create manager
|
||||||
|
m, err := NewCgroup2Manager(
|
||||||
|
WithCgroup2ProcessType(ProcessTypePTY, cgroupPath, map[string]string{
|
||||||
|
"memory.max": strconv.Itoa(1 * megabyte),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := m.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
// create new child process
|
||||||
|
cmd := startProcess(t, m, ProcessTypePTY)
|
||||||
|
|
||||||
|
// wait for child process to die
|
||||||
|
err = waitForProcess(t, cmd, maxTimeout)
|
||||||
|
|
||||||
|
// verify process exited correctly
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.ErrorAs(t, err, &exitErr)
|
||||||
|
assert.Equal(t, "signal: killed", exitErr.Error())
|
||||||
|
assert.False(t, exitErr.Exited())
|
||||||
|
assert.False(t, exitErr.Success())
|
||||||
|
assert.Equal(t, -1, exitErr.ExitCode())
|
||||||
|
|
||||||
|
// dig a little deeper
|
||||||
|
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, syscall.SIGKILL, ws.Signal())
|
||||||
|
assert.True(t, ws.Signaled())
|
||||||
|
assert.False(t, ws.Stopped())
|
||||||
|
assert.False(t, ws.Continued())
|
||||||
|
assert.False(t, ws.CoreDump())
|
||||||
|
assert.False(t, ws.Exited())
|
||||||
|
assert.Equal(t, -1, ws.ExitStatus())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("process cannot be spawned because memory limit is too low", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cgroupPath := createCgroupPath(t, "real-one")
|
||||||
|
|
||||||
|
// create manager
|
||||||
|
m, err := NewCgroup2Manager(
|
||||||
|
WithCgroup2ProcessType(ProcessTypeSocat, cgroupPath, map[string]string{
|
||||||
|
"memory.max": strconv.Itoa(1 * kilobyte),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := m.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
// create new child process
|
||||||
|
cmd := startProcess(t, m, ProcessTypeSocat)
|
||||||
|
|
||||||
|
// wait for child process to die
|
||||||
|
err = waitForProcess(t, cmd, maxTimeout)
|
||||||
|
|
||||||
|
// verify process exited correctly
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.ErrorAs(t, err, &exitErr)
|
||||||
|
assert.Equal(t, "exit status 253", exitErr.Error())
|
||||||
|
assert.True(t, exitErr.Exited())
|
||||||
|
assert.False(t, exitErr.Success())
|
||||||
|
assert.Equal(t, 253, exitErr.ExitCode())
|
||||||
|
|
||||||
|
// dig a little deeper
|
||||||
|
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, syscall.Signal(-1), ws.Signal())
|
||||||
|
assert.False(t, ws.Signaled())
|
||||||
|
assert.False(t, ws.Stopped())
|
||||||
|
assert.False(t, ws.Continued())
|
||||||
|
assert.False(t, ws.CoreDump())
|
||||||
|
assert.True(t, ws.Exited())
|
||||||
|
assert.Equal(t, 253, ws.ExitStatus())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCgroupPath(t *testing.T, s string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
randPart := rand.Int()
|
||||||
|
|
||||||
|
return fmt.Sprintf("envd-test-%s-%d", s, randPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startProcess(t *testing.T, m *Cgroup2Manager, pt ProcessType) *exec.Cmd {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cmdName, args := "bash", []string{"-c", `sleep 1 && tail /dev/zero`}
|
||||||
|
cmd := exec.CommandContext(t.Context(), cmdName, args...)
|
||||||
|
|
||||||
|
fd, ok := m.GetFileDescriptor(pt)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
UseCgroupFD: ok,
|
||||||
|
CgroupFD: fd,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cmd.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForProcess(t *testing.T, cmd *exec.Cmd, timeout time.Duration) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
done <- cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), timeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
16
envd/internal/services/cgroups/iface.go
Normal file
16
envd/internal/services/cgroups/iface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cgroups
|
||||||
|
|
||||||
|
type ProcessType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProcessTypePTY ProcessType = "pty"
|
||||||
|
ProcessTypeUser ProcessType = "user"
|
||||||
|
ProcessTypeSocat ProcessType = "socat"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetFileDescriptor(procType ProcessType) (int, bool)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
19
envd/internal/services/cgroups/noop.go
Normal file
19
envd/internal/services/cgroups/noop.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cgroups
|
||||||
|
|
||||||
|
type NoopManager struct{}
|
||||||
|
|
||||||
|
var _ Manager = (*NoopManager)(nil)
|
||||||
|
|
||||||
|
func NewNoopManager() *NoopManager {
|
||||||
|
return &NoopManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NoopManager) GetFileDescriptor(ProcessType) (int, bool) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NoopManager) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
186
envd/internal/services/filesystem/dir.go
Normal file
186
envd/internal/services/filesystem/dir.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s Service) ListDir(ctx context.Context, req *connect.Request[rpc.ListDirRequest]) (*connect.Response[rpc.ListDirResponse], error) {
|
||||||
|
depth := req.Msg.GetDepth()
|
||||||
|
if depth == 0 {
|
||||||
|
depth = 1 // default depth to current directory
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedPath := req.Msg.GetPath()
|
||||||
|
|
||||||
|
// Expand the path so we can return absolute paths in the response.
|
||||||
|
requestedPath, err = permissions.ExpandAndResolve(requestedPath, u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedPath, err := followSymlink(requestedPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = checkIfDirectory(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := walkDir(requestedPath, resolvedPath, int(depth))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.ListDirResponse{
|
||||||
|
Entries: entries,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) MakeDir(ctx context.Context, req *connect.Request[rpc.MakeDirRequest]) (*connect.Response[rpc.MakeDirResponse], error) {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dirPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, err := os.Stat(dirPath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if stat.IsDir() {
|
||||||
|
return nil, connect.NewError(connect.CodeAlreadyExists, fmt.Errorf("directory already exists: %s", dirPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path already exists but it is not a directory: %s", dirPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
uid, gid, userErr := permissions.GetUserIdInts(u)
|
||||||
|
if userErr != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
userErr = permissions.EnsureDirs(dirPath, uid, gid)
|
||||||
|
if userErr != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := entryInfo(dirPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.MakeDirResponse{
|
||||||
|
Entry: entry,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// followSymlink resolves a symbolic link to its target path.
|
||||||
|
func followSymlink(path string) (string, error) {
|
||||||
|
// Resolve symlinks
|
||||||
|
resolvedPath, err := filepath.EvalSymlinks(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return "", connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(err.Error(), "too many links") {
|
||||||
|
return "", connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("cyclic symlink or chain >255 links at %q", path))
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", connect.NewError(connect.CodeInternal, fmt.Errorf("error resolving symlink: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvedPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkIfDirectory checks if the given path is a directory.
|
||||||
|
func checkIfDirectory(path string) error {
|
||||||
|
stat, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stat.IsDir() {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path is not a directory: %s", path))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// walkDir walks the directory tree starting from dirPath up to the specified depth (doesn't follow symlinks).
|
||||||
|
func walkDir(requestedPath string, dirPath string, depth int) (entries []*rpc.EntryInfo, err error) {
|
||||||
|
err = filepath.WalkDir(dirPath, func(path string, _ os.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the root directory itself
|
||||||
|
if path == dirPath {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate current depth
|
||||||
|
relPath, err := filepath.Rel(dirPath, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
currentDepth := len(strings.Split(relPath, string(os.PathSeparator)))
|
||||||
|
|
||||||
|
if currentDepth > depth {
|
||||||
|
return filepath.SkipDir
|
||||||
|
}
|
||||||
|
|
||||||
|
entryInfo, err := entryInfo(path)
|
||||||
|
if err != nil {
|
||||||
|
var connectErr *connect.Error
|
||||||
|
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeNotFound {
|
||||||
|
// Skip entries that don't exist anymore
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the requested path as the base path instead of the symlink-resolved path
|
||||||
|
path = filepath.Join(requestedPath, relPath)
|
||||||
|
entryInfo.Path = path
|
||||||
|
|
||||||
|
entries = append(entries, entryInfo)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error reading directory %s: %w", dirPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
407
envd/internal/services/filesystem/dir_test.go
Normal file
407
envd/internal/services/filesystem/dir_test.go
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"connectrpc.com/authn"
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestListDir(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup directory structure
|
||||||
|
testFolder := filepath.Join(root, "test")
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-1"), 0o755))
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-2"), 0o755))
|
||||||
|
filePath := filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Helper to inject user into context
|
||||||
|
injectUser := func(ctx context.Context, u *user.User) context.Context {
|
||||||
|
return authn.SetInfo(ctx, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
depth uint32
|
||||||
|
expectedPaths []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "depth 0 lists only root directory",
|
||||||
|
depth: 0,
|
||||||
|
expectedPaths: []string{
|
||||||
|
filepath.Join(testFolder, "test-dir"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "depth 1 lists root directory",
|
||||||
|
depth: 1,
|
||||||
|
expectedPaths: []string{
|
||||||
|
filepath.Join(testFolder, "test-dir"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "depth 2 lists first level of subdirectories (in this case the root directory)",
|
||||||
|
depth: 2,
|
||||||
|
expectedPaths: []string{
|
||||||
|
filepath.Join(testFolder, "test-dir"),
|
||||||
|
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
|
||||||
|
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "depth 3 lists all directories and files",
|
||||||
|
depth: 3,
|
||||||
|
expectedPaths: []string{
|
||||||
|
filepath.Join(testFolder, "test-dir"),
|
||||||
|
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
|
||||||
|
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
|
||||||
|
filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := injectUser(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: testFolder,
|
||||||
|
Depth: tt.depth,
|
||||||
|
})
|
||||||
|
resp, err := svc.ListDir(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.Msg)
|
||||||
|
assert.Len(t, resp.Msg.GetEntries(), len(tt.expectedPaths))
|
||||||
|
actualPaths := make([]string, len(resp.Msg.GetEntries()))
|
||||||
|
for i, entry := range resp.Msg.GetEntries() {
|
||||||
|
actualPaths[i] = entry.GetPath()
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, tt.expectedPaths, actualPaths)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDirNonExistingPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
svc := mockService()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: "/non-existing-path",
|
||||||
|
Depth: 1,
|
||||||
|
})
|
||||||
|
_, err = svc.ListDir(ctx, req)
|
||||||
|
require.Error(t, err)
|
||||||
|
var connectErr *connect.Error
|
||||||
|
ok := errors.As(err, &connectErr)
|
||||||
|
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||||
|
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDirRelativePath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup directory structure
|
||||||
|
testRelativePath := fmt.Sprintf("test-%s", uuid.New())
|
||||||
|
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
|
||||||
|
filePath := filepath.Join(testFolderPath, "file.txt")
|
||||||
|
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: testRelativePath,
|
||||||
|
Depth: 1,
|
||||||
|
})
|
||||||
|
resp, err := svc.ListDir(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.Msg)
|
||||||
|
|
||||||
|
expectedPaths := []string{
|
||||||
|
filepath.Join(testFolderPath, "file.txt"),
|
||||||
|
}
|
||||||
|
assert.Len(t, resp.Msg.GetEntries(), len(expectedPaths))
|
||||||
|
|
||||||
|
actualPaths := make([]string, len(resp.Msg.GetEntries()))
|
||||||
|
for i, entry := range resp.Msg.GetEntries() {
|
||||||
|
actualPaths[i] = entry.GetPath()
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, expectedPaths, actualPaths)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDir_Symlinks(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
|
||||||
|
symlinkRoot := filepath.Join(root, "test-symlinks")
|
||||||
|
require.NoError(t, os.MkdirAll(symlinkRoot, 0o755))
|
||||||
|
|
||||||
|
// 1. Prepare a real directory + file that a symlink will point to
|
||||||
|
realDir := filepath.Join(symlinkRoot, "real-dir")
|
||||||
|
require.NoError(t, os.MkdirAll(realDir, 0o755))
|
||||||
|
filePath := filepath.Join(realDir, "file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
|
||||||
|
|
||||||
|
// 2. Prepare a standalone real file (points-to-file scenario)
|
||||||
|
realFile := filepath.Join(symlinkRoot, "real-file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
|
||||||
|
|
||||||
|
// 3. Create the three symlinks
|
||||||
|
linkToDir := filepath.Join(symlinkRoot, "link-dir") // → directory
|
||||||
|
linkToFile := filepath.Join(symlinkRoot, "link-file") // → file
|
||||||
|
cyclicLink := filepath.Join(symlinkRoot, "cyclic") // → itself
|
||||||
|
require.NoError(t, os.Symlink(realDir, linkToDir))
|
||||||
|
require.NoError(t, os.Symlink(realFile, linkToFile))
|
||||||
|
require.NoError(t, os.Symlink(cyclicLink, cyclicLink))
|
||||||
|
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
t.Run("symlink to directory behaves like directory and the content looks like inside the directory", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: linkToDir,
|
||||||
|
Depth: 1,
|
||||||
|
})
|
||||||
|
resp, err := svc.ListDir(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
expected := []string{
|
||||||
|
filepath.Join(linkToDir, "file.txt"),
|
||||||
|
}
|
||||||
|
actual := make([]string, len(resp.Msg.GetEntries()))
|
||||||
|
for i, e := range resp.Msg.GetEntries() {
|
||||||
|
actual[i] = e.GetPath()
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, expected, actual)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("link to file", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: linkToFile,
|
||||||
|
Depth: 1,
|
||||||
|
})
|
||||||
|
_, err := svc.ListDir(ctx, req)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "not a directory")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cyclic symlink surfaces 'too many links' → invalid-argument", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: cyclicLink,
|
||||||
|
})
|
||||||
|
_, err := svc.ListDir(ctx, req)
|
||||||
|
require.Error(t, err)
|
||||||
|
var connectErr *connect.Error
|
||||||
|
ok := errors.As(err, &connectErr)
|
||||||
|
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||||
|
assert.Equal(t, connect.CodeFailedPrecondition, connectErr.Code())
|
||||||
|
assert.Contains(t, connectErr.Error(), "cyclic symlink")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("symlink not resolved if not root", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.ListDirRequest{
|
||||||
|
Path: symlinkRoot,
|
||||||
|
Depth: 3,
|
||||||
|
})
|
||||||
|
res, err := svc.ListDir(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
expected := []string{
|
||||||
|
filepath.Join(symlinkRoot, "cyclic"),
|
||||||
|
filepath.Join(symlinkRoot, "link-dir"),
|
||||||
|
filepath.Join(symlinkRoot, "link-file"),
|
||||||
|
filepath.Join(symlinkRoot, "real-dir"),
|
||||||
|
filepath.Join(symlinkRoot, "real-dir", "file.txt"),
|
||||||
|
filepath.Join(symlinkRoot, "real-file.txt"),
|
||||||
|
}
|
||||||
|
actual := make([]string, len(res.Msg.GetEntries()))
|
||||||
|
for i, e := range res.Msg.GetEntries() {
|
||||||
|
actual[i] = e.GetPath()
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, expected, actual, "symlinks should not be resolved when listing the symlink root directory")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFollowSymlink_Success makes sure that followSymlink resolves symlinks,
|
||||||
|
// while also being robust to the /var → /private/var indirection that exists on macOS.
|
||||||
|
func TestFollowSymlink_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Base temporary directory. On macOS this lives under /var/folders/…
|
||||||
|
// which itself is a symlink to /private/var/folders/….
|
||||||
|
base := t.TempDir()
|
||||||
|
|
||||||
|
// Create a real directory that we ultimately want to resolve to.
|
||||||
|
target := filepath.Join(base, "target")
|
||||||
|
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||||
|
|
||||||
|
// Create a symlink pointing at the real directory so we can verify that
|
||||||
|
// followSymlink follows it.
|
||||||
|
link := filepath.Join(base, "link")
|
||||||
|
require.NoError(t, os.Symlink(target, link))
|
||||||
|
|
||||||
|
got, err := followSymlink(link)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Canonicalise the expected path too, so that /var → /private/var (macOS)
|
||||||
|
// or any other benign symlink indirections don’t cause flaky tests.
|
||||||
|
want, err := filepath.EvalSymlinks(link)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, want, got, "followSymlink should resolve and canonicalise symlinks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFollowSymlink_MultiSymlinkChain verifies that followSymlink follows a chain
|
||||||
|
// of several symlinks (non‑cyclic) correctly.
|
||||||
|
func TestFollowSymlink_MultiSymlinkChain(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
base := t.TempDir()
|
||||||
|
|
||||||
|
// Final destination directory.
|
||||||
|
target := filepath.Join(base, "target")
|
||||||
|
require.NoError(t, os.MkdirAll(target, 0o755))
|
||||||
|
|
||||||
|
// Build a 3‑link chain: link1 → link2 → link3 → target.
|
||||||
|
link3 := filepath.Join(base, "link3")
|
||||||
|
require.NoError(t, os.Symlink(target, link3))
|
||||||
|
|
||||||
|
link2 := filepath.Join(base, "link2")
|
||||||
|
require.NoError(t, os.Symlink(link3, link2))
|
||||||
|
|
||||||
|
link1 := filepath.Join(base, "link1")
|
||||||
|
require.NoError(t, os.Symlink(link2, link1))
|
||||||
|
|
||||||
|
got, err := followSymlink(link1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
want, err := filepath.EvalSymlinks(link1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, want, got, "followSymlink should resolve an arbitrary symlink chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFollowSymlink_NotFound(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := followSymlink("/definitely/does/not/exist")
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var cerr *connect.Error
|
||||||
|
require.ErrorAs(t, err, &cerr)
|
||||||
|
require.Equal(t, connect.CodeNotFound, cerr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFollowSymlink_CyclicSymlink(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
a := filepath.Join(dir, "a")
|
||||||
|
b := filepath.Join(dir, "b")
|
||||||
|
require.NoError(t, os.MkdirAll(a, 0o755))
|
||||||
|
require.NoError(t, os.MkdirAll(b, 0o755))
|
||||||
|
|
||||||
|
// Create a two‑node loop: a/loop → b/loop, b/loop → a/loop.
|
||||||
|
require.NoError(t, os.Symlink(filepath.Join(b, "loop"), filepath.Join(a, "loop")))
|
||||||
|
require.NoError(t, os.Symlink(filepath.Join(a, "loop"), filepath.Join(b, "loop")))
|
||||||
|
|
||||||
|
_, err := followSymlink(filepath.Join(a, "loop"))
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var cerr *connect.Error
|
||||||
|
require.ErrorAs(t, err, &cerr)
|
||||||
|
require.Equal(t, connect.CodeFailedPrecondition, cerr.Code())
|
||||||
|
require.Contains(t, cerr.Message(), "cyclic")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckIfDirectory(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
require.NoError(t, checkIfDirectory(dir))
|
||||||
|
|
||||||
|
file := filepath.Join(dir, "file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(file, []byte("hello"), 0o644))
|
||||||
|
|
||||||
|
err := checkIfDirectory(file)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var cerr *connect.Error
|
||||||
|
require.ErrorAs(t, err, &cerr)
|
||||||
|
require.Equal(t, connect.CodeInvalidArgument, cerr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkDir_Depth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
root := t.TempDir()
|
||||||
|
sub := filepath.Join(root, "sub")
|
||||||
|
subsub := filepath.Join(sub, "subsub")
|
||||||
|
require.NoError(t, os.MkdirAll(subsub, 0o755))
|
||||||
|
|
||||||
|
entries, err := walkDir(root, root, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Collect the names for easier assertions.
|
||||||
|
names := make([]string, 0, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
names = append(names, e.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, names, "sub")
|
||||||
|
require.NotContains(t, names, "subsub", "entries beyond depth should be excluded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkDir_Error(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := walkDir("/does/not/exist", "/does/not/exist", 1)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var cerr *connect.Error
|
||||||
|
require.ErrorAs(t, err, &cerr)
|
||||||
|
require.Equal(t, connect.CodeInternal, cerr.Code())
|
||||||
|
}
|
||||||
60
envd/internal/services/filesystem/move.go
Normal file
60
envd/internal/services/filesystem/move.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s Service) Move(ctx context.Context, req *connect.Request[rpc.MoveRequest]) (*connect.Response[rpc.MoveResponse], error) {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
source, err := permissions.ExpandAndResolve(req.Msg.GetSource(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destination, err := permissions.ExpandAndResolve(req.Msg.GetDestination(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
uid, gid, userErr := permissions.GetUserIdInts(u)
|
||||||
|
if userErr != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
userErr = permissions.EnsureDirs(filepath.Dir(destination), uid, gid)
|
||||||
|
if userErr != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, userErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Rename(source, destination)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("source file not found: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error renaming: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := entryInfo(destination)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.MoveResponse{
|
||||||
|
Entry: entry,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
366
envd/internal/services/filesystem/move_test.go
Normal file
366
envd/internal/services/filesystem/move_test.go
Normal file
@ -0,0 +1,366 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"connectrpc.com/authn"
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMove(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup source and destination directories
|
||||||
|
sourceDir := filepath.Join(root, "source")
|
||||||
|
destDir := filepath.Join(root, "destination")
|
||||||
|
require.NoError(t, os.MkdirAll(sourceDir, 0o755))
|
||||||
|
require.NoError(t, os.MkdirAll(destDir, 0o755))
|
||||||
|
|
||||||
|
// Create a test file to move
|
||||||
|
sourceFile := filepath.Join(sourceDir, "test-file.txt")
|
||||||
|
testContent := []byte("Hello, World!")
|
||||||
|
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
|
||||||
|
|
||||||
|
// Destination file path
|
||||||
|
destFile := filepath.Join(destDir, "test-file.txt")
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Call the Move function
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: sourceFile,
|
||||||
|
Destination: destFile,
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
|
||||||
|
// Verify the move was successful
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the file exists at the destination
|
||||||
|
_, err = os.Stat(destFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file no longer exists at the source
|
||||||
|
_, err = os.Stat(sourceFile)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the content of the moved file
|
||||||
|
content, err := os.ReadFile(destFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testContent, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveDirectory(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup source and destination directories
|
||||||
|
sourceParent := filepath.Join(root, "source-parent")
|
||||||
|
destParent := filepath.Join(root, "dest-parent")
|
||||||
|
require.NoError(t, os.MkdirAll(sourceParent, 0o755))
|
||||||
|
require.NoError(t, os.MkdirAll(destParent, 0o755))
|
||||||
|
|
||||||
|
// Create a test directory with files to move
|
||||||
|
sourceDir := filepath.Join(sourceParent, "test-dir")
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Join(sourceDir, "subdir"), 0o755))
|
||||||
|
|
||||||
|
// Create some files in the directory
|
||||||
|
file1 := filepath.Join(sourceDir, "file1.txt")
|
||||||
|
file2 := filepath.Join(sourceDir, "subdir", "file2.txt")
|
||||||
|
require.NoError(t, os.WriteFile(file1, []byte("File 1 content"), 0o644))
|
||||||
|
require.NoError(t, os.WriteFile(file2, []byte("File 2 content"), 0o644))
|
||||||
|
|
||||||
|
// Destination directory path
|
||||||
|
destDir := filepath.Join(destParent, "test-dir")
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Call the Move function
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: sourceDir,
|
||||||
|
Destination: destDir,
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
|
||||||
|
// Verify the move was successful
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, destDir, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the directory exists at the destination
|
||||||
|
_, err = os.Stat(destDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the files exist at the destination
|
||||||
|
destFile1 := filepath.Join(destDir, "file1.txt")
|
||||||
|
destFile2 := filepath.Join(destDir, "subdir", "file2.txt")
|
||||||
|
_, err = os.Stat(destFile1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = os.Stat(destFile2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the directory no longer exists at the source
|
||||||
|
_, err = os.Stat(sourceDir)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the content of the moved files
|
||||||
|
content1, err := os.ReadFile(destFile1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("File 1 content"), content1)
|
||||||
|
|
||||||
|
content2, err := os.ReadFile(destFile2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("File 2 content"), content2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveNonExistingFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup destination directory
|
||||||
|
destDir := filepath.Join(root, "destination")
|
||||||
|
require.NoError(t, os.MkdirAll(destDir, 0o755))
|
||||||
|
|
||||||
|
// Non-existing source file
|
||||||
|
sourceFile := filepath.Join(root, "non-existing-file.txt")
|
||||||
|
|
||||||
|
// Destination file path
|
||||||
|
destFile := filepath.Join(destDir, "moved-file.txt")
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Call the Move function
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: sourceFile,
|
||||||
|
Destination: destFile,
|
||||||
|
})
|
||||||
|
_, err = svc.Move(ctx, req)
|
||||||
|
|
||||||
|
// Verify the correct error is returned
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var connectErr *connect.Error
|
||||||
|
ok := errors.As(err, &connectErr)
|
||||||
|
assert.True(t, ok, "expected error to be of type *connect.Error")
|
||||||
|
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||||
|
assert.Contains(t, connectErr.Message(), "source file not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveRelativePath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup user
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup directory structure with unique name to avoid conflicts
|
||||||
|
testRelativePath := fmt.Sprintf("test-move-%s", uuid.New())
|
||||||
|
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
|
||||||
|
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
|
||||||
|
|
||||||
|
// Create a test file to move
|
||||||
|
sourceFile := filepath.Join(testFolderPath, "source-file.txt")
|
||||||
|
testContent := []byte("Hello from relative path!")
|
||||||
|
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
|
||||||
|
|
||||||
|
// Destination file path (also relative)
|
||||||
|
destRelativePath := fmt.Sprintf("test-move-dest-%s", uuid.New())
|
||||||
|
destFolderPath := filepath.Join(u.HomeDir, destRelativePath)
|
||||||
|
require.NoError(t, os.MkdirAll(destFolderPath, 0o755))
|
||||||
|
destFile := filepath.Join(destFolderPath, "moved-file.txt")
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Call the Move function with relative paths
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: filepath.Join(testRelativePath, "source-file.txt"), // Relative path
|
||||||
|
Destination: filepath.Join(destRelativePath, "moved-file.txt"), // Relative path
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
|
||||||
|
// Verify the move was successful
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the file exists at the destination
|
||||||
|
_, err = os.Stat(destFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file no longer exists at the source
|
||||||
|
_, err = os.Stat(sourceFile)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the content of the moved file
|
||||||
|
content, err := os.ReadFile(destFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testContent, content)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
os.RemoveAll(testFolderPath)
|
||||||
|
os.RemoveAll(destFolderPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMove_Symlinks(t *testing.T) { //nolint:tparallel // this test cannot be executed in parallel
|
||||||
|
root := t.TempDir()
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
|
||||||
|
// Setup source and destination directories
|
||||||
|
sourceRoot := filepath.Join(root, "source")
|
||||||
|
destRoot := filepath.Join(root, "destination")
|
||||||
|
require.NoError(t, os.MkdirAll(sourceRoot, 0o755))
|
||||||
|
require.NoError(t, os.MkdirAll(destRoot, 0o755))
|
||||||
|
|
||||||
|
// 1. Prepare a real directory + file that a symlink will point to
|
||||||
|
realDir := filepath.Join(sourceRoot, "real-dir")
|
||||||
|
require.NoError(t, os.MkdirAll(realDir, 0o755))
|
||||||
|
filePath := filepath.Join(realDir, "file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
|
||||||
|
|
||||||
|
// 2. Prepare a standalone real file (points-to-file scenario)
|
||||||
|
realFile := filepath.Join(sourceRoot, "real-file.txt")
|
||||||
|
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
|
||||||
|
|
||||||
|
// 3. Create symlinks
|
||||||
|
linkToDir := filepath.Join(sourceRoot, "link-dir") // → directory
|
||||||
|
linkToFile := filepath.Join(sourceRoot, "link-file") // → file
|
||||||
|
require.NoError(t, os.Symlink(realDir, linkToDir))
|
||||||
|
require.NoError(t, os.Symlink(realFile, linkToFile))
|
||||||
|
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
t.Run("move symlink to directory", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
destPath := filepath.Join(destRoot, "moved-link-dir")
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: linkToDir,
|
||||||
|
Destination: destPath,
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the symlink was moved
|
||||||
|
_, err = os.Stat(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it's still a symlink
|
||||||
|
info, err := os.Lstat(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
|
||||||
|
|
||||||
|
// Verify the symlink target is still correct
|
||||||
|
target, err := os.Readlink(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, realDir, target)
|
||||||
|
|
||||||
|
// Verify the original symlink is gone
|
||||||
|
_, err = os.Stat(linkToDir)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the real directory still exists
|
||||||
|
_, err = os.Stat(realDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("move symlink to file", func(t *testing.T) { //nolint:paralleltest
|
||||||
|
destPath := filepath.Join(destRoot, "moved-link-file")
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: linkToFile,
|
||||||
|
Destination: destPath,
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the symlink was moved
|
||||||
|
_, err = os.Stat(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it's still a symlink
|
||||||
|
info, err := os.Lstat(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
|
||||||
|
|
||||||
|
// Verify the symlink target is still correct
|
||||||
|
target, err := os.Readlink(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, realFile, target)
|
||||||
|
|
||||||
|
// Verify the original symlink is gone
|
||||||
|
_, err = os.Stat(linkToFile)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the real file still exists
|
||||||
|
_, err = os.Stat(realFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("move real file that is target of symlink", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Create a new symlink to the real file
|
||||||
|
newLinkToFile := filepath.Join(sourceRoot, "new-link-file")
|
||||||
|
require.NoError(t, os.Symlink(realFile, newLinkToFile))
|
||||||
|
|
||||||
|
destPath := filepath.Join(destRoot, "moved-real-file.txt")
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.MoveRequest{
|
||||||
|
Source: realFile,
|
||||||
|
Destination: destPath,
|
||||||
|
})
|
||||||
|
resp, err := svc.Move(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
|
||||||
|
|
||||||
|
// Verify the real file was moved
|
||||||
|
_, err = os.Stat(destPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the original file is gone
|
||||||
|
_, err = os.Stat(realFile)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
// Verify the symlink still exists but now points to a non-existent file
|
||||||
|
_, err = os.Stat(newLinkToFile)
|
||||||
|
require.Error(t, err, "symlink should point to non-existent file")
|
||||||
|
})
|
||||||
|
}
|
||||||
33
envd/internal/services/filesystem/remove.go
Normal file
33
envd/internal/services/filesystem/remove.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s Service) Remove(ctx context.Context, req *connect.Request[rpc.RemoveRequest]) (*connect.Response[rpc.RemoveResponse], error) {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.RemoveAll(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error removing file or directory: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.RemoveResponse{}), nil
|
||||||
|
}
|
||||||
37
envd/internal/services/filesystem/service.go
Normal file
37
envd/internal/services/filesystem/service.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem/filesystemconnect"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
logger *zerolog.Logger
|
||||||
|
watchers *utils.Map[string, *FileWatcher]
|
||||||
|
defaults *execcontext.Defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults) {
|
||||||
|
service := Service{
|
||||||
|
logger: l,
|
||||||
|
watchers: utils.NewMap[string, *FileWatcher](),
|
||||||
|
defaults: defaults,
|
||||||
|
}
|
||||||
|
|
||||||
|
interceptors := connect.WithInterceptors(
|
||||||
|
logs.NewUnaryLogInterceptor(l),
|
||||||
|
)
|
||||||
|
|
||||||
|
path, handler := spec.NewFilesystemHandler(service, interceptors)
|
||||||
|
|
||||||
|
server.Mount(path, handler)
|
||||||
|
}
|
||||||
16
envd/internal/services/filesystem/service_test.go
Normal file
16
envd/internal/services/filesystem/service_test.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockService() Service {
|
||||||
|
return Service{
|
||||||
|
defaults: &execcontext.Defaults{
|
||||||
|
EnvVars: utils.NewMap[string, string](),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
31
envd/internal/services/filesystem/stat.go
Normal file
31
envd/internal/services/filesystem/stat.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s Service) Stat(ctx context.Context, req *connect.Request[rpc.StatRequest]) (*connect.Response[rpc.StatResponse], error) {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := entryInfo(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.StatResponse{Entry: entry}), nil
|
||||||
|
}
|
||||||
116
envd/internal/services/filesystem/stat_test.go
Normal file
116
envd/internal/services/filesystem/stat_test.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"connectrpc.com/authn"
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup temp root and user
|
||||||
|
root := t.TempDir()
|
||||||
|
// Get the actual path to the temp directory (symlinks can cause issues)
|
||||||
|
root, err := filepath.EvalSymlinks(root)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
group, err := user.LookupGroupId(u.Gid)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup directory structure
|
||||||
|
testFolder := filepath.Join(root, "test")
|
||||||
|
err = os.MkdirAll(testFolder, 0o755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testFile := filepath.Join(testFolder, "file.txt")
|
||||||
|
err = os.WriteFile(testFile, []byte("Hello, World!"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
linkedFile := filepath.Join(testFolder, "linked-file.txt")
|
||||||
|
err = os.Symlink(testFile, linkedFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Service instance
|
||||||
|
svc := mockService()
|
||||||
|
|
||||||
|
// Helper to inject user into context
|
||||||
|
injectUser := func(ctx context.Context, u *user.User) context.Context {
|
||||||
|
return authn.SetInfo(ctx, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Stat file directory",
|
||||||
|
path: testFile,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Stat symlink to file",
|
||||||
|
path: linkedFile,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := injectUser(t.Context(), u)
|
||||||
|
req := connect.NewRequest(&filesystem.StatRequest{
|
||||||
|
Path: tt.path,
|
||||||
|
})
|
||||||
|
resp, err := svc.Stat(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, resp.Msg)
|
||||||
|
require.NotNil(t, resp.Msg.GetEntry())
|
||||||
|
assert.Equal(t, tt.path, resp.Msg.GetEntry().GetPath())
|
||||||
|
assert.Equal(t, filesystem.FileType_FILE_TYPE_FILE, resp.Msg.GetEntry().GetType())
|
||||||
|
assert.Equal(t, u.Username, resp.Msg.GetEntry().GetOwner())
|
||||||
|
assert.Equal(t, group.Name, resp.Msg.GetEntry().GetGroup())
|
||||||
|
assert.Equal(t, uint32(0o644), resp.Msg.GetEntry().GetMode())
|
||||||
|
if tt.path == linkedFile {
|
||||||
|
require.NotNil(t, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||||
|
assert.Equal(t, testFile, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||||
|
} else {
|
||||||
|
assert.Empty(t, resp.Msg.GetEntry().GetSymlinkTarget())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatMissingPathReturnsNotFound(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
u, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := mockService()
|
||||||
|
ctx := authn.SetInfo(t.Context(), u)
|
||||||
|
|
||||||
|
req := connect.NewRequest(&filesystem.StatRequest{
|
||||||
|
Path: filepath.Join(t.TempDir(), "missing.txt"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = svc.Stat(ctx, req)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var connectErr *connect.Error
|
||||||
|
require.ErrorAs(t, err, &connectErr)
|
||||||
|
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
|
||||||
|
}
|
||||||
109
envd/internal/services/filesystem/utils.go
Normal file
109
envd/internal/services/filesystem/utils.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Filesystem magic numbers from Linux kernel (include/uapi/linux/magic.h)
|
||||||
|
const (
|
||||||
|
nfsSuperMagic = 0x6969
|
||||||
|
cifsMagic = 0xFF534D42
|
||||||
|
smbSuperMagic = 0x517B
|
||||||
|
smb2MagicNumber = 0xFE534D42
|
||||||
|
fuseSuperMagic = 0x65735546
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsPathOnNetworkMount checks if the given path is on a network filesystem mount.
|
||||||
|
// Returns true if the path is on NFS, CIFS, SMB, or FUSE filesystem.
|
||||||
|
func IsPathOnNetworkMount(path string) (bool, error) {
|
||||||
|
var statfs syscall.Statfs_t
|
||||||
|
if err := syscall.Statfs(path, &statfs); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to statfs %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch statfs.Type {
|
||||||
|
case nfsSuperMagic, cifsMagic, smbSuperMagic, smb2MagicNumber, fuseSuperMagic:
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func entryInfo(path string) (*rpc.EntryInfo, error) {
|
||||||
|
info, err := filesystem.GetEntryFromPath(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
owner, group := getFileOwnership(info)
|
||||||
|
|
||||||
|
return &rpc.EntryInfo{
|
||||||
|
Name: info.Name,
|
||||||
|
Type: getEntryType(info.Type),
|
||||||
|
Path: info.Path,
|
||||||
|
Size: info.Size,
|
||||||
|
Mode: uint32(info.Mode),
|
||||||
|
Permissions: info.Permissions,
|
||||||
|
Owner: owner,
|
||||||
|
Group: group,
|
||||||
|
ModifiedTime: toTimestamp(info.ModifiedTime),
|
||||||
|
SymlinkTarget: info.SymlinkTarget,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTimestamp(time time.Time) *timestamppb.Timestamp {
|
||||||
|
if time.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return timestamppb.New(time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFileOwnership returns the owner and group names for a file.
|
||||||
|
// If the lookup fails, it returns the numeric UID and GID as strings.
|
||||||
|
func getFileOwnership(fileInfo filesystem.EntryInfo) (owner, group string) {
|
||||||
|
// Look up username
|
||||||
|
owner = fmt.Sprintf("%d", fileInfo.UID)
|
||||||
|
if u, err := user.LookupId(owner); err == nil {
|
||||||
|
owner = u.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up group name
|
||||||
|
group = fmt.Sprintf("%d", fileInfo.GID)
|
||||||
|
if g, err := user.LookupGroupId(group); err == nil {
|
||||||
|
group = g.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return owner, group
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEntryType determines the type of file entry based on its mode and path.
|
||||||
|
// If the file is a symlink, it follows the symlink to determine the actual type.
|
||||||
|
func getEntryType(fileType filesystem.FileType) rpc.FileType {
|
||||||
|
switch fileType {
|
||||||
|
case filesystem.FileFileType:
|
||||||
|
return rpc.FileType_FILE_TYPE_FILE
|
||||||
|
case filesystem.DirectoryFileType:
|
||||||
|
return rpc.FileType_FILE_TYPE_DIRECTORY
|
||||||
|
case filesystem.SymlinkFileType:
|
||||||
|
return rpc.FileType_FILE_TYPE_SYMLINK
|
||||||
|
default:
|
||||||
|
return rpc.FileType_FILE_TYPE_UNSPECIFIED
|
||||||
|
}
|
||||||
|
}
|
||||||
151
envd/internal/services/filesystem/utils_test.go
Normal file
151
envd/internal/services/filesystem/utils_test.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
osuser "os/user"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fsmodel "git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsPathOnNetworkMount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Test with a regular directory (should not be on network mount)
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
isNetwork, err := IsPathOnNetworkMount(tempDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, isNetwork, "temp directory should not be on a network mount")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPathOnNetworkMount_FuseMount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Require bindfs to be available
|
||||||
|
_, err := exec.LookPath("bindfs")
|
||||||
|
require.NoError(t, err, "bindfs must be installed for this test")
|
||||||
|
|
||||||
|
// Require fusermount to be available (needed for unmounting)
|
||||||
|
_, err = exec.LookPath("fusermount")
|
||||||
|
require.NoError(t, err, "fusermount must be installed for this test")
|
||||||
|
|
||||||
|
// Create source and mount directories
|
||||||
|
sourceDir := t.TempDir()
|
||||||
|
mountDir := t.TempDir()
|
||||||
|
|
||||||
|
// Mount sourceDir onto mountDir using bindfs (FUSE)
|
||||||
|
ctx := context.Background()
|
||||||
|
cmd := exec.CommandContext(ctx, "bindfs", sourceDir, mountDir)
|
||||||
|
require.NoError(t, cmd.Run(), "failed to mount bindfs")
|
||||||
|
|
||||||
|
// Ensure we unmount on cleanup
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = exec.CommandContext(context.Background(), "fusermount", "-u", mountDir).Run()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test that the FUSE mount is detected
|
||||||
|
isNetwork, err := IsPathOnNetworkMount(mountDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, isNetwork, "FUSE mount should be detected as network filesystem")
|
||||||
|
|
||||||
|
// Test that the source directory is NOT detected as network mount
|
||||||
|
isNetworkSource, err := IsPathOnNetworkMount(sourceDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, isNetworkSource, "source directory should not be detected as network filesystem")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileOwnership_CurrentUser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("current user", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Get current user running the tests
|
||||||
|
cur, err := osuser.Current()
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("unable to determine current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine expected owner/group using the same lookup logic
|
||||||
|
expectedOwner := cur.Uid
|
||||||
|
if u, err := osuser.LookupId(cur.Uid); err == nil {
|
||||||
|
expectedOwner = u.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedGroup := cur.Gid
|
||||||
|
if g, err := osuser.LookupGroupId(cur.Gid); err == nil {
|
||||||
|
expectedGroup = g.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse UID/GID strings to uint32 for EntryInfo
|
||||||
|
uid64, err := strconv.ParseUint(cur.Uid, 10, 32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
gid64, err := strconv.ParseUint(cur.Gid, 10, 32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Build a minimal EntryInfo with current UID/GID
|
||||||
|
info := fsmodel.EntryInfo{ // from shared pkg
|
||||||
|
UID: uint32(uid64),
|
||||||
|
GID: uint32(gid64),
|
||||||
|
}
|
||||||
|
|
||||||
|
owner, group := getFileOwnership(info)
|
||||||
|
assert.Equal(t, expectedOwner, owner)
|
||||||
|
assert.Equal(t, expectedGroup, group)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no user", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Find a UID that does not exist on this system
|
||||||
|
var unknownUIDStr string
|
||||||
|
for i := 60001; i < 70000; i++ { // search a high range typically unused
|
||||||
|
idStr := strconv.Itoa(i)
|
||||||
|
if _, err := osuser.LookupId(idStr); err != nil {
|
||||||
|
unknownUIDStr = idStr
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if unknownUIDStr == "" {
|
||||||
|
t.Skip("could not find a non-existent UID in the probed range")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a GID that does not exist on this system
|
||||||
|
var unknownGIDStr string
|
||||||
|
for i := 60001; i < 70000; i++ { // search a high range typically unused
|
||||||
|
idStr := strconv.Itoa(i)
|
||||||
|
if _, err := osuser.LookupGroupId(idStr); err != nil {
|
||||||
|
unknownGIDStr = idStr
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if unknownGIDStr == "" {
|
||||||
|
t.Skip("could not find a non-existent GID in the probed range")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse to uint32 for EntryInfo construction
|
||||||
|
uid64, err := strconv.ParseUint(unknownUIDStr, 10, 32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
gid64, err := strconv.ParseUint(unknownGIDStr, 10, 32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
info := fsmodel.EntryInfo{
|
||||||
|
UID: uint32(uid64),
|
||||||
|
GID: uint32(gid64),
|
||||||
|
}
|
||||||
|
|
||||||
|
owner, group := getFileOwnership(info)
|
||||||
|
// Expect numeric fallbacks because lookups should fail for unknown IDs
|
||||||
|
assert.Equal(t, unknownUIDStr, owner)
|
||||||
|
assert.Equal(t, unknownGIDStr, group)
|
||||||
|
})
|
||||||
|
}
|
||||||
161
envd/internal/services/filesystem/watch.go
Normal file
161
envd/internal/services/filesystem/watch.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/e2b-dev/fsnotify"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s Service) WatchDir(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
|
||||||
|
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.watchHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) watchHandler(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(watchPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if path is on a network filesystem mount
|
||||||
|
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
|
||||||
|
}
|
||||||
|
if isNetworkMount {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
err = w.Add(utils.FsnotifyPath(watchPath, req.Msg.GetRecursive()))
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = stream.Send(&rpc.WatchDirResponse{
|
||||||
|
Event: &rpc.WatchDirResponse_Start{
|
||||||
|
Start: &rpc.WatchDirResponse_StartEvent{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-keepaliveTicker.C:
|
||||||
|
streamErr := stream.Send(&rpc.WatchDirResponse{
|
||||||
|
Event: &rpc.WatchDirResponse_Keepalive{
|
||||||
|
Keepalive: &rpc.WatchDirResponse_KeepAlive{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr))
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case chErr, ok := <-w.Errors:
|
||||||
|
if !ok {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
|
||||||
|
case e, ok := <-w.Events:
|
||||||
|
if !ok {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// One event can have multiple operations.
|
||||||
|
ops := []rpc.EventType{}
|
||||||
|
|
||||||
|
if fsnotify.Create.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Rename.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Chmod.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Write.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Remove.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
name, nameErr := filepath.Rel(watchPath, e.Name)
|
||||||
|
if nameErr != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
|
||||||
|
Filesystem: &rpc.FilesystemEvent{
|
||||||
|
Name: name,
|
||||||
|
Type: op,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
event := &rpc.WatchDirResponse{
|
||||||
|
Event: filesystemEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(event)
|
||||||
|
|
||||||
|
s.logger.
|
||||||
|
Debug().
|
||||||
|
Str("event_type", "filesystem_event").
|
||||||
|
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||||
|
Interface("filesystem_event", event).
|
||||||
|
Msg("Streaming filesystem event")
|
||||||
|
|
||||||
|
if streamErr != nil {
|
||||||
|
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending filesystem event: %w", streamErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
resetKeepalive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
226
envd/internal/services/filesystem/watch_sync.go
Normal file
226
envd/internal/services/filesystem/watch_sync.go
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package filesystem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/e2b-dev/fsnotify"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/id"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileWatcher struct {
|
||||||
|
watcher *fsnotify.Watcher
|
||||||
|
Events []*rpc.FilesystemEvent
|
||||||
|
cancel func()
|
||||||
|
Error error
|
||||||
|
|
||||||
|
Lock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateFileWatcher(ctx context.Context, watchPath string, recursive bool, operationID string, logger *zerolog.Logger) (*FileWatcher, error) {
|
||||||
|
w, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't want to cancel the context when the request is finished
|
||||||
|
ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||||
|
|
||||||
|
err = w.Add(utils.FsnotifyPath(watchPath, recursive))
|
||||||
|
if err != nil {
|
||||||
|
_ = w.Close()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
fw := &FileWatcher{
|
||||||
|
watcher: w,
|
||||||
|
cancel: cancel,
|
||||||
|
Events: []*rpc.FilesystemEvent{},
|
||||||
|
Error: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case chErr, ok := <-w.Errors:
|
||||||
|
if !ok {
|
||||||
|
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
|
||||||
|
|
||||||
|
return
|
||||||
|
case e, ok := <-w.Events:
|
||||||
|
if !ok {
|
||||||
|
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// One event can have multiple operations.
|
||||||
|
ops := []rpc.EventType{}
|
||||||
|
|
||||||
|
if fsnotify.Create.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Rename.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Chmod.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Write.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fsnotify.Remove.Has(e.Op) {
|
||||||
|
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
name, nameErr := filepath.Rel(watchPath, e.Name)
|
||||||
|
if nameErr != nil {
|
||||||
|
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fw.Lock.Lock()
|
||||||
|
fw.Events = append(fw.Events, &rpc.FilesystemEvent{
|
||||||
|
Name: name,
|
||||||
|
Type: op,
|
||||||
|
})
|
||||||
|
fw.Lock.Unlock()
|
||||||
|
|
||||||
|
// these are only used for logging
|
||||||
|
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
|
||||||
|
Filesystem: &rpc.FilesystemEvent{
|
||||||
|
Name: name,
|
||||||
|
Type: op,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
event := &rpc.WatchDirResponse{
|
||||||
|
Event: filesystemEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.
|
||||||
|
Debug().
|
||||||
|
Str("event_type", "filesystem_event").
|
||||||
|
Str(string(logs.OperationIDKey), operationID).
|
||||||
|
Interface("filesystem_event", event).
|
||||||
|
Msg("Streaming filesystem event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fw *FileWatcher) Close() {
|
||||||
|
_ = fw.watcher.Close()
|
||||||
|
fw.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) CreateWatcher(ctx context.Context, req *connect.Request[rpc.CreateWatcherRequest]) (*connect.Response[rpc.CreateWatcherResponse], error) {
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(watchPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if path is on a network filesystem mount
|
||||||
|
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
|
||||||
|
}
|
||||||
|
if isNetworkMount {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
watcherId := "w" + id.Generate()
|
||||||
|
|
||||||
|
w, err := CreateFileWatcher(ctx, watchPath, req.Msg.GetRecursive(), watcherId, s.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.watchers.Store(watcherId, w)
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.CreateWatcherResponse{
|
||||||
|
WatcherId: watcherId,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) GetWatcherEvents(_ context.Context, req *connect.Request[rpc.GetWatcherEventsRequest]) (*connect.Response[rpc.GetWatcherEventsResponse], error) {
|
||||||
|
watcherId := req.Msg.GetWatcherId()
|
||||||
|
|
||||||
|
w, ok := s.watchers.Load(watcherId)
|
||||||
|
if !ok {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Error != nil {
|
||||||
|
return nil, w.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Lock.Lock()
|
||||||
|
defer w.Lock.Unlock()
|
||||||
|
events := w.Events
|
||||||
|
w.Events = []*rpc.FilesystemEvent{}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.GetWatcherEventsResponse{
|
||||||
|
Events: events,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) RemoveWatcher(_ context.Context, req *connect.Request[rpc.RemoveWatcherRequest]) (*connect.Response[rpc.RemoveWatcherResponse], error) {
|
||||||
|
watcherId := req.Msg.GetWatcherId()
|
||||||
|
|
||||||
|
w, ok := s.watchers.Load(watcherId)
|
||||||
|
if !ok {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
s.watchers.Delete(watcherId)
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.RemoveWatcherResponse{}), nil
|
||||||
|
}
|
||||||
128
envd/internal/services/process/connect.go
Normal file
128
envd/internal/services/process/connect.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||||
|
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
|
||||||
|
ctx, cancel := context.WithCancelCause(ctx)
|
||||||
|
defer cancel(nil)
|
||||||
|
|
||||||
|
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
data, dataCancel := proc.DataEvent.Fork()
|
||||||
|
defer dataCancel()
|
||||||
|
|
||||||
|
end, endCancel := proc.EndEvent.Fork()
|
||||||
|
defer endCancel()
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &rpc.ProcessEvent_Start{
|
||||||
|
Start: &rpc.ProcessEvent_StartEvent{
|
||||||
|
Pid: proc.Pid(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
|
||||||
|
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
|
||||||
|
dataLoop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-keepaliveTicker.C:
|
||||||
|
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &rpc.ProcessEvent_Keepalive{
|
||||||
|
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel(ctx.Err())
|
||||||
|
|
||||||
|
return
|
||||||
|
case event, ok := <-data:
|
||||||
|
if !ok {
|
||||||
|
break dataLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &event,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resetKeepalive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel(ctx.Err())
|
||||||
|
|
||||||
|
return
|
||||||
|
case event, ok := <-end:
|
||||||
|
if !ok {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.ConnectResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &event,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-exitChan:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
480
envd/internal/services/process/handler/handler.go
Normal file
480
envd/internal/services/process/handler/handler.go
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/creack/pty"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultNice = 0
|
||||||
|
defaultOomScore = 100
|
||||||
|
outputBufferSize = 64
|
||||||
|
stdChunkSize = 2 << 14
|
||||||
|
ptyChunkSize = 2 << 13
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProcessExit struct {
|
||||||
|
Error *string
|
||||||
|
Status string
|
||||||
|
Exited bool
|
||||||
|
Code int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
Config *rpc.ProcessConfig
|
||||||
|
|
||||||
|
logger *zerolog.Logger
|
||||||
|
|
||||||
|
Tag *string
|
||||||
|
cmd *exec.Cmd
|
||||||
|
tty *os.File
|
||||||
|
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
|
||||||
|
outCancel context.CancelFunc
|
||||||
|
|
||||||
|
stdinMu sync.Mutex
|
||||||
|
stdin io.WriteCloser
|
||||||
|
|
||||||
|
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
|
||||||
|
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
|
||||||
|
}
|
||||||
|
|
||||||
|
// This method must be called only after the process has been started
|
||||||
|
func (p *Handler) Pid() uint32 {
|
||||||
|
return uint32(p.cmd.Process.Pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// userCommand returns a human-readable representation of the user's original command,
|
||||||
|
// without the internal OOM/nice wrapper that is prepended to the actual exec.
|
||||||
|
func (p *Handler) userCommand() string {
|
||||||
|
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentNice returns the nice value of the current process.
|
||||||
|
func currentNice() int {
|
||||||
|
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getpriority returns 20 - nice on Linux.
|
||||||
|
return 20 - prio
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(
|
||||||
|
ctx context.Context,
|
||||||
|
user *user.User,
|
||||||
|
req *rpc.StartRequest,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
defaults *execcontext.Defaults,
|
||||||
|
cgroupManager cgroups.Manager,
|
||||||
|
cancel context.CancelFunc,
|
||||||
|
) (*Handler, error) {
|
||||||
|
// User command string for logging (without the internal wrapper details).
|
||||||
|
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
|
||||||
|
|
||||||
|
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
|
||||||
|
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
|
||||||
|
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
|
||||||
|
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
|
||||||
|
niceDelta := defaultNice - currentNice()
|
||||||
|
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
|
||||||
|
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
|
||||||
|
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
|
||||||
|
|
||||||
|
uid, gid, err := permissions.GetUserIdUints(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := []uint32{gid}
|
||||||
|
if gids, err := user.GroupIds(); err != nil {
|
||||||
|
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
|
||||||
|
} else {
|
||||||
|
for _, g := range gids {
|
||||||
|
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
|
||||||
|
groups = append(groups, uint32(parsed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
|
||||||
|
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
UseCgroupFD: ok,
|
||||||
|
CgroupFD: cgroupFD,
|
||||||
|
Credential: &syscall.Credential{
|
||||||
|
Uid: uid,
|
||||||
|
Gid: gid,
|
||||||
|
Groups: groups,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the cwd resolved path exists
|
||||||
|
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Dir = resolvedPath
|
||||||
|
|
||||||
|
var formattedVars []string
|
||||||
|
|
||||||
|
// Take only 'PATH' variable from the current environment
|
||||||
|
// The 'PATH' should ideally be set in the environment
|
||||||
|
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
|
||||||
|
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
|
||||||
|
formattedVars = append(formattedVars, "USER="+user.Username)
|
||||||
|
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
|
||||||
|
|
||||||
|
// Add the environment variables from the global environment
|
||||||
|
if defaults.EnvVars != nil {
|
||||||
|
defaults.EnvVars.Range(func(key string, value string) bool {
|
||||||
|
formattedVars = append(formattedVars, key+"="+value)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only the last values of the env vars are used - this allows for overwriting defaults
|
||||||
|
for key, value := range req.GetProcess().GetEnvs() {
|
||||||
|
formattedVars = append(formattedVars, key+"="+value)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Env = formattedVars
|
||||||
|
|
||||||
|
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
|
||||||
|
|
||||||
|
var outWg sync.WaitGroup
|
||||||
|
|
||||||
|
// Create a context for waiting for and cancelling output pipes.
|
||||||
|
// Cancellation of the process via timeout will propagate and cancel this context too.
|
||||||
|
outCtx, outCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
Config: req.GetProcess(),
|
||||||
|
cmd: cmd,
|
||||||
|
Tag: req.Tag,
|
||||||
|
DataEvent: outMultiplex,
|
||||||
|
cancel: cancel,
|
||||||
|
outCtx: outCtx,
|
||||||
|
outCancel: outCancel,
|
||||||
|
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.GetPty() != nil {
|
||||||
|
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
|
||||||
|
// The output of the pty should correctly be passed though.
|
||||||
|
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
|
||||||
|
Cols: uint16(req.GetPty().GetSize().GetCols()),
|
||||||
|
Rows: uint16(req.GetPty().GetSize().GetRows()),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
|
||||||
|
}
|
||||||
|
|
||||||
|
outWg.Go(func() {
|
||||||
|
for {
|
||||||
|
buf := make([]byte, ptyChunkSize)
|
||||||
|
|
||||||
|
n, readErr := tty.Read(buf)
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||||
|
Data: &rpc.ProcessEvent_DataEvent{
|
||||||
|
Output: &rpc.ProcessEvent_DataEvent_Pty{
|
||||||
|
Pty: buf[:n],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(readErr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if readErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
h.tty = tty
|
||||||
|
} else {
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
outWg.Go(func() {
|
||||||
|
stdoutLogs := make(chan []byte, outputBufferSize)
|
||||||
|
defer close(stdoutLogs)
|
||||||
|
|
||||||
|
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
|
||||||
|
|
||||||
|
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
|
||||||
|
|
||||||
|
for {
|
||||||
|
buf := make([]byte, stdChunkSize)
|
||||||
|
|
||||||
|
n, readErr := stdout.Read(buf)
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||||
|
Data: &rpc.ProcessEvent_DataEvent{
|
||||||
|
Output: &rpc.ProcessEvent_DataEvent_Stdout{
|
||||||
|
Stdout: buf[:n],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stdoutLogs <- buf[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(readErr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if readErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
outWg.Go(func() {
|
||||||
|
stderrLogs := make(chan []byte, outputBufferSize)
|
||||||
|
defer close(stderrLogs)
|
||||||
|
|
||||||
|
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
|
||||||
|
|
||||||
|
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
|
||||||
|
|
||||||
|
for {
|
||||||
|
buf := make([]byte, stdChunkSize)
|
||||||
|
|
||||||
|
n, readErr := stderr.Read(buf)
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
outMultiplex.Source <- rpc.ProcessEvent_Data{
|
||||||
|
Data: &rpc.ProcessEvent_DataEvent{
|
||||||
|
Output: &rpc.ProcessEvent_DataEvent_Stderr{
|
||||||
|
Stderr: buf[:n],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stderrLogs <- buf[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(readErr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if readErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// For backwards compatibility we still set the stdin if not explicitly disabled
|
||||||
|
// If stdin is disabled, the process will use /dev/null as stdin
|
||||||
|
if req.Stdin == nil || req.GetStdin() == true {
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
h.stdin = stdin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
outWg.Wait()
|
||||||
|
|
||||||
|
close(outMultiplex.Source)
|
||||||
|
|
||||||
|
outCancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
|
||||||
|
if req != nil && req.GetPty() != nil {
|
||||||
|
return cgroups.ProcessTypePTY
|
||||||
|
}
|
||||||
|
|
||||||
|
return cgroups.ProcessTypeUser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) SendSignal(signal syscall.Signal) error {
|
||||||
|
if p.cmd.Process == nil {
|
||||||
|
return fmt.Errorf("process not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
|
||||||
|
p.outCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.cmd.Process.Signal(signal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) ResizeTty(size *pty.Winsize) error {
|
||||||
|
if p.tty == nil {
|
||||||
|
return fmt.Errorf("tty not assigned to process")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pty.Setsize(p.tty, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) WriteStdin(data []byte) error {
|
||||||
|
if p.tty != nil {
|
||||||
|
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
|
||||||
|
}
|
||||||
|
|
||||||
|
p.stdinMu.Lock()
|
||||||
|
defer p.stdinMu.Unlock()
|
||||||
|
|
||||||
|
if p.stdin == nil {
|
||||||
|
return fmt.Errorf("stdin not enabled or closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.stdin.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseStdin closes the stdin pipe to signal EOF to the process.
|
||||||
|
// Only works for non-PTY processes.
|
||||||
|
func (p *Handler) CloseStdin() error {
|
||||||
|
if p.tty != nil {
|
||||||
|
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
p.stdinMu.Lock()
|
||||||
|
defer p.stdinMu.Unlock()
|
||||||
|
|
||||||
|
if p.stdin == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.stdin.Close()
|
||||||
|
// We still set the stdin to nil even on error as there are no errors,
|
||||||
|
// for which it is really safe to retry close across all distributions.
|
||||||
|
p.stdin = nil
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) WriteTty(data []byte) error {
|
||||||
|
if p.tty == nil {
|
||||||
|
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.tty.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) Start() (uint32, error) {
|
||||||
|
// Pty is already started in the New method
|
||||||
|
if p.tty == nil {
|
||||||
|
err := p.cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.
|
||||||
|
Info().
|
||||||
|
Str("event_type", "process_start").
|
||||||
|
Int("pid", p.cmd.Process.Pid).
|
||||||
|
Str("command", p.userCommand()).
|
||||||
|
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
|
||||||
|
|
||||||
|
return uint32(p.cmd.Process.Pid), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Handler) Wait() {
|
||||||
|
// Wait for the output pipes to be closed or cancelled.
|
||||||
|
<-p.outCtx.Done()
|
||||||
|
|
||||||
|
err := p.cmd.Wait()
|
||||||
|
|
||||||
|
p.tty.Close()
|
||||||
|
|
||||||
|
var errMsg *string
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
msg := err.Error()
|
||||||
|
errMsg = &msg
|
||||||
|
}
|
||||||
|
|
||||||
|
endEvent := &rpc.ProcessEvent_EndEvent{
|
||||||
|
Error: errMsg,
|
||||||
|
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
|
||||||
|
Exited: p.cmd.ProcessState.Exited(),
|
||||||
|
Status: p.cmd.ProcessState.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
event := rpc.ProcessEvent_End{
|
||||||
|
End: endEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.EndEvent.Source <- event
|
||||||
|
|
||||||
|
p.logger.
|
||||||
|
Info().
|
||||||
|
Str("event_type", "process_end").
|
||||||
|
Interface("process_result", endEvent).
|
||||||
|
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
|
||||||
|
|
||||||
|
// Ensure the process cancel is called to cleanup resources.
|
||||||
|
// As it is called after end event and Wait, it should not affect command execution or returned events.
|
||||||
|
p.cancel()
|
||||||
|
}
|
||||||
75
envd/internal/services/process/handler/multiplex.go
Normal file
75
envd/internal/services/process/handler/multiplex.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MultiplexedChannel[T any] struct {
|
||||||
|
Source chan T
|
||||||
|
channels []chan T
|
||||||
|
mu sync.RWMutex
|
||||||
|
exited atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
|
||||||
|
c := &MultiplexedChannel[T]{
|
||||||
|
channels: nil,
|
||||||
|
Source: make(chan T, buffer),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for v := range c.Source {
|
||||||
|
c.mu.RLock()
|
||||||
|
|
||||||
|
for _, cons := range c.channels {
|
||||||
|
cons <- v
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.exited.Store(true)
|
||||||
|
|
||||||
|
for _, cons := range c.channels {
|
||||||
|
close(cons)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
|
||||||
|
if m.exited.Load() {
|
||||||
|
ch := make(chan T)
|
||||||
|
close(ch)
|
||||||
|
|
||||||
|
return ch, func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
consumer := make(chan T)
|
||||||
|
|
||||||
|
m.channels = append(m.channels, consumer)
|
||||||
|
|
||||||
|
return consumer, func() {
|
||||||
|
m.remove(consumer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
for i, ch := range m.channels {
|
||||||
|
if ch == consumer {
|
||||||
|
m.channels = append(m.channels[:i], m.channels[i+1:]...)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
109
envd/internal/services/process/input.go
Normal file
109
envd/internal/services/process/input.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
|
||||||
|
switch in.GetInput().(type) {
|
||||||
|
case *rpc.ProcessInput_Pty:
|
||||||
|
err := process.WriteTty(in.GetPty())
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
case *rpc.ProcessInput_Stdin:
|
||||||
|
err := process.WriteStdin(in.GetStdin())
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug().
|
||||||
|
Str("event_type", "stdin").
|
||||||
|
Interface("stdin", in.GetStdin()).
|
||||||
|
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||||
|
Msg("Streaming input to process")
|
||||||
|
|
||||||
|
default:
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
|
||||||
|
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.SendInputResponse{}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||||
|
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
|
||||||
|
var proc *handler.Handler
|
||||||
|
|
||||||
|
for stream.Receive() {
|
||||||
|
req := stream.Msg()
|
||||||
|
|
||||||
|
switch req.GetEvent().(type) {
|
||||||
|
case *rpc.StreamInputRequest_Start:
|
||||||
|
p, err := s.getProcess(req.GetStart().GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
proc = p
|
||||||
|
case *rpc.StreamInputRequest_Data:
|
||||||
|
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case *rpc.StreamInputRequest_Keepalive:
|
||||||
|
default:
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := stream.Err()
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CloseStdin(
|
||||||
|
_ context.Context,
|
||||||
|
req *connect.Request[rpc.CloseStdinRequest],
|
||||||
|
) (*connect.Response[rpc.CloseStdinResponse], error) {
|
||||||
|
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handler.CloseStdin(); err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
|
||||||
|
}
|
||||||
30
envd/internal/services/process/list.go
Normal file
30
envd/internal/services/process/list.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
|
||||||
|
processes := make([]*rpc.ProcessInfo, 0)
|
||||||
|
|
||||||
|
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
|
||||||
|
processes = append(processes, &rpc.ProcessInfo{
|
||||||
|
Pid: pid,
|
||||||
|
Tag: value.Tag,
|
||||||
|
Config: value.Config,
|
||||||
|
})
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.ListResponse{
|
||||||
|
Processes: processes,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
86
envd/internal/services/process/service.go
Normal file
86
envd/internal/services/process/service.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
processes *utils.Map[uint32, *handler.Handler]
|
||||||
|
logger *zerolog.Logger
|
||||||
|
defaults *execcontext.Defaults
|
||||||
|
cgroupManager cgroups.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||||
|
return &Service{
|
||||||
|
logger: l,
|
||||||
|
processes: utils.NewMap[uint32, *handler.Handler](),
|
||||||
|
defaults: defaults,
|
||||||
|
cgroupManager: cgroupManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
|
||||||
|
service := newService(l, defaults, cgroupManager)
|
||||||
|
|
||||||
|
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
|
||||||
|
|
||||||
|
path, h := spec.NewProcessHandler(service, interceptors)
|
||||||
|
|
||||||
|
server.Mount(path, h)
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
|
||||||
|
var proc *handler.Handler
|
||||||
|
|
||||||
|
switch selector.GetSelector().(type) {
|
||||||
|
case *rpc.ProcessSelector_Pid:
|
||||||
|
p, ok := s.processes.Load(selector.GetPid())
|
||||||
|
if !ok {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
|
||||||
|
}
|
||||||
|
|
||||||
|
proc = p
|
||||||
|
case *rpc.ProcessSelector_Tag:
|
||||||
|
tag := selector.GetTag()
|
||||||
|
|
||||||
|
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
|
||||||
|
if value.Tag == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if *value.Tag == tag {
|
||||||
|
proc = value
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if proc == nil {
|
||||||
|
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
|
||||||
|
}
|
||||||
|
|
||||||
|
return proc, nil
|
||||||
|
}
|
||||||
40
envd/internal/services/process/signal.go
Normal file
40
envd/internal/services/process/signal.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) SendSignal(
|
||||||
|
_ context.Context,
|
||||||
|
req *connect.Request[rpc.SendSignalRequest],
|
||||||
|
) (*connect.Response[rpc.SendSignalResponse], error) {
|
||||||
|
handler, err := s.getProcess(req.Msg.GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var signal syscall.Signal
|
||||||
|
switch req.Msg.GetSignal() {
|
||||||
|
case rpc.Signal_SIGNAL_SIGKILL:
|
||||||
|
signal = syscall.SIGKILL
|
||||||
|
case rpc.Signal_SIGNAL_SIGTERM:
|
||||||
|
signal = syscall.SIGTERM
|
||||||
|
default:
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = handler.SendSignal(signal)
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
|
||||||
|
}
|
||||||
249
envd/internal/services/process/start.go
Normal file
249
envd/internal/services/process/start.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os/user"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ctx = logs.AddRequestIDToContext(ctx)
|
||||||
|
|
||||||
|
defer s.logger.
|
||||||
|
Err(err).
|
||||||
|
Interface("request", req).
|
||||||
|
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
|
||||||
|
Msg("Initialized startCmd")
|
||||||
|
|
||||||
|
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||||
|
|
||||||
|
startProcCtx, startProcCancel := context.WithCancel(ctx)
|
||||||
|
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := proc.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.processes.Store(pid, proc)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer s.processes.Delete(pid)
|
||||||
|
|
||||||
|
proc.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||||
|
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
|
||||||
|
ctx, cancel := context.WithCancelCause(ctx)
|
||||||
|
defer cancel(nil)
|
||||||
|
|
||||||
|
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
|
||||||
|
|
||||||
|
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new context with a timeout if provided.
|
||||||
|
// We do not want the command to be killed if the request context is cancelled
|
||||||
|
procCtx, cancelProc := context.Background(), func() {}
|
||||||
|
if timeout > 0 { // zero timeout means no timeout
|
||||||
|
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
|
||||||
|
procCtx,
|
||||||
|
u,
|
||||||
|
req.Msg,
|
||||||
|
&handlerL,
|
||||||
|
s.defaults,
|
||||||
|
s.cgroupManager,
|
||||||
|
cancelProc,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// Ensure the process cancel is called to cleanup resources.
|
||||||
|
cancelProc()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
|
||||||
|
defer close(startMultiplexer.Source)
|
||||||
|
|
||||||
|
start, startCancel := startMultiplexer.Fork()
|
||||||
|
defer startCancel()
|
||||||
|
|
||||||
|
data, dataCancel := proc.DataEvent.Fork()
|
||||||
|
defer dataCancel()
|
||||||
|
|
||||||
|
end, endCancel := proc.EndEvent.Fork()
|
||||||
|
defer endCancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel(ctx.Err())
|
||||||
|
|
||||||
|
return
|
||||||
|
case event, ok := <-start:
|
||||||
|
if !ok {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.StartResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &event,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
|
||||||
|
dataLoop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-keepaliveTicker.C:
|
||||||
|
streamErr := stream.Send(&rpc.StartResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &rpc.ProcessEvent_Keepalive{
|
||||||
|
Keepalive: &rpc.ProcessEvent_KeepAlive{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel(ctx.Err())
|
||||||
|
|
||||||
|
return
|
||||||
|
case event, ok := <-data:
|
||||||
|
if !ok {
|
||||||
|
break dataLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.StartResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &event,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resetKeepalive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel(ctx.Err())
|
||||||
|
|
||||||
|
return
|
||||||
|
case event, ok := <-end:
|
||||||
|
if !ok {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamErr := stream.Send(&rpc.StartResponse{
|
||||||
|
Event: &rpc.ProcessEvent{
|
||||||
|
Event: &event,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if streamErr != nil {
|
||||||
|
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
pid, err := proc.Start()
|
||||||
|
if err != nil {
|
||||||
|
return connect.NewError(connect.CodeInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.processes.Store(pid, proc)
|
||||||
|
|
||||||
|
start <- rpc.ProcessEvent_Start{
|
||||||
|
Start: &rpc.ProcessEvent_StartEvent{
|
||||||
|
Pid: pid,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer s.processes.Delete(pid)
|
||||||
|
|
||||||
|
proc.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-exitChan:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
|
||||||
|
timeoutHeader := header.Get("Connect-Timeout-Ms")
|
||||||
|
|
||||||
|
if timeoutHeader == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout, err := strconv.Atoi(timeoutHeader)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Duration(timeout) * time.Millisecond, nil
|
||||||
|
}
|
||||||
32
envd/internal/services/process/update.go
Normal file
32
envd/internal/services/process/update.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"github.com/creack/pty"
|
||||||
|
|
||||||
|
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
|
||||||
|
proc, err := s.getProcess(req.Msg.GetProcess())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Msg.GetPty() != nil {
|
||||||
|
err := proc.ResizeTty(&pty.Winsize{
|
||||||
|
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
|
||||||
|
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect.NewResponse(&rpc.UpdateResponse{}), nil
|
||||||
|
}
|
||||||
1446
envd/internal/services/spec/filesystem.pb.go
Normal file
1446
envd/internal/services/spec/filesystem.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
1444
envd/internal/services/spec/filesystem/filesystem.pb.go
Normal file
1444
envd/internal/services/spec/filesystem/filesystem.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,337 @@
|
|||||||
|
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// Source: filesystem/filesystem.proto
|
||||||
|
|
||||||
|
package filesystemconnect
|
||||||
|
|
||||||
|
import (
|
||||||
|
connect "connectrpc.com/connect"
|
||||||
|
context "context"
|
||||||
|
errors "errors"
|
||||||
|
filesystem "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
|
||||||
|
http "net/http"
|
||||||
|
strings "strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||||
|
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||||
|
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||||
|
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||||
|
// version compiled into your binary.
|
||||||
|
const _ = connect.IsAtLeastVersion1_13_0
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FilesystemName is the fully-qualified name of the Filesystem service.
|
||||||
|
FilesystemName = "filesystem.Filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||||
|
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||||
|
//
|
||||||
|
// Note that these are different from the fully-qualified method names used by
|
||||||
|
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||||
|
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||||
|
// period.
|
||||||
|
const (
|
||||||
|
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
|
||||||
|
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
|
||||||
|
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
|
||||||
|
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
|
||||||
|
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
|
||||||
|
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
|
||||||
|
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
|
||||||
|
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
|
||||||
|
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
|
||||||
|
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
|
||||||
|
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
|
||||||
|
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
|
||||||
|
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
|
||||||
|
// RPC.
|
||||||
|
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
|
||||||
|
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
|
||||||
|
// GetWatcherEvents RPC.
|
||||||
|
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
|
||||||
|
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
|
||||||
|
// RPC.
|
||||||
|
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FilesystemClient is a client for the filesystem.Filesystem service.
|
||||||
|
type FilesystemClient interface {
|
||||||
|
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
|
||||||
|
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
|
||||||
|
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
|
||||||
|
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
|
||||||
|
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
|
||||||
|
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error)
|
||||||
|
// Non-streaming versions of WatchDir
|
||||||
|
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
|
||||||
|
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
|
||||||
|
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
|
||||||
|
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||||
|
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||||
|
// connect.WithGRPCWeb() options.
|
||||||
|
//
|
||||||
|
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||||
|
// http://api.acme.com or https://acme.com/grpc).
|
||||||
|
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
|
||||||
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
|
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||||
|
return &filesystemClient{
|
||||||
|
stat: connect.NewClient[filesystem.StatRequest, filesystem.StatResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemStatProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
makeDir: connect.NewClient[filesystem.MakeDirRequest, filesystem.MakeDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemMakeDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
move: connect.NewClient[filesystem.MoveRequest, filesystem.MoveResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemMoveProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
listDir: connect.NewClient[filesystem.ListDirRequest, filesystem.ListDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemListDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
remove: connect.NewClient[filesystem.RemoveRequest, filesystem.RemoveResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemRemoveProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
watchDir: connect.NewClient[filesystem.WatchDirRequest, filesystem.WatchDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemWatchDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
createWatcher: connect.NewClient[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemCreateWatcherProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
getWatcherEvents: connect.NewClient[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemGetWatcherEventsProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
removeWatcher: connect.NewClient[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemRemoveWatcherProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filesystemClient implements FilesystemClient.
|
||||||
|
type filesystemClient struct {
|
||||||
|
stat *connect.Client[filesystem.StatRequest, filesystem.StatResponse]
|
||||||
|
makeDir *connect.Client[filesystem.MakeDirRequest, filesystem.MakeDirResponse]
|
||||||
|
move *connect.Client[filesystem.MoveRequest, filesystem.MoveResponse]
|
||||||
|
listDir *connect.Client[filesystem.ListDirRequest, filesystem.ListDirResponse]
|
||||||
|
remove *connect.Client[filesystem.RemoveRequest, filesystem.RemoveResponse]
|
||||||
|
watchDir *connect.Client[filesystem.WatchDirRequest, filesystem.WatchDirResponse]
|
||||||
|
createWatcher *connect.Client[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse]
|
||||||
|
getWatcherEvents *connect.Client[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse]
|
||||||
|
removeWatcher *connect.Client[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stat calls filesystem.Filesystem.Stat.
|
||||||
|
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
|
||||||
|
return c.stat.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeDir calls filesystem.Filesystem.MakeDir.
|
||||||
|
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
|
||||||
|
return c.makeDir.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move calls filesystem.Filesystem.Move.
|
||||||
|
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
|
||||||
|
return c.move.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListDir calls filesystem.Filesystem.ListDir.
|
||||||
|
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
|
||||||
|
return c.listDir.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove calls filesystem.Filesystem.Remove.
|
||||||
|
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
|
||||||
|
return c.remove.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchDir calls filesystem.Filesystem.WatchDir.
|
||||||
|
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error) {
|
||||||
|
return c.watchDir.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
|
||||||
|
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
|
||||||
|
return c.createWatcher.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
|
||||||
|
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
|
||||||
|
return c.getWatcherEvents.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
|
||||||
|
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
|
||||||
|
return c.removeWatcher.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
|
||||||
|
type FilesystemHandler interface {
|
||||||
|
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
|
||||||
|
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
|
||||||
|
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
|
||||||
|
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
|
||||||
|
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
|
||||||
|
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error
|
||||||
|
// Non-streaming versions of WatchDir
|
||||||
|
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
|
||||||
|
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
|
||||||
|
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
|
||||||
|
// on which to mount the handler and the handler itself.
|
||||||
|
//
|
||||||
|
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||||
|
// and JSON codecs. They also support gzip compression.
|
||||||
|
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||||
|
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||||
|
filesystemStatHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemStatProcedure,
|
||||||
|
svc.Stat,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemMakeDirHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemMakeDirProcedure,
|
||||||
|
svc.MakeDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemMoveHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemMoveProcedure,
|
||||||
|
svc.Move,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemListDirHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemListDirProcedure,
|
||||||
|
svc.ListDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemRemoveHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemRemoveProcedure,
|
||||||
|
svc.Remove,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemWatchDirHandler := connect.NewServerStreamHandler(
|
||||||
|
FilesystemWatchDirProcedure,
|
||||||
|
svc.WatchDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemCreateWatcherProcedure,
|
||||||
|
svc.CreateWatcher,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemGetWatcherEventsProcedure,
|
||||||
|
svc.GetWatcherEvents,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemRemoveWatcherProcedure,
|
||||||
|
svc.RemoveWatcher,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case FilesystemStatProcedure:
|
||||||
|
filesystemStatHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemMakeDirProcedure:
|
||||||
|
filesystemMakeDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemMoveProcedure:
|
||||||
|
filesystemMoveHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemListDirProcedure:
|
||||||
|
filesystemListDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemRemoveProcedure:
|
||||||
|
filesystemRemoveHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemWatchDirProcedure:
|
||||||
|
filesystemWatchDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemCreateWatcherProcedure:
|
||||||
|
filesystemCreateWatcherHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemGetWatcherEventsProcedure:
|
||||||
|
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemRemoveWatcherProcedure:
|
||||||
|
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
|
||||||
|
type UnimplementedFilesystemHandler struct{}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
|
||||||
|
}
|
||||||
1972
envd/internal/services/spec/process.pb.go
Normal file
1972
envd/internal/services/spec/process.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
1970
envd/internal/services/spec/process/process.pb.go
Normal file
1970
envd/internal/services/spec/process/process.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,310 @@
|
|||||||
|
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// Source: process/process.proto
|
||||||
|
|
||||||
|
package processconnect
|
||||||
|
|
||||||
|
import (
|
||||||
|
connect "connectrpc.com/connect"
|
||||||
|
context "context"
|
||||||
|
errors "errors"
|
||||||
|
process "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
|
||||||
|
http "net/http"
|
||||||
|
strings "strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||||
|
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||||
|
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||||
|
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||||
|
// version compiled into your binary.
|
||||||
|
const _ = connect.IsAtLeastVersion1_13_0
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProcessName is the fully-qualified name of the Process service.
|
||||||
|
ProcessName = "process.Process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||||
|
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||||
|
//
|
||||||
|
// Note that these are different from the fully-qualified method names used by
|
||||||
|
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||||
|
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||||
|
// period.
|
||||||
|
const (
|
||||||
|
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
|
||||||
|
ProcessListProcedure = "/process.Process/List"
|
||||||
|
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
|
||||||
|
ProcessConnectProcedure = "/process.Process/Connect"
|
||||||
|
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
|
||||||
|
ProcessStartProcedure = "/process.Process/Start"
|
||||||
|
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
|
||||||
|
ProcessUpdateProcedure = "/process.Process/Update"
|
||||||
|
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
|
||||||
|
ProcessStreamInputProcedure = "/process.Process/StreamInput"
|
||||||
|
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
|
||||||
|
ProcessSendInputProcedure = "/process.Process/SendInput"
|
||||||
|
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
|
||||||
|
ProcessSendSignalProcedure = "/process.Process/SendSignal"
|
||||||
|
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
|
||||||
|
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessClient is a client for the process.Process service.
|
||||||
|
type ProcessClient interface {
|
||||||
|
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
|
||||||
|
Connect(context.Context, *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error)
|
||||||
|
Start(context.Context, *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error)
|
||||||
|
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
|
||||||
|
// Client input stream ensures ordering of messages
|
||||||
|
StreamInput(context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse]
|
||||||
|
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
|
||||||
|
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
|
||||||
|
// Close stdin to signal EOF to the process.
|
||||||
|
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||||
|
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
|
||||||
|
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||||
|
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||||
|
// connect.WithGRPCWeb() options.
|
||||||
|
//
|
||||||
|
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||||
|
// http://api.acme.com or https://acme.com/grpc).
|
||||||
|
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
|
||||||
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
|
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
|
||||||
|
return &processClient{
|
||||||
|
list: connect.NewClient[process.ListRequest, process.ListResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessListProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("List")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
connect: connect.NewClient[process.ConnectRequest, process.ConnectResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessConnectProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Connect")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
start: connect.NewClient[process.StartRequest, process.StartResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessStartProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Start")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
update: connect.NewClient[process.UpdateRequest, process.UpdateResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessUpdateProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Update")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
streamInput: connect.NewClient[process.StreamInputRequest, process.StreamInputResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessStreamInputProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
sendInput: connect.NewClient[process.SendInputRequest, process.SendInputResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessSendInputProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
sendSignal: connect.NewClient[process.SendSignalRequest, process.SendSignalResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessSendSignalProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
closeStdin: connect.NewClient[process.CloseStdinRequest, process.CloseStdinResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessCloseStdinProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClient implements ProcessClient.
|
||||||
|
type processClient struct {
|
||||||
|
list *connect.Client[process.ListRequest, process.ListResponse]
|
||||||
|
connect *connect.Client[process.ConnectRequest, process.ConnectResponse]
|
||||||
|
start *connect.Client[process.StartRequest, process.StartResponse]
|
||||||
|
update *connect.Client[process.UpdateRequest, process.UpdateResponse]
|
||||||
|
streamInput *connect.Client[process.StreamInputRequest, process.StreamInputResponse]
|
||||||
|
sendInput *connect.Client[process.SendInputRequest, process.SendInputResponse]
|
||||||
|
sendSignal *connect.Client[process.SendSignalRequest, process.SendSignalResponse]
|
||||||
|
closeStdin *connect.Client[process.CloseStdinRequest, process.CloseStdinResponse]
|
||||||
|
}
|
||||||
|
|
||||||
|
// List calls process.Process.List.
|
||||||
|
func (c *processClient) List(ctx context.Context, req *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
|
||||||
|
return c.list.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect calls process.Process.Connect.
|
||||||
|
func (c *processClient) Connect(ctx context.Context, req *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error) {
|
||||||
|
return c.connect.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start calls process.Process.Start.
|
||||||
|
func (c *processClient) Start(ctx context.Context, req *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error) {
|
||||||
|
return c.start.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update calls process.Process.Update.
|
||||||
|
func (c *processClient) Update(ctx context.Context, req *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
|
||||||
|
return c.update.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamInput calls process.Process.StreamInput.
|
||||||
|
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse] {
|
||||||
|
return c.streamInput.CallClientStream(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendInput calls process.Process.SendInput.
|
||||||
|
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
|
||||||
|
return c.sendInput.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendSignal calls process.Process.SendSignal.
|
||||||
|
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
|
||||||
|
return c.sendSignal.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseStdin calls process.Process.CloseStdin.
|
||||||
|
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
|
||||||
|
return c.closeStdin.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessHandler is an implementation of the process.Process service.
|
||||||
|
type ProcessHandler interface {
|
||||||
|
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
|
||||||
|
Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error
|
||||||
|
Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error
|
||||||
|
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
|
||||||
|
// Client input stream ensures ordering of messages
|
||||||
|
StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error)
|
||||||
|
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
|
||||||
|
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
|
||||||
|
// Close stdin to signal EOF to the process.
|
||||||
|
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||||
|
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
|
||||||
|
// which to mount the handler and the handler itself.
|
||||||
|
//
|
||||||
|
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||||
|
// and JSON codecs. They also support gzip compression.
|
||||||
|
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||||
|
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
|
||||||
|
processListHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessListProcedure,
|
||||||
|
svc.List,
|
||||||
|
connect.WithSchema(processMethods.ByName("List")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processConnectHandler := connect.NewServerStreamHandler(
|
||||||
|
ProcessConnectProcedure,
|
||||||
|
svc.Connect,
|
||||||
|
connect.WithSchema(processMethods.ByName("Connect")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processStartHandler := connect.NewServerStreamHandler(
|
||||||
|
ProcessStartProcedure,
|
||||||
|
svc.Start,
|
||||||
|
connect.WithSchema(processMethods.ByName("Start")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processUpdateHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessUpdateProcedure,
|
||||||
|
svc.Update,
|
||||||
|
connect.WithSchema(processMethods.ByName("Update")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processStreamInputHandler := connect.NewClientStreamHandler(
|
||||||
|
ProcessStreamInputProcedure,
|
||||||
|
svc.StreamInput,
|
||||||
|
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processSendInputHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessSendInputProcedure,
|
||||||
|
svc.SendInput,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processSendSignalHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessSendSignalProcedure,
|
||||||
|
svc.SendSignal,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processCloseStdinHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessCloseStdinProcedure,
|
||||||
|
svc.CloseStdin,
|
||||||
|
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case ProcessListProcedure:
|
||||||
|
processListHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessConnectProcedure:
|
||||||
|
processConnectHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessStartProcedure:
|
||||||
|
processStartHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessUpdateProcedure:
|
||||||
|
processUpdateHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessStreamInputProcedure:
|
||||||
|
processStreamInputHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessSendInputProcedure:
|
||||||
|
processSendInputHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessSendSignalProcedure:
|
||||||
|
processSendSignalHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessCloseStdinProcedure:
|
||||||
|
processCloseStdinHandler.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
|
||||||
|
type UnimplementedProcessHandler struct{}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
|
||||||
|
}
|
||||||
339
envd/internal/services/spec/specconnect/filesystem.connect.go
Normal file
339
envd/internal/services/spec/specconnect/filesystem.connect.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// Source: filesystem.proto
|
||||||
|
|
||||||
|
package specconnect
|
||||||
|
|
||||||
|
import (
|
||||||
|
connect "connectrpc.com/connect"
|
||||||
|
context "context"
|
||||||
|
errors "errors"
|
||||||
|
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec"
|
||||||
|
http "net/http"
|
||||||
|
strings "strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||||
|
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||||
|
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||||
|
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||||
|
// version compiled into your binary.
|
||||||
|
const _ = connect.IsAtLeastVersion1_13_0
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FilesystemName is the fully-qualified name of the Filesystem service.
|
||||||
|
FilesystemName = "filesystem.Filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||||
|
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||||
|
//
|
||||||
|
// Note that these are different from the fully-qualified method names used by
|
||||||
|
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||||
|
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||||
|
// period.
|
||||||
|
const (
|
||||||
|
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
|
||||||
|
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
|
||||||
|
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
|
||||||
|
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
|
||||||
|
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
|
||||||
|
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
|
||||||
|
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
|
||||||
|
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
|
||||||
|
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
|
||||||
|
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
|
||||||
|
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
|
||||||
|
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
|
||||||
|
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
|
||||||
|
// RPC.
|
||||||
|
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
|
||||||
|
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
|
||||||
|
// GetWatcherEvents RPC.
|
||||||
|
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
|
||||||
|
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
|
||||||
|
// RPC.
|
||||||
|
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FilesystemClient is a client for the filesystem.Filesystem service.
|
||||||
|
type FilesystemClient interface {
|
||||||
|
Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error)
|
||||||
|
MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error)
|
||||||
|
Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error)
|
||||||
|
ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error)
|
||||||
|
Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error)
|
||||||
|
WatchDir(context.Context, *connect.Request[spec.WatchDirRequest]) (*connect.ServerStreamForClient[spec.WatchDirResponse], error)
|
||||||
|
// Non-streaming versions of WatchDir
|
||||||
|
CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error)
|
||||||
|
GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error)
|
||||||
|
RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
|
||||||
|
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||||
|
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||||
|
// connect.WithGRPCWeb() options.
|
||||||
|
//
|
||||||
|
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||||
|
// http://api.acme.com or https://acme.com/grpc).
|
||||||
|
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
|
||||||
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
|
filesystemMethods := spec.File_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||||
|
return &filesystemClient{
|
||||||
|
stat: connect.NewClient[spec.StatRequest, spec.StatResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemStatProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
makeDir: connect.NewClient[spec.MakeDirRequest, spec.MakeDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemMakeDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
move: connect.NewClient[spec.MoveRequest, spec.MoveResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemMoveProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
listDir: connect.NewClient[spec.ListDirRequest, spec.ListDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemListDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
remove: connect.NewClient[spec.RemoveRequest, spec.RemoveResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemRemoveProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
watchDir: connect.NewClient[spec.WatchDirRequest, spec.WatchDirResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemWatchDirProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
createWatcher: connect.NewClient[spec.CreateWatcherRequest, spec.CreateWatcherResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemCreateWatcherProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
getWatcherEvents: connect.NewClient[spec.GetWatcherEventsRequest, spec.GetWatcherEventsResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemGetWatcherEventsProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
removeWatcher: connect.NewClient[spec.RemoveWatcherRequest, spec.RemoveWatcherResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+FilesystemRemoveWatcherProcedure,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filesystemClient implements FilesystemClient.
|
||||||
|
type filesystemClient struct {
|
||||||
|
stat *connect.Client[spec.StatRequest, spec.StatResponse]
|
||||||
|
makeDir *connect.Client[spec.MakeDirRequest, spec.MakeDirResponse]
|
||||||
|
move *connect.Client[spec.MoveRequest, spec.MoveResponse]
|
||||||
|
listDir *connect.Client[spec.ListDirRequest, spec.ListDirResponse]
|
||||||
|
remove *connect.Client[spec.RemoveRequest, spec.RemoveResponse]
|
||||||
|
watchDir *connect.Client[spec.WatchDirRequest, spec.WatchDirResponse]
|
||||||
|
createWatcher *connect.Client[spec.CreateWatcherRequest, spec.CreateWatcherResponse]
|
||||||
|
getWatcherEvents *connect.Client[spec.GetWatcherEventsRequest, spec.GetWatcherEventsResponse]
|
||||||
|
removeWatcher *connect.Client[spec.RemoveWatcherRequest, spec.RemoveWatcherResponse]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stat calls filesystem.Filesystem.Stat.
|
||||||
|
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error) {
|
||||||
|
return c.stat.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeDir calls filesystem.Filesystem.MakeDir.
|
||||||
|
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error) {
|
||||||
|
return c.makeDir.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move calls filesystem.Filesystem.Move.
|
||||||
|
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error) {
|
||||||
|
return c.move.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListDir calls filesystem.Filesystem.ListDir.
|
||||||
|
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error) {
|
||||||
|
return c.listDir.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove calls filesystem.Filesystem.Remove.
|
||||||
|
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error) {
|
||||||
|
return c.remove.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchDir calls filesystem.Filesystem.WatchDir.
|
||||||
|
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[spec.WatchDirRequest]) (*connect.ServerStreamForClient[spec.WatchDirResponse], error) {
|
||||||
|
return c.watchDir.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
|
||||||
|
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error) {
|
||||||
|
return c.createWatcher.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
|
||||||
|
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error) {
|
||||||
|
return c.getWatcherEvents.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
|
||||||
|
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error) {
|
||||||
|
return c.removeWatcher.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
|
||||||
|
type FilesystemHandler interface {
|
||||||
|
Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error)
|
||||||
|
MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error)
|
||||||
|
Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error)
|
||||||
|
ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error)
|
||||||
|
Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error)
|
||||||
|
WatchDir(context.Context, *connect.Request[spec.WatchDirRequest], *connect.ServerStream[spec.WatchDirResponse]) error
|
||||||
|
// Non-streaming versions of WatchDir
|
||||||
|
CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error)
|
||||||
|
GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error)
|
||||||
|
RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
|
||||||
|
// on which to mount the handler and the handler itself.
|
||||||
|
//
|
||||||
|
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||||
|
// and JSON codecs. They also support gzip compression.
|
||||||
|
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||||
|
filesystemMethods := spec.File_filesystem_proto.Services().ByName("Filesystem").Methods()
|
||||||
|
filesystemStatHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemStatProcedure,
|
||||||
|
svc.Stat,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Stat")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemMakeDirHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemMakeDirProcedure,
|
||||||
|
svc.MakeDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemMoveHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemMoveProcedure,
|
||||||
|
svc.Move,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Move")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemListDirHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemListDirProcedure,
|
||||||
|
svc.ListDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("ListDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemRemoveHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemRemoveProcedure,
|
||||||
|
svc.Remove,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("Remove")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemWatchDirHandler := connect.NewServerStreamHandler(
|
||||||
|
FilesystemWatchDirProcedure,
|
||||||
|
svc.WatchDir,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemCreateWatcherProcedure,
|
||||||
|
svc.CreateWatcher,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemGetWatcherEventsProcedure,
|
||||||
|
svc.GetWatcherEvents,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
|
||||||
|
FilesystemRemoveWatcherProcedure,
|
||||||
|
svc.RemoveWatcher,
|
||||||
|
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case FilesystemStatProcedure:
|
||||||
|
filesystemStatHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemMakeDirProcedure:
|
||||||
|
filesystemMakeDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemMoveProcedure:
|
||||||
|
filesystemMoveHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemListDirProcedure:
|
||||||
|
filesystemListDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemRemoveProcedure:
|
||||||
|
filesystemRemoveHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemWatchDirProcedure:
|
||||||
|
filesystemWatchDirHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemCreateWatcherProcedure:
|
||||||
|
filesystemCreateWatcherHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemGetWatcherEventsProcedure:
|
||||||
|
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
|
||||||
|
case FilesystemRemoveWatcherProcedure:
|
||||||
|
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
|
||||||
|
type UnimplementedFilesystemHandler struct{}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[spec.StatRequest]) (*connect.Response[spec.StatResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[spec.MakeDirRequest]) (*connect.Response[spec.MakeDirResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[spec.MoveRequest]) (*connect.Response[spec.MoveResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[spec.ListDirRequest]) (*connect.Response[spec.ListDirResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[spec.RemoveRequest]) (*connect.Response[spec.RemoveResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[spec.WatchDirRequest], *connect.ServerStream[spec.WatchDirResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[spec.CreateWatcherRequest]) (*connect.Response[spec.CreateWatcherResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[spec.GetWatcherEventsRequest]) (*connect.Response[spec.GetWatcherEventsResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[spec.RemoveWatcherRequest]) (*connect.Response[spec.RemoveWatcherResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
|
||||||
|
}
|
||||||
312
envd/internal/services/spec/specconnect/process.connect.go
Normal file
312
envd/internal/services/spec/specconnect/process.connect.go
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// Source: process.proto
|
||||||
|
|
||||||
|
package specconnect
|
||||||
|
|
||||||
|
import (
|
||||||
|
connect "connectrpc.com/connect"
|
||||||
|
context "context"
|
||||||
|
errors "errors"
|
||||||
|
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec"
|
||||||
|
http "net/http"
|
||||||
|
strings "strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||||
|
// compatible. If you get a compiler error that this constant is not defined, this code was
|
||||||
|
// generated with a version of connect newer than the one compiled into your binary. You can fix the
|
||||||
|
// problem by either regenerating this code with an older version of connect or updating the connect
|
||||||
|
// version compiled into your binary.
|
||||||
|
const _ = connect.IsAtLeastVersion1_13_0
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProcessName is the fully-qualified name of the Process service.
|
||||||
|
ProcessName = "process.Process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the fully-qualified names of the RPCs defined in this package. They're
|
||||||
|
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
|
||||||
|
//
|
||||||
|
// Note that these are different from the fully-qualified method names used by
|
||||||
|
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
|
||||||
|
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
|
||||||
|
// period.
|
||||||
|
const (
|
||||||
|
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
|
||||||
|
ProcessListProcedure = "/process.Process/List"
|
||||||
|
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
|
||||||
|
ProcessConnectProcedure = "/process.Process/Connect"
|
||||||
|
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
|
||||||
|
ProcessStartProcedure = "/process.Process/Start"
|
||||||
|
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
|
||||||
|
ProcessUpdateProcedure = "/process.Process/Update"
|
||||||
|
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
|
||||||
|
ProcessStreamInputProcedure = "/process.Process/StreamInput"
|
||||||
|
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
|
||||||
|
ProcessSendInputProcedure = "/process.Process/SendInput"
|
||||||
|
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
|
||||||
|
ProcessSendSignalProcedure = "/process.Process/SendSignal"
|
||||||
|
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
|
||||||
|
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessClient is a client for the process.Process service.
|
||||||
|
type ProcessClient interface {
|
||||||
|
List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error)
|
||||||
|
Connect(context.Context, *connect.Request[spec.ConnectRequest]) (*connect.ServerStreamForClient[spec.ConnectResponse], error)
|
||||||
|
Start(context.Context, *connect.Request[spec.StartRequest]) (*connect.ServerStreamForClient[spec.StartResponse], error)
|
||||||
|
Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error)
|
||||||
|
// Client input stream ensures ordering of messages
|
||||||
|
StreamInput(context.Context) *connect.ClientStreamForClient[spec.StreamInputRequest, spec.StreamInputResponse]
|
||||||
|
SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error)
|
||||||
|
SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error)
|
||||||
|
// Close stdin to signal EOF to the process.
|
||||||
|
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||||
|
CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
|
||||||
|
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
|
||||||
|
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
|
||||||
|
// connect.WithGRPCWeb() options.
|
||||||
|
//
|
||||||
|
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
|
||||||
|
// http://api.acme.com or https://acme.com/grpc).
|
||||||
|
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
|
||||||
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
|
processMethods := spec.File_process_proto.Services().ByName("Process").Methods()
|
||||||
|
return &processClient{
|
||||||
|
list: connect.NewClient[spec.ListRequest, spec.ListResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessListProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("List")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
connect: connect.NewClient[spec.ConnectRequest, spec.ConnectResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessConnectProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Connect")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
start: connect.NewClient[spec.StartRequest, spec.StartResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessStartProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Start")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
update: connect.NewClient[spec.UpdateRequest, spec.UpdateResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessUpdateProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("Update")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
streamInput: connect.NewClient[spec.StreamInputRequest, spec.StreamInputResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessStreamInputProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
sendInput: connect.NewClient[spec.SendInputRequest, spec.SendInputResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessSendInputProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
sendSignal: connect.NewClient[spec.SendSignalRequest, spec.SendSignalResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessSendSignalProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
closeStdin: connect.NewClient[spec.CloseStdinRequest, spec.CloseStdinResponse](
|
||||||
|
httpClient,
|
||||||
|
baseURL+ProcessCloseStdinProcedure,
|
||||||
|
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||||
|
connect.WithClientOptions(opts...),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClient implements ProcessClient.
|
||||||
|
type processClient struct {
|
||||||
|
list *connect.Client[spec.ListRequest, spec.ListResponse]
|
||||||
|
connect *connect.Client[spec.ConnectRequest, spec.ConnectResponse]
|
||||||
|
start *connect.Client[spec.StartRequest, spec.StartResponse]
|
||||||
|
update *connect.Client[spec.UpdateRequest, spec.UpdateResponse]
|
||||||
|
streamInput *connect.Client[spec.StreamInputRequest, spec.StreamInputResponse]
|
||||||
|
sendInput *connect.Client[spec.SendInputRequest, spec.SendInputResponse]
|
||||||
|
sendSignal *connect.Client[spec.SendSignalRequest, spec.SendSignalResponse]
|
||||||
|
closeStdin *connect.Client[spec.CloseStdinRequest, spec.CloseStdinResponse]
|
||||||
|
}
|
||||||
|
|
||||||
|
// List calls process.Process.List.
|
||||||
|
func (c *processClient) List(ctx context.Context, req *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error) {
|
||||||
|
return c.list.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect calls process.Process.Connect.
|
||||||
|
func (c *processClient) Connect(ctx context.Context, req *connect.Request[spec.ConnectRequest]) (*connect.ServerStreamForClient[spec.ConnectResponse], error) {
|
||||||
|
return c.connect.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start calls process.Process.Start.
|
||||||
|
func (c *processClient) Start(ctx context.Context, req *connect.Request[spec.StartRequest]) (*connect.ServerStreamForClient[spec.StartResponse], error) {
|
||||||
|
return c.start.CallServerStream(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update calls process.Process.Update.
|
||||||
|
func (c *processClient) Update(ctx context.Context, req *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error) {
|
||||||
|
return c.update.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamInput calls process.Process.StreamInput.
|
||||||
|
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[spec.StreamInputRequest, spec.StreamInputResponse] {
|
||||||
|
return c.streamInput.CallClientStream(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendInput calls process.Process.SendInput.
|
||||||
|
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error) {
|
||||||
|
return c.sendInput.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendSignal calls process.Process.SendSignal.
|
||||||
|
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error) {
|
||||||
|
return c.sendSignal.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseStdin calls process.Process.CloseStdin.
|
||||||
|
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error) {
|
||||||
|
return c.closeStdin.CallUnary(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessHandler is an implementation of the process.Process service.
|
||||||
|
type ProcessHandler interface {
|
||||||
|
List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error)
|
||||||
|
Connect(context.Context, *connect.Request[spec.ConnectRequest], *connect.ServerStream[spec.ConnectResponse]) error
|
||||||
|
Start(context.Context, *connect.Request[spec.StartRequest], *connect.ServerStream[spec.StartResponse]) error
|
||||||
|
Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error)
|
||||||
|
// Client input stream ensures ordering of messages
|
||||||
|
StreamInput(context.Context, *connect.ClientStream[spec.StreamInputRequest]) (*connect.Response[spec.StreamInputResponse], error)
|
||||||
|
SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error)
|
||||||
|
SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error)
|
||||||
|
// Close stdin to signal EOF to the process.
|
||||||
|
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
|
||||||
|
CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
|
||||||
|
// which to mount the handler and the handler itself.
|
||||||
|
//
|
||||||
|
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
|
||||||
|
// and JSON codecs. They also support gzip compression.
|
||||||
|
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
|
||||||
|
processMethods := spec.File_process_proto.Services().ByName("Process").Methods()
|
||||||
|
processListHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessListProcedure,
|
||||||
|
svc.List,
|
||||||
|
connect.WithSchema(processMethods.ByName("List")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processConnectHandler := connect.NewServerStreamHandler(
|
||||||
|
ProcessConnectProcedure,
|
||||||
|
svc.Connect,
|
||||||
|
connect.WithSchema(processMethods.ByName("Connect")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processStartHandler := connect.NewServerStreamHandler(
|
||||||
|
ProcessStartProcedure,
|
||||||
|
svc.Start,
|
||||||
|
connect.WithSchema(processMethods.ByName("Start")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processUpdateHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessUpdateProcedure,
|
||||||
|
svc.Update,
|
||||||
|
connect.WithSchema(processMethods.ByName("Update")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processStreamInputHandler := connect.NewClientStreamHandler(
|
||||||
|
ProcessStreamInputProcedure,
|
||||||
|
svc.StreamInput,
|
||||||
|
connect.WithSchema(processMethods.ByName("StreamInput")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processSendInputHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessSendInputProcedure,
|
||||||
|
svc.SendInput,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendInput")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processSendSignalHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessSendSignalProcedure,
|
||||||
|
svc.SendSignal,
|
||||||
|
connect.WithSchema(processMethods.ByName("SendSignal")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
processCloseStdinHandler := connect.NewUnaryHandler(
|
||||||
|
ProcessCloseStdinProcedure,
|
||||||
|
svc.CloseStdin,
|
||||||
|
connect.WithSchema(processMethods.ByName("CloseStdin")),
|
||||||
|
connect.WithHandlerOptions(opts...),
|
||||||
|
)
|
||||||
|
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case ProcessListProcedure:
|
||||||
|
processListHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessConnectProcedure:
|
||||||
|
processConnectHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessStartProcedure:
|
||||||
|
processStartHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessUpdateProcedure:
|
||||||
|
processUpdateHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessStreamInputProcedure:
|
||||||
|
processStreamInputHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessSendInputProcedure:
|
||||||
|
processSendInputHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessSendSignalProcedure:
|
||||||
|
processSendSignalHandler.ServeHTTP(w, r)
|
||||||
|
case ProcessCloseStdinProcedure:
|
||||||
|
processCloseStdinHandler.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
|
||||||
|
type UnimplementedProcessHandler struct{}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[spec.ListRequest]) (*connect.Response[spec.ListResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[spec.ConnectRequest], *connect.ServerStream[spec.ConnectResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[spec.StartRequest], *connect.ServerStream[spec.StartResponse]) error {
|
||||||
|
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[spec.UpdateRequest]) (*connect.Response[spec.UpdateResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[spec.StreamInputRequest]) (*connect.Response[spec.StreamInputResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[spec.SendInputRequest]) (*connect.Response[spec.SendInputResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[spec.SendSignalRequest]) (*connect.Response[spec.SendSignalResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[spec.CloseStdinRequest]) (*connect.Response[spec.CloseStdinResponse], error) {
|
||||||
|
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user