Port envd from e2b with internalized shared packages and Connect RPC

- Copy envd source from e2b-dev/infra, internalize shared dependencies
  into envd/internal/shared/ (keys, filesystem, id, smap, utils)
- Switch from gRPC to Connect RPC for all envd services
- Update module paths to git.omukk.dev/wrenn/{sandbox,sandbox/envd}
- Add proto specs (process, filesystem) with buf-based code generation
- Implement full envd: process exec, filesystem ops, port forwarding,
  cgroup management, MMDS integration, and HTTP API
- Update main module dependencies (firecracker SDK, pgx, goose, etc.)
- Remove placeholder .gitkeep files replaced by real implementations
This commit is contained in:
2026-03-09 21:03:19 +06:00
parent bd78cc068c
commit a3898d68fb
99 changed files with 17185 additions and 24 deletions

View File

@ -1,17 +1,62 @@
LDFLAGS := -s -w BUILD := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
LDFLAGS := -s -w -X=main.commitSHA=$(BUILD)
BUILDS := ../builds
.PHONY: build clean fmt vet # ═══════════════════════════════════════════════════
# Build
# ═══════════════════════════════════════════════════
.PHONY: build build-debug
build: build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o envd . CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="$(LDFLAGS)" -o $(BUILDS)/envd .
@file envd | grep -q "statically linked" || \ @file $(BUILDS)/envd | grep -q "statically linked" || \
(echo "ERROR: envd is not statically linked!" && exit 1) (echo "ERROR: envd is not statically linked!" && exit 1)
clean: build-debug:
rm -f envd CGO_ENABLED=1 go build -race -gcflags=all="-N -l" -ldflags="-X=main.commitSHA=$(BUILD)" -o $(BUILDS)/debug/envd .
# ═══════════════════════════════════════════════════
# Run (debug mode, not inside a VM)
# ═══════════════════════════════════════════════════
.PHONY: run-debug
run-debug: build-debug
$(BUILDS)/debug/envd -isnotfc -port 49983
# ═══════════════════════════════════════════════════
# Code Generation
# ═══════════════════════════════════════════════════
.PHONY: generate proto openapi
generate: proto openapi
proto:
cd spec && buf generate --template buf.gen.yaml
openapi:
go generate ./internal/api/...
# ═══════════════════════════════════════════════════
# Quality
# ═══════════════════════════════════════════════════
.PHONY: fmt vet test tidy
fmt: fmt:
gofmt -w . gofmt -w .
vet: vet:
go vet ./... go vet ./...
test:
go test -race -v ./...
tidy:
go mod tidy
# ═══════════════════════════════════════════════════
# Clean
# ═══════════════════════════════════════════════════
.PHONY: clean
clean:
rm -f $(BUILDS)/envd $(BUILDS)/debug/envd

View File

@ -1,9 +1,42 @@
module github.com/wrenn-dev/envd module git.omukk.dev/wrenn/sandbox/envd
go 1.23.0 go 1.25.5
require ( require (
github.com/mdlayher/vsock v1.2.1 connectrpc.com/authn v0.1.0
google.golang.org/grpc v1.71.0 connectrpc.com/connect v1.19.1
google.golang.org/protobuf v1.36.5 connectrpc.com/cors v0.1.0
github.com/awnumar/memguard v0.23.0
github.com/creack/pty v1.1.24
github.com/dchest/uniuri v1.2.0
github.com/e2b-dev/fsnotify v0.0.1
github.com/go-chi/chi/v5 v5.2.5
github.com/google/uuid v1.6.0
github.com/oapi-codegen/runtime v1.2.0
github.com/orcaman/concurrent-map/v2 v2.0.1
github.com/rs/cors v1.11.1
github.com/rs/zerolog v1.34.0
github.com/shirou/gopsutil/v4 v4.26.2
github.com/stretchr/testify v1.11.1
github.com/txn2/txeh v1.8.0
golang.org/x/sys v0.42.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ebitengine/purego v0.10.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

92
envd/go.sum Normal file
View File

@ -0,0 +1,92 @@
connectrpc.com/authn v0.1.0 h1:m5weACjLWwgwcjttvUDyTPICJKw74+p2obBVrf8hT9E=
connectrpc.com/authn v0.1.0/go.mod h1:AwNZK/KYbqaJzRYadTuAaoz6sYQSPdORPqh1TOPIkgY=
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
connectrpc.com/cors v0.1.0 h1:f3gTXJyDZPrDIZCQ567jxfD9PAIpopHiRDnJRt3QuOQ=
connectrpc.com/cors v0.1.0/go.mod h1:v8SJZCPfHtGH1zsm+Ttajpozd4cYIUryl4dFB6QEpfg=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g=
github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY=
github.com/e2b-dev/fsnotify v0.0.1 h1:7j0I98HD6VehAuK/bcslvW4QDynAULtOuMZtImihjVk=
github.com/e2b-dev/fsnotify v0.0.1/go.mod h1:jAuDjregRrUixKneTRQwPI847nNuPFg3+n5QM/ku/JM=
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/oapi-codegen/runtime v1.2.0 h1:RvKc1CVS1QeKSNzO97FBQbSMZyQ8s6rZd+LpmzwHMP4=
github.com/oapi-codegen/runtime v1.2.0/go.mod h1:Y7ZhmmlE8ikZOmuHRRndiIm7nf3xcVv+YMweKgG1DT0=
github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/txn2/txeh v1.8.0 h1:G1vZgom6+P/xWwU53AMOpcZgC5ni382ukcPP1TDVYHk=
github.com/txn2/txeh v1.8.0/go.mod h1:rRI3Egi3+AFmEXQjft051YdYbxeCT3nFmBLsNCZZaxM=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=

View File

@ -0,0 +1,568 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/oapi-codegen/runtime"
openapi_types "github.com/oapi-codegen/runtime/types"
)
const (
AccessTokenAuthScopes = "AccessTokenAuth.Scopes"
)
// Defines values for EntryInfoType.
const (
File EntryInfoType = "file"
)
// EntryInfo defines model for EntryInfo.
type EntryInfo struct {
// Name Name of the file
Name string `json:"name"`
// Path Path to the file
Path string `json:"path"`
// Type Type of the file
Type EntryInfoType `json:"type"`
}
// EntryInfoType Type of the file
type EntryInfoType string
// EnvVars Environment variables to set
type EnvVars map[string]string
// Error defines model for Error.
type Error struct {
// Code Error code
Code int `json:"code"`
// Message Error message
Message string `json:"message"`
}
// Metrics Resource usage metrics
type Metrics struct {
// CpuCount Number of CPU cores
CpuCount *int `json:"cpu_count,omitempty"`
// CpuUsedPct CPU usage percentage
CpuUsedPct *float32 `json:"cpu_used_pct,omitempty"`
// DiskTotal Total disk space in bytes
DiskTotal *int `json:"disk_total,omitempty"`
// DiskUsed Used disk space in bytes
DiskUsed *int `json:"disk_used,omitempty"`
// MemTotal Total virtual memory in bytes
MemTotal *int `json:"mem_total,omitempty"`
// MemUsed Used virtual memory in bytes
MemUsed *int `json:"mem_used,omitempty"`
// Ts Unix timestamp in UTC for current sandbox time
Ts *int64 `json:"ts,omitempty"`
}
// VolumeMount Volume
type VolumeMount struct {
NfsTarget string `json:"nfs_target"`
Path string `json:"path"`
}
// FilePath defines model for FilePath.
type FilePath = string
// Signature defines model for Signature.
type Signature = string
// SignatureExpiration defines model for SignatureExpiration.
type SignatureExpiration = int
// User defines model for User.
type User = string
// FileNotFound defines model for FileNotFound.
type FileNotFound = Error
// InternalServerError defines model for InternalServerError.
type InternalServerError = Error
// InvalidPath defines model for InvalidPath.
type InvalidPath = Error
// InvalidUser defines model for InvalidUser.
type InvalidUser = Error
// NotEnoughDiskSpace defines model for NotEnoughDiskSpace.
type NotEnoughDiskSpace = Error
// UploadSuccess defines model for UploadSuccess.
type UploadSuccess = []EntryInfo
// GetFilesParams defines parameters for GetFiles.
type GetFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostFilesMultipartBody defines parameters for PostFiles.
type PostFilesMultipartBody struct {
File *openapi_types.File `json:"file,omitempty"`
}
// PostFilesParams defines parameters for PostFiles.
type PostFilesParams struct {
// Path Path to the file, URL encoded. Can be relative to user's home directory.
Path *FilePath `form:"path,omitempty" json:"path,omitempty"`
// Username User used for setting the owner, or resolving relative paths.
Username *User `form:"username,omitempty" json:"username,omitempty"`
// Signature Signature used for file access permission verification.
Signature *Signature `form:"signature,omitempty" json:"signature,omitempty"`
// SignatureExpiration Signature expiration used for defining the expiration time of the signature.
SignatureExpiration *SignatureExpiration `form:"signature_expiration,omitempty" json:"signature_expiration,omitempty"`
}
// PostInitJSONBody defines parameters for PostInit.
type PostInitJSONBody struct {
// AccessToken Access token for secure access to envd service
AccessToken *SecureToken `json:"accessToken,omitempty"`
// DefaultUser The default user to use for operations
DefaultUser *string `json:"defaultUser,omitempty"`
// DefaultWorkdir The default working directory to use for operations
DefaultWorkdir *string `json:"defaultWorkdir,omitempty"`
// EnvVars Environment variables to set
EnvVars *EnvVars `json:"envVars,omitempty"`
// HyperloopIP IP address of the hyperloop server to connect to
HyperloopIP *string `json:"hyperloopIP,omitempty"`
// Timestamp The current timestamp in RFC3339 format
Timestamp *time.Time `json:"timestamp,omitempty"`
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"`
}
// PostFilesMultipartRequestBody defines body for PostFiles for multipart/form-data ContentType.
type PostFilesMultipartRequestBody PostFilesMultipartBody
// PostInitJSONRequestBody defines body for PostInit for application/json ContentType.
type PostInitJSONRequestBody PostInitJSONBody
// ServerInterface represents all server handlers.
type ServerInterface interface {
// Get the environment variables
// (GET /envs)
GetEnvs(w http.ResponseWriter, r *http.Request)
// Download a file
// (GET /files)
GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams)
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams)
// Check the health of the service
// (GET /health)
GetHealth(w http.ResponseWriter, r *http.Request)
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
PostInit(w http.ResponseWriter, r *http.Request)
// Get the stats of the service
// (GET /metrics)
GetMetrics(w http.ResponseWriter, r *http.Request)
}
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
type Unimplemented struct{}
// Get the environment variables
// (GET /envs)
func (_ Unimplemented) GetEnvs(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Download a file
// (GET /files)
func (_ Unimplemented) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
// (POST /files)
func (_ Unimplemented) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// Check the health of the service
// (GET /health)
func (_ Unimplemented) GetHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Set initial vars, ensure the time and metadata is synced with the host
// (POST /init)
func (_ Unimplemented) PostInit(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// Get the stats of the service
// (GET /metrics)
func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
// ServerInterfaceWrapper converts contexts to parameters.
type ServerInterfaceWrapper struct {
Handler ServerInterface
HandlerMiddlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
type MiddlewareFunc func(http.Handler) http.Handler
// GetEnvs operation middleware
func (siw *ServerInterfaceWrapper) GetEnvs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetEnvs(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetFiles operation middleware
func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params GetFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), &params.Path)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), &params.Username)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), &params.Signature)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostFiles operation middleware
func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Request) {
var err error
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
// Parameter object where we will unmarshal all parameters from the context
var params PostFilesParams
// ------------- Optional query parameter "path" -------------
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), &params.Path)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
return
}
// ------------- Optional query parameter "username" -------------
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), &params.Username)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
return
}
// ------------- Optional query parameter "signature" -------------
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), &params.Signature)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
return
}
// ------------- Optional query parameter "signature_expiration" -------------
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), &params.SignatureExpiration)
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
return
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostFiles(w, r, params)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetHealth operation middleware
func (siw *ServerInterfaceWrapper) GetHealth(w http.ResponseWriter, r *http.Request) {
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetHealth(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// PostInit operation middleware
func (siw *ServerInterfaceWrapper) PostInit(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.PostInit(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetMetrics operation middleware
func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, AccessTokenAuthScopes, []string{})
r = r.WithContext(ctx)
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.GetMetrics(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
type UnescapedCookieParamError struct {
ParamName string
Err error
}
func (e *UnescapedCookieParamError) Error() string {
return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName)
}
func (e *UnescapedCookieParamError) Unwrap() error {
return e.Err
}
type UnmarshalingParamError struct {
ParamName string
Err error
}
func (e *UnmarshalingParamError) Error() string {
return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error())
}
func (e *UnmarshalingParamError) Unwrap() error {
return e.Err
}
type RequiredParamError struct {
ParamName string
}
func (e *RequiredParamError) Error() string {
return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName)
}
type RequiredHeaderError struct {
ParamName string
Err error
}
func (e *RequiredHeaderError) Error() string {
return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName)
}
func (e *RequiredHeaderError) Unwrap() error {
return e.Err
}
type InvalidParamFormatError struct {
ParamName string
Err error
}
func (e *InvalidParamFormatError) Error() string {
return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error())
}
func (e *InvalidParamFormatError) Unwrap() error {
return e.Err
}
type TooManyValuesForParamError struct {
ParamName string
Count int
}
func (e *TooManyValuesForParamError) Error() string {
return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count)
}
// Handler creates http.Handler with routing matching OpenAPI spec.
func Handler(si ServerInterface) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{})
}
type ChiServerOptions struct {
BaseURL string
BaseRouter chi.Router
Middlewares []MiddlewareFunc
ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error)
}
// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux.
func HandlerFromMux(si ServerInterface, r chi.Router) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseRouter: r,
})
}
func HandlerFromMuxWithBaseURL(si ServerInterface, r chi.Router, baseURL string) http.Handler {
return HandlerWithOptions(si, ChiServerOptions{
BaseURL: baseURL,
BaseRouter: r,
})
}
// HandlerWithOptions creates http.Handler with additional options
func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handler {
r := options.BaseRouter
if r == nil {
r = chi.NewRouter()
}
if options.ErrorHandlerFunc == nil {
options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusBadRequest)
}
}
wrapper := ServerInterfaceWrapper{
Handler: si,
HandlerMiddlewares: options.Middlewares,
ErrorHandlerFunc: options.ErrorHandlerFunc,
}
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/envs", wrapper.GetEnvs)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/files", wrapper.GetFiles)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/files", wrapper.PostFiles)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/health", wrapper.GetHealth)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/init", wrapper.PostInit)
})
r.Group(func(r chi.Router) {
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
})
return r
}

129
envd/internal/api/auth.go Normal file
View File

@ -0,0 +1,129 @@
package api
import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/awnumar/memguard"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
const (
SigningReadOperation = "read"
SigningWriteOperation = "write"
accessTokenHeader = "X-Access-Token"
)
// paths that are always allowed without general authentication
// POST/init is secured via MMDS hash validation instead
var authExcludedPaths = []string{
"GET/health",
"GET/files",
"POST/files",
"POST/init",
}
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if a.accessToken.IsSet() {
authHeader := req.Header.Get(accessTokenHeader)
// check if this path is allowed without authentication (e.g., health check, endpoints supporting signing)
allowedPath := slices.Contains(authExcludedPaths, req.Method+req.URL.Path)
if !a.accessToken.Equals(authHeader) && !allowedPath {
a.logger.Error().Msg("Trying to access secured envd without correct access token")
err := fmt.Errorf("unauthorized access, please provide a valid access token or method signing if supported")
jsonError(w, http.StatusUnauthorized, err)
return
}
}
handler.ServeHTTP(w, req)
})
}
func (a *API) generateSignature(path string, username string, operation string, signatureExpiration *int64) (string, error) {
tokenBytes, err := a.accessToken.Bytes()
if err != nil {
return "", fmt.Errorf("access token is not set: %w", err)
}
defer memguard.WipeBytes(tokenBytes)
var signature string
hasher := keys.NewSHA256Hashing()
if signatureExpiration == nil {
signature = strings.Join([]string{path, operation, username, string(tokenBytes)}, ":")
} else {
signature = strings.Join([]string{path, operation, username, string(tokenBytes), strconv.FormatInt(*signatureExpiration, 10)}, ":")
}
return fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(signature))), nil
}
func (a *API) validateSigning(r *http.Request, signature *string, signatureExpiration *int, username *string, path string, operation string) (err error) {
var expectedSignature string
// no need to validate signing key if access token is not set
if !a.accessToken.IsSet() {
return nil
}
// check if access token is sent in the header
tokenFromHeader := r.Header.Get(accessTokenHeader)
if tokenFromHeader != "" {
if !a.accessToken.Equals(tokenFromHeader) {
return fmt.Errorf("access token present in header but does not match")
}
return nil
}
if signature == nil {
return fmt.Errorf("missing signature query parameter")
}
// Empty string is used when no username is provided and the default user should be used
signatureUsername := ""
if username != nil {
signatureUsername = *username
}
if signatureExpiration == nil {
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, nil)
} else {
exp := int64(*signatureExpiration)
expectedSignature, err = a.generateSignature(path, signatureUsername, operation, &exp)
}
if err != nil {
a.logger.Error().Err(err).Msg("error generating signing key")
return errors.New("invalid signature")
}
// signature validation
if expectedSignature != *signature {
return fmt.Errorf("invalid signature")
}
// signature expiration
if signatureExpiration != nil {
exp := int64(*signatureExpiration)
if exp < time.Now().Unix() {
return fmt.Errorf("signature is already expired")
}
}
return nil
}

View File

@ -0,0 +1,62 @@
package api
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
func TestKeyGenerationAlgorithmIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/demo.txt"
username := "root"
operation := "write"
timestamp := time.Now().Unix()
signature, err := api.generateSignature(path, username, operation, &timestamp)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s:%s", path, operation, username, apiToken, strconv.FormatInt(timestamp, 10))
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}
func TestKeyGenerationAlgorithmWithoutExpirationIsStable(t *testing.T) {
t.Parallel()
apiToken := "secret-access-token"
secureToken := &SecureToken{}
err := secureToken.Set([]byte(apiToken))
require.NoError(t, err)
api := &API{accessToken: secureToken}
path := "/path/to/resource.txt"
username := "user"
operation := "read"
signature, err := api.generateSignature(path, username, operation, nil)
require.NoError(t, err)
assert.NotEmpty(t, signature)
// locally generated signature
hasher := keys.NewSHA256Hashing()
localSignatureTmp := fmt.Sprintf("%s:%s:%s:%s", path, operation, username, apiToken)
localSignature := fmt.Sprintf("v1_%s", hasher.HashWithoutPrefix([]byte(localSignatureTmp)))
assert.Equal(t, localSignature, signature)
}

View File

@ -0,0 +1,8 @@
# yaml-language-server: $schema=https://raw.githubusercontent.com/deepmap/oapi-codegen/HEAD/configuration-schema.json
package: api
output: api.gen.go
generate:
models: true
chi-server: true
client: false

View File

@ -0,0 +1,173 @@
package api
import (
"compress/gzip"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"os/user"
"path/filepath"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
)
func (a *API) GetFiles(w http.ResponseWriter, r *http.Request, params GetFilesParams) {
defer r.Body.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningReadOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File read")
}()
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
resolvedPath, err := permissions.ExpandAndResolve(path, u, a.defaults.Workdir)
if err != nil {
errMsg = fmt.Errorf("error expanding and resolving path '%s': %w", path, err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
stat, err := os.Stat(resolvedPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
errMsg = fmt.Errorf("path '%s' does not exist", resolvedPath)
errorCode = http.StatusNotFound
jsonError(w, errorCode, errMsg)
return
}
errMsg = fmt.Errorf("error checking if path exists '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
if stat.IsDir() {
errMsg = fmt.Errorf("path '%s' is a directory", resolvedPath)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
// Validate Accept-Encoding header
encoding, err := parseAcceptEncoding(r)
if err != nil {
errMsg = fmt.Errorf("error parsing Accept-Encoding: %w", err)
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
// Tell caches to store separate variants for different Accept-Encoding values
w.Header().Set("Vary", "Accept-Encoding")
// Fall back to identity for Range or conditional requests to preserve http.ServeContent
// behavior (206 Partial Content, 304 Not Modified). However, we must check if identity
// is acceptable per the Accept-Encoding header.
hasRangeOrConditional := r.Header.Get("Range") != "" ||
r.Header.Get("If-Modified-Since") != "" ||
r.Header.Get("If-None-Match") != "" ||
r.Header.Get("If-Range") != ""
if hasRangeOrConditional {
if !isIdentityAcceptable(r) {
errMsg = fmt.Errorf("identity encoding not acceptable for Range or conditional request")
errorCode = http.StatusNotAcceptable
jsonError(w, errorCode, errMsg)
return
}
encoding = EncodingIdentity
}
file, err := os.Open(resolvedPath)
if err != nil {
errMsg = fmt.Errorf("error opening file '%s': %w", resolvedPath, err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
defer file.Close()
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": filepath.Base(resolvedPath)}))
// Serve with gzip encoding if requested.
if encoding == EncodingGzip {
w.Header().Set("Content-Encoding", EncodingGzip)
// Set Content-Type based on file extension, preserving the original type
contentType := mime.TypeByExtension(filepath.Ext(path))
if contentType == "" {
contentType = "application/octet-stream"
}
w.Header().Set("Content-Type", contentType)
gw := gzip.NewWriter(w)
defer gw.Close()
_, err = io.Copy(gw, file)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error writing gzip response")
}
return
}
http.ServeContent(w, r, path, stat.ModTime(), file)
}

View File

@ -0,0 +1,401 @@
package api
import (
"bytes"
"compress/gzip"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/user"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestGetFilesContentDisposition(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
tests := []struct {
name string
filename string
expectedHeader string
}{
{
name: "simple filename",
filename: "test.txt",
expectedHeader: `inline; filename=test.txt`,
},
{
name: "filename with extension",
filename: "presentation.pptx",
expectedHeader: `inline; filename=presentation.pptx`,
},
{
name: "filename with multiple dots",
filename: "archive.tar.gz",
expectedHeader: `inline; filename=archive.tar.gz`,
},
{
name: "filename with spaces",
filename: "my document.pdf",
expectedHeader: `inline; filename="my document.pdf"`,
},
{
name: "filename with quotes",
filename: `file"name.txt`,
expectedHeader: `inline; filename="file\"name.txt"`,
},
{
name: "filename with backslash",
filename: `file\name.txt`,
expectedHeader: `inline; filename="file\\name.txt"`,
},
{
name: "unicode filename",
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
},
{
name: "dotfile preserved",
filename: ".env",
expectedHeader: `inline; filename=.env`,
},
{
name: "dotfile with extension preserved",
filename: ".gitignore",
expectedHeader: `inline; filename=.gitignore`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a temp directory and file
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, tt.filename)
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
})
}
}
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with nested structure
tempDir := t.TempDir()
nestedDir := filepath.Join(tempDir, "subdir", "another")
err = os.MkdirAll(nestedDir, 0o755)
require.NoError(t, err)
filename := "document.pdf"
tempFile := filepath.Join(nestedDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify Content-Disposition header uses only the base filename, not the full path
contentDisposition := resp.Header.Get("Content-Disposition")
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
}
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
// Create a temp directory with a test file
tempDir := t.TempDir()
filename := "document.pdf"
tempFile := filepath.Join(tempDir, filename)
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
require.NoError(t, err)
// Create test API
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
// Create request and response recorder
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
w := httptest.NewRecorder()
// Call the handler
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
// Check response
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
}
func TestGetFiles_GzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file for gzip compression")
// Create a temp file with known content
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test.txt")
err = os.WriteFile(tempFile, originalContent, 0o644)
require.NoError(t, err)
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
params := GetFilesParams{
Path: &tempFile,
Username: &currentUser.Username,
}
api.GetFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
// Decompress the gzip response body
gzReader, err := gzip.NewReader(resp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}
func TestPostFiles_GzipUpload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("hello world, this is a test file uploaded with gzip")
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
// Create test API
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "uploaded.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
req.Header.Set("Content-Encoding", "gzip")
w := httptest.NewRecorder()
params := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(w, req, params)
resp := w.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Verify the file was written with the original (decompressed) content
data, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
}
func TestGzipUploadThenGzipDownload(t *testing.T) {
t.Parallel()
currentUser, err := user.Current()
require.NoError(t, err)
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
// --- Upload with gzip ---
// Build a multipart body
var multipartBuf bytes.Buffer
mpWriter := multipart.NewWriter(&multipartBuf)
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
require.NoError(t, err)
_, err = part.Write(originalContent)
require.NoError(t, err)
err = mpWriter.Close()
require.NoError(t, err)
// Gzip-compress the entire multipart body
var gzBuf bytes.Buffer
gzWriter := gzip.NewWriter(&gzBuf)
_, err = gzWriter.Write(multipartBuf.Bytes())
require.NoError(t, err)
err = gzWriter.Close()
require.NoError(t, err)
tempDir := t.TempDir()
destPath := filepath.Join(tempDir, "roundtrip.txt")
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
User: currentUser.Username,
}
api := New(&logger, defaults, nil, false)
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
uploadReq.Header.Set("Content-Encoding", "gzip")
uploadW := httptest.NewRecorder()
uploadParams := PostFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.PostFiles(uploadW, uploadReq, uploadParams)
uploadResp := uploadW.Result()
defer uploadResp.Body.Close()
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
// --- Download with gzip ---
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
downloadReq.Header.Set("Accept-Encoding", "gzip")
downloadW := httptest.NewRecorder()
downloadParams := GetFilesParams{
Path: &destPath,
Username: &currentUser.Username,
}
api.GetFiles(downloadW, downloadReq, downloadParams)
downloadResp := downloadW.Result()
defer downloadResp.Body.Close()
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
// Decompress and verify content matches original
gzReader, err := gzip.NewReader(downloadResp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, originalContent, decompressed)
}

View File

@ -0,0 +1,227 @@
package api
import (
"compress/gzip"
"fmt"
"io"
"net/http"
"slices"
"sort"
"strconv"
"strings"
)
const (
// EncodingGzip is the gzip content encoding.
EncodingGzip = "gzip"
// EncodingIdentity means no encoding (passthrough).
EncodingIdentity = "identity"
// EncodingWildcard means any encoding is acceptable.
EncodingWildcard = "*"
)
// SupportedEncodings lists the content encodings supported for file transfer.
// The order matters - encodings are checked in order of preference.
var SupportedEncodings = []string{
EncodingGzip,
}
// encodingWithQuality holds an encoding name and its quality value.
type encodingWithQuality struct {
encoding string
quality float64
}
// isSupportedEncoding checks if the given encoding is in the supported list.
// Per RFC 7231, content-coding values are case-insensitive.
func isSupportedEncoding(encoding string) bool {
return slices.Contains(SupportedEncodings, strings.ToLower(encoding))
}
// parseEncodingWithQuality parses an encoding value and extracts the quality.
// Returns the encoding name (lowercased) and quality value (default 1.0 if not specified).
// Per RFC 7231, content-coding values are case-insensitive.
func parseEncodingWithQuality(value string) encodingWithQuality {
value = strings.TrimSpace(value)
quality := 1.0
if idx := strings.Index(value, ";"); idx != -1 {
params := value[idx+1:]
value = strings.TrimSpace(value[:idx])
// Parse q=X.X parameter
for param := range strings.SplitSeq(params, ";") {
param = strings.TrimSpace(param)
if strings.HasPrefix(strings.ToLower(param), "q=") {
if q, err := strconv.ParseFloat(param[2:], 64); err == nil {
quality = q
}
}
}
}
// Normalize encoding to lowercase per RFC 7231
return encodingWithQuality{encoding: strings.ToLower(value), quality: quality}
}
// parseEncoding extracts the encoding name from a header value, stripping quality.
func parseEncoding(value string) string {
return parseEncodingWithQuality(value).encoding
}
// parseContentEncoding parses the Content-Encoding header and returns the encoding.
// Returns an error if an unsupported encoding is specified.
// If no Content-Encoding header is present, returns empty string.
func parseContentEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Content-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encoding := parseEncoding(header)
if encoding == EncodingIdentity {
return EncodingIdentity, nil
}
if !isSupportedEncoding(encoding) {
return "", fmt.Errorf("unsupported Content-Encoding: %s, supported: %v", header, SupportedEncodings)
}
return encoding, nil
}
// parseAcceptEncodingHeader parses the Accept-Encoding header and returns
// the parsed encodings along with the identity rejection state.
// Per RFC 7231 Section 5.3.4, identity is acceptable unless excluded by
// "identity;q=0" or "*;q=0" without a more specific entry for identity with q>0.
func parseAcceptEncodingHeader(header string) ([]encodingWithQuality, bool) {
if header == "" {
return nil, false // identity not rejected when header is empty
}
// Parse all encodings with their quality values
var encodings []encodingWithQuality
for value := range strings.SplitSeq(header, ",") {
eq := parseEncodingWithQuality(value)
encodings = append(encodings, eq)
}
// Check if identity is rejected per RFC 7231 Section 5.3.4:
// identity is acceptable unless excluded by "identity;q=0" or "*;q=0"
// without a more specific entry for identity with q>0.
identityRejected := false
identityExplicitlyAccepted := false
wildcardRejected := false
for _, eq := range encodings {
switch eq.encoding {
case EncodingIdentity:
if eq.quality == 0 {
identityRejected = true
} else {
identityExplicitlyAccepted = true
}
case EncodingWildcard:
if eq.quality == 0 {
wildcardRejected = true
}
}
}
if wildcardRejected && !identityExplicitlyAccepted {
identityRejected = true
}
return encodings, identityRejected
}
// isIdentityAcceptable checks if identity encoding is acceptable based on the
// Accept-Encoding header. Per RFC 7231 section 5.3.4, identity is always
// implicitly acceptable unless explicitly rejected with q=0.
func isIdentityAcceptable(r *http.Request) bool {
header := r.Header.Get("Accept-Encoding")
_, identityRejected := parseAcceptEncodingHeader(header)
return !identityRejected
}
// parseAcceptEncoding parses the Accept-Encoding header and returns the best
// supported encoding based on quality values. Per RFC 7231 section 5.3.4,
// identity is always implicitly acceptable unless explicitly rejected with q=0.
// If no Accept-Encoding header is present, returns empty string (identity).
func parseAcceptEncoding(r *http.Request) (string, error) {
header := r.Header.Get("Accept-Encoding")
if header == "" {
return EncodingIdentity, nil
}
encodings, identityRejected := parseAcceptEncodingHeader(header)
// Sort by quality value (highest first)
sort.Slice(encodings, func(i, j int) bool {
return encodings[i].quality > encodings[j].quality
})
// Find the best supported encoding
for _, eq := range encodings {
// Skip encodings with q=0 (explicitly rejected)
if eq.quality == 0 {
continue
}
if eq.encoding == EncodingIdentity {
return EncodingIdentity, nil
}
// Wildcard means any encoding is acceptable - return a supported encoding if identity is rejected
if eq.encoding == EncodingWildcard {
if identityRejected && len(SupportedEncodings) > 0 {
return SupportedEncodings[0], nil
}
return EncodingIdentity, nil
}
if isSupportedEncoding(eq.encoding) {
return eq.encoding, nil
}
}
// Per RFC 7231, identity is implicitly acceptable unless rejected
if !identityRejected {
return EncodingIdentity, nil
}
// Identity rejected and no supported encodings found
return "", fmt.Errorf("no acceptable encoding found, supported: %v", SupportedEncodings)
}
// getDecompressedBody returns a reader that decompresses the request body based on
// Content-Encoding header. Returns the original body if no encoding is specified.
// Returns an error if an unsupported encoding is specified.
// The caller is responsible for closing both the returned ReadCloser and the
// original request body (r.Body) separately.
func getDecompressedBody(r *http.Request) (io.ReadCloser, error) {
encoding, err := parseContentEncoding(r)
if err != nil {
return nil, err
}
if encoding == EncodingIdentity {
return r.Body, nil
}
switch encoding {
case EncodingGzip:
gzReader, err := gzip.NewReader(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
return gzReader, nil
default:
// This shouldn't happen if isSupportedEncoding is correct
return nil, fmt.Errorf("encoding %s is supported but not implemented", encoding)
}
}

View File

@ -0,0 +1,494 @@
package api
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSupportedEncoding(t *testing.T) {
t.Parallel()
t.Run("gzip is supported", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("gzip"))
})
t.Run("GZIP is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("GZIP"))
})
t.Run("Gzip is supported (case-insensitive)", func(t *testing.T) {
t.Parallel()
assert.True(t, isSupportedEncoding("Gzip"))
})
t.Run("br is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("br"))
})
t.Run("deflate is not supported", func(t *testing.T) {
t.Parallel()
assert.False(t, isSupportedEncoding("deflate"))
})
}
func TestParseEncodingWithQuality(t *testing.T) {
t.Parallel()
t.Run("returns encoding with default quality 1.0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("parses quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
t.Run("parses quality value with whitespace", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip ; q=0.8")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.8, eq.quality, 0.001)
})
t.Run("handles q=0", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=0")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.0, eq.quality, 0.001)
})
t.Run("handles invalid quality value", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("gzip;q=invalid")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001) // defaults to 1.0 on parse error
})
t.Run("trims whitespace from encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality(" gzip ")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 1.0, eq.quality, 0.001)
})
t.Run("normalizes encoding to lowercase", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("GZIP")
assert.Equal(t, "gzip", eq.encoding)
})
t.Run("normalizes mixed case encoding", func(t *testing.T) {
t.Parallel()
eq := parseEncodingWithQuality("Gzip;q=0.5")
assert.Equal(t, "gzip", eq.encoding)
assert.InDelta(t, 0.5, eq.quality, 0.001)
})
}
func TestParseEncoding(t *testing.T) {
t.Parallel()
t.Run("returns encoding as-is", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip"))
})
t.Run("trims whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding(" gzip "))
})
t.Run("strips quality value", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip;q=1.0"))
})
t.Run("strips quality value with whitespace", func(t *testing.T) {
t.Parallel()
assert.Equal(t, "gzip", parseEncoding("gzip ; q=0.5"))
})
}
func TestParseContentEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "GZIP")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Content-Encoding is Gzip (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "Gzip")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "identity")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "br")
_, err := parseContentEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
assert.Contains(t, err.Error(), "supported: [gzip]")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", nil)
req.Header.Set("Content-Encoding", "gzip;q=1.0")
encoding, err := parseContentEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
}
func TestParseAcceptEncoding(t *testing.T) {
t.Parallel()
t.Run("returns identity when no header", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when Accept-Encoding is gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when Accept-Encoding is GZIP (case-insensitive)", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "GZIP")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip when gzip is among multiple encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, gzip, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns gzip with quality value", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity for identity encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns identity for wildcard encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity for unsupported encoding only", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when only unsupported encodings", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "deflate, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("selects gzip when it has highest quality", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=0.5, gzip;q=1.0, deflate;q=0.8")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("selects gzip even with lower quality when others unsupported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br;q=1.0, gzip;q=0.5")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding)
})
t.Run("returns identity when it has higher quality than gzip", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0.5, identity;q=1.0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("skips encoding with q=0", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("falls back to identity when gzip rejected and no other supported", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "gzip;q=0, br")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns error when identity explicitly rejected and no supported encoding", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "br, identity;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns gzip for wildcard when identity rejected", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*, identity;q=0")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, "gzip", encoding) // wildcard with identity rejected returns supported encoding
})
t.Run("returns error when wildcard rejected and no explicit identity", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0")
_, err := parseAcceptEncoding(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "no acceptable encoding found")
})
t.Run("returns identity when wildcard rejected but identity explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, identity")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingIdentity, encoding)
})
t.Run("returns gzip when wildcard rejected but gzip explicitly accepted", func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
req.Header.Set("Accept-Encoding", "*;q=0, gzip")
encoding, err := parseAcceptEncoding(req)
require.NoError(t, err)
assert.Equal(t, EncodingGzip, encoding)
})
}
func TestGetDecompressedBody(t *testing.T) {
t.Parallel()
t.Run("returns original body when no Content-Encoding header", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("decompresses gzip body when Content-Encoding is gzip", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
assert.NotEqual(t, req.Body, body, "should return a new gzip reader")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
t.Run("returns error for invalid gzip data", func(t *testing.T) {
t.Parallel()
invalidGzip := []byte("this is not gzip data")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(invalidGzip))
req.Header.Set("Content-Encoding", "gzip")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to create gzip reader")
})
t.Run("returns original body for identity encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "identity")
body, err := getDecompressedBody(req)
require.NoError(t, err)
assert.Equal(t, req.Body, body, "should return original body")
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("returns error for unsupported encoding", func(t *testing.T) {
t.Parallel()
content := []byte("test content")
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(content))
req.Header.Set("Content-Encoding", "br")
_, err := getDecompressedBody(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported Content-Encoding")
})
t.Run("handles gzip with quality value", func(t *testing.T) {
t.Parallel()
originalContent := []byte("test content to compress")
var compressed bytes.Buffer
gw := gzip.NewWriter(&compressed)
_, err := gw.Write(originalContent)
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "/test", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "gzip;q=1.0")
body, err := getDecompressedBody(req)
require.NoError(t, err)
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, originalContent, data)
})
}

29
envd/internal/api/envs.go Normal file
View File

@ -0,0 +1,29 @@
package api
import (
"encoding/json"
"net/http"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
)
func (a *API) GetEnvs(w http.ResponseWriter, _ *http.Request) {
operationID := logs.AssignOperationID()
a.logger.Debug().Str(string(logs.OperationIDKey), operationID).Msg("Getting env vars")
envs := make(EnvVars)
a.defaults.EnvVars.Range(func(key, value string) bool {
envs[key] = value
return true
})
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(envs); err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("Failed to encode env vars")
}
}

View File

@ -0,0 +1,21 @@
package api
import (
"encoding/json"
"errors"
"net/http"
)
func jsonError(w http.ResponseWriter, code int, err error) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
encodeErr := json.NewEncoder(w).Encode(Error{
Code: code,
Message: err.Error(),
})
if encodeErr != nil {
http.Error(w, errors.Join(encodeErr, err).Error(), http.StatusInternalServerError)
}
}

View File

@ -0,0 +1,3 @@
package api
//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen -config cfg.yaml ../../spec/envd.yaml

314
envd/internal/api/init.go Normal file
View File

@ -0,0 +1,314 @@
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"os/exec"
"time"
"github.com/awnumar/memguard"
"github.com/rs/zerolog"
"github.com/txn2/txeh"
"golang.org/x/sys/unix"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
)
var (
ErrAccessTokenMismatch = errors.New("access token validation failed")
ErrAccessTokenResetNotAuthorized = errors.New("access token reset not authorized")
)
const (
maxTimeInPast = 50 * time.Millisecond
maxTimeInFuture = 5 * time.Second
)
// validateInitAccessToken validates the access token for /init requests.
// Token is valid if it matches the existing token OR the MMDS hash.
// If neither exists, first-time setup is allowed.
func (a *API) validateInitAccessToken(ctx context.Context, requestToken *SecureToken) error {
requestTokenSet := requestToken.IsSet()
// Fast path: token matches existing
if a.accessToken.IsSet() && requestTokenSet && a.accessToken.EqualsSecure(requestToken) {
return nil
}
// Check MMDS only if token didn't match existing
matchesMMDS, mmdsExists := a.checkMMDSHash(ctx, requestToken)
switch {
case matchesMMDS:
return nil
case !a.accessToken.IsSet() && !mmdsExists:
return nil // first-time setup
case !requestTokenSet:
return ErrAccessTokenResetNotAuthorized
default:
return ErrAccessTokenMismatch
}
}
// checkMMDSHash checks if the request token matches the MMDS hash.
// Returns (matches, mmdsExists).
//
// The MMDS hash is set by the orchestrator during Resume:
// - hash(token): requires this specific token
// - hash(""): explicitly allows nil token (token reset authorized)
// - "": MMDS not properly configured, no authorization granted
func (a *API) checkMMDSHash(ctx context.Context, requestToken *SecureToken) (bool, bool) {
if a.isNotFC {
return false, false
}
mmdsHash, err := a.mmdsClient.GetAccessTokenHash(ctx)
if err != nil {
return false, false
}
if mmdsHash == "" {
return false, false
}
if !requestToken.IsSet() {
return mmdsHash == keys.HashAccessToken(""), true
}
tokenBytes, err := requestToken.Bytes()
if err != nil {
return false, true
}
defer memguard.WipeBytes(tokenBytes)
return keys.HashAccessTokenBytes(tokenBytes) == mmdsHash, true
}
func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := r.Context()
operationID := logs.AssignOperationID()
logger := a.logger.With().Str(string(logs.OperationIDKey), operationID).Logger()
if r.Body != nil {
// Read raw body so we can wipe it after parsing
body, err := io.ReadAll(r.Body)
// Ensure body is wiped after we're done
defer memguard.WipeBytes(body)
if err != nil {
logger.Error().Msgf("Failed to read request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
var initRequest PostInitJSONBody
if len(body) > 0 {
err = json.Unmarshal(body, &initRequest)
if err != nil {
logger.Error().Msgf("Failed to decode request: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
}
// Ensure request token is destroyed if not transferred via TakeFrom.
// This handles: validation failures, timestamp-based skips, and any early returns.
// Safe because Destroy() is nil-safe and TakeFrom clears the source.
defer initRequest.AccessToken.Destroy()
a.initLock.Lock()
defer a.initLock.Unlock()
// Update data only if the request is newer or if there's no timestamp at all
if initRequest.Timestamp == nil || a.lastSetTime.SetToGreater(initRequest.Timestamp.UnixNano()) {
err = a.SetData(ctx, logger, initRequest)
if err != nil {
switch {
case errors.Is(err, ErrAccessTokenMismatch), errors.Is(err, ErrAccessTokenResetNotAuthorized):
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Error().Msgf("Failed to set data: %v", err)
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(err.Error()))
return
}
}
}
go func() { //nolint:contextcheck // TODO: fix this later
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
}()
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) SetData(ctx context.Context, logger zerolog.Logger, data PostInitJSONBody) error {
// Validate access token before proceeding with any action
// The request must provide a token that is either:
// 1. Matches the existing access token (if set), OR
// 2. Matches the MMDS hash (for token change during resume)
if err := a.validateInitAccessToken(ctx, data.AccessToken); err != nil {
return err
}
if data.Timestamp != nil {
// Check if current time differs significantly from the received timestamp
if shouldSetSystemTime(time.Now(), *data.Timestamp) {
logger.Debug().Msgf("Setting sandbox start time to: %v", *data.Timestamp)
ts := unix.NsecToTimespec(data.Timestamp.UnixNano())
err := unix.ClockSettime(unix.CLOCK_REALTIME, &ts)
if err != nil {
logger.Error().Msgf("Failed to set system time: %v", err)
}
} else {
logger.Debug().Msgf("Current time is within acceptable range of timestamp %v, not setting system time", *data.Timestamp)
}
}
if data.EnvVars != nil {
logger.Debug().Msg(fmt.Sprintf("Setting %d env vars", len(*data.EnvVars)))
for key, value := range *data.EnvVars {
logger.Debug().Msgf("Setting env var for %s", key)
a.defaults.EnvVars.Store(key, value)
}
}
if data.AccessToken.IsSet() {
logger.Debug().Msg("Setting access token")
a.accessToken.TakeFrom(data.AccessToken)
} else if a.accessToken.IsSet() {
logger.Debug().Msg("Clearing access token")
a.accessToken.Destroy()
}
if data.HyperloopIP != nil {
go a.SetupHyperloop(*data.HyperloopIP)
}
if data.DefaultUser != nil && *data.DefaultUser != "" {
logger.Debug().Msgf("Setting default user to: %s", *data.DefaultUser)
a.defaults.User = *data.DefaultUser
}
if data.DefaultWorkdir != nil && *data.DefaultWorkdir != "" {
logger.Debug().Msgf("Setting default workdir to: %s", *data.DefaultWorkdir)
a.defaults.Workdir = data.DefaultWorkdir
}
if data.VolumeMounts != nil {
for _, volume := range *data.VolumeMounts {
logger.Debug().Msgf("Mounting %s at %q", volume.NfsTarget, volume.Path)
go a.setupNfs(context.WithoutCancel(ctx), volume.NfsTarget, volume.Path)
}
}
return nil
}
func (a *API) setupNfs(ctx context.Context, nfsTarget, path string) {
commands := [][]string{
{"mkdir", "-p", path},
{"mount", "-v", "-t", "nfs", "-o", "mountproto=tcp,mountport=2049,proto=tcp,port=2049,nfsvers=3,noacl", nfsTarget, path},
}
for _, command := range commands {
data, err := exec.CommandContext(ctx, command[0], command[1:]...).CombinedOutput()
logger := a.getLogger(err)
logger.
Strs("command", command).
Str("output", string(data)).
Msg("Mount NFS")
if err != nil {
return
}
}
}
func (a *API) SetupHyperloop(address string) {
a.hyperloopLock.Lock()
defer a.hyperloopLock.Unlock()
if err := rewriteHostsFile(address, "/etc/hosts"); err != nil {
a.logger.Error().Err(err).Msg("failed to modify hosts file")
} else {
a.defaults.EnvVars.Store("WRENN_EVENTS_ADDRESS", fmt.Sprintf("http://%s", address))
}
}
const eventsHost = "events.wrenn.local"
func rewriteHostsFile(address, path string) error {
hosts, err := txeh.NewHosts(&txeh.HostsConfig{
ReadFilePath: path,
WriteFilePath: path,
})
if err != nil {
return fmt.Errorf("failed to create hosts: %w", err)
}
// Update /etc/hosts to point events.wrenn.local to the hyperloop IP
// This will remove any existing entries for events.wrenn.local first
ipFamily, err := getIPFamily(address)
if err != nil {
return fmt.Errorf("failed to get ip family: %w", err)
}
if ok, current, _ := hosts.HostAddressLookup(eventsHost, ipFamily); ok && current == address {
return nil // nothing to be done
}
hosts.AddHost(address, eventsHost)
return hosts.Save()
}
var (
ErrInvalidAddress = errors.New("invalid IP address")
ErrUnknownAddressFormat = errors.New("unknown IP address format")
)
func getIPFamily(address string) (txeh.IPFamily, error) {
addressIP, err := netip.ParseAddr(address)
if err != nil {
return txeh.IPFamilyV4, fmt.Errorf("failed to parse IP address: %w", err)
}
switch {
case addressIP.Is4():
return txeh.IPFamilyV4, nil
case addressIP.Is6():
return txeh.IPFamilyV6, nil
default:
return txeh.IPFamilyV4, fmt.Errorf("%w: %s", ErrUnknownAddressFormat, address)
}
}
// shouldSetSystemTime returns true if the current time differs significantly from the received timestamp,
// indicating the system clock should be adjusted. Returns true when the sandboxTime is more than
// maxTimeInPast before the hostTime or more than maxTimeInFuture after the hostTime.
func shouldSetSystemTime(sandboxTime, hostTime time.Time) bool {
return sandboxTime.Before(hostTime.Add(-maxTimeInPast)) || sandboxTime.After(hostTime.Add(maxTimeInFuture))
}

View File

@ -0,0 +1,587 @@
package api
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
)
func TestSimpleCases(t *testing.T) {
t.Parallel()
testCases := map[string]func(string) string{
"both newlines": func(s string) string { return s },
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
"no newline prefix or suffix": strings.TrimSpace,
}
for name, preprocessor := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
value := `
# comment
127.0.0.1 one.host
127.0.0.2 two.host
`
value = preprocessor(value)
inputPath := filepath.Join(tempDir, "hosts")
err := os.WriteFile(inputPath, []byte(value), 0o644)
require.NoError(t, err)
err = rewriteHostsFile("127.0.0.3", inputPath)
require.NoError(t, err)
data, err := os.ReadFile(inputPath)
require.NoError(t, err)
assert.Equal(t, `# comment
127.0.0.1 one.host
127.0.0.2 two.host
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
})
}
}
func TestShouldSetSystemTime(t *testing.T) {
t.Parallel()
sandboxTime := time.Now()
tests := []struct {
name string
hostTime time.Time
want bool
}{
{
name: "sandbox time far ahead of host time (should set)",
hostTime: sandboxTime.Add(-10 * time.Second),
want: true,
},
{
name: "sandbox time at maxTimeInPast boundary ahead of host time (should not set)",
hostTime: sandboxTime.Add(-50 * time.Millisecond),
want: false,
},
{
name: "sandbox time just within maxTimeInPast ahead of host time (should not set)",
hostTime: sandboxTime.Add(-40 * time.Millisecond),
want: false,
},
{
name: "sandbox time slightly ahead of host time (should not set)",
hostTime: sandboxTime.Add(-10 * time.Millisecond),
want: false,
},
{
name: "sandbox time equals host time (should not set)",
hostTime: sandboxTime,
want: false,
},
{
name: "sandbox time slightly behind host time (should not set)",
hostTime: sandboxTime.Add(1 * time.Second),
want: false,
},
{
name: "sandbox time just within maxTimeInFuture behind host time (should not set)",
hostTime: sandboxTime.Add(4 * time.Second),
want: false,
},
{
name: "sandbox time at maxTimeInFuture boundary behind host time (should not set)",
hostTime: sandboxTime.Add(5 * time.Second),
want: false,
},
{
name: "sandbox time far behind host time (should set)",
hostTime: sandboxTime.Add(1 * time.Minute),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := shouldSetSystemTime(tt.hostTime, sandboxTime)
assert.Equal(t, tt.want, got)
})
}
}
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))
return token
}
type mockMMDSClient struct {
hash string
err error
}
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
return m.hash, m.err
}
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
}
api := New(&logger, defaults, nil, false)
if accessToken != nil {
api.accessToken.TakeFrom(accessToken)
}
api.mmdsClient = mmdsClient
return api
}
func TestValidateInitAccessToken(t *testing.T) {
t.Parallel()
ctx := t.Context()
tests := []struct {
name string
accessToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
}{
{
name: "fast path: token matches existing",
accessToken: secureTokenPtr("secret-token"),
requestToken: secureTokenPtr("secret-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "MMDS match: token hash matches MMDS hash",
accessToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: no existing token, MMDS error",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "first-time setup: no existing token, empty MMDS hash",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: both tokens nil, no MMDS",
accessToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "mismatch: existing token differs from request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
},
{
name: "mismatch: existing token differs from request, MMDS hash mismatch",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
},
{
name: "conflict: existing token, nil request, MMDS exists",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
},
{
name: "conflict: existing token, nil request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.accessToken, mmdsClient)
err := api.validateInitAccessToken(ctx, tt.requestToken)
if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestCheckMMDSHash(t *testing.T) {
t.Parallel()
ctx := t.Context()
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
t.Parallel()
token := "my-secret-token"
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
assert.True(t, matches)
assert.True(t, exists)
})
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.True(t, matches)
assert.True(t, exists)
})
}
func TestSetData(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := zerolog.Nop()
t.Run("access token updates", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
wantFinalToken *SecureToken
}{
{
name: "first-time setup: sets initial token",
existingToken: nil,
requestToken: secureTokenPtr("initial-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("initial-token"),
},
{
name: "first-time setup: nil request token leaves token unset",
existingToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: nil,
},
{
name: "re-init with same token: token unchanged",
existingToken: secureTokenPtr("same-token"),
requestToken: secureTokenPtr("same-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("same-token"),
},
{
name: "resume with MMDS: updates token when hash matches",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: secureTokenPtr("new-token"),
},
{
name: "resume with MMDS: fails when hash doesn't match",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("old-token"),
},
{
name: "fails when existing token and request token mismatch without MMDS",
existingToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request token",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request with MMDS present",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken(""),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.existingToken, mmdsClient)
data := PostInitJSONBody{
AccessToken: tt.requestToken,
}
err := api.SetData(ctx, logger, data)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
if tt.wantFinalToken == nil {
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
} else {
require.True(t, api.accessToken.IsSet(), "expected token to be set")
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
}
})
}
})
t.Run("sets environment variables", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
data := PostInitJSONBody{
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
val, ok := api.defaults.EnvVars.Load("FOO")
assert.True(t, ok)
assert.Equal(t, "bar", val)
val, ok = api.defaults.EnvVars.Load("BAZ")
assert.True(t, ok)
assert.Equal(t, "qux", val)
})
t.Run("sets default user", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr("testuser"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "testuser", api.defaults.User)
})
t.Run("does not set default user when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
api.defaults.User = "original"
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "original", api.defaults.User)
})
t.Run("sets default workdir", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/home/user", *api.defaults.Workdir)
})
t.Run("does not set default workdir when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
originalWorkdir := "/original"
api.defaults.Workdir = &originalWorkdir
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/original", *api.defaults.Workdir)
})
t.Run("sets multiple fields at once", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"KEY": "value"}
data := PostInitJSONBody{
AccessToken: secureTokenPtr("token"),
DefaultUser: utilsShared.ToPtr("user"),
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
assert.Equal(t, "user", api.defaults.User)
assert.Equal(t, "/workdir", *api.defaults.Workdir)
val, ok := api.defaults.EnvVars.Load("KEY")
assert.True(t, ok)
assert.Equal(t, "value", val)
})
}

View File

@ -0,0 +1,212 @@
package api
import (
"bytes"
"errors"
"sync"
"github.com/awnumar/memguard"
)
var (
ErrTokenNotSet = errors.New("access token not set")
ErrTokenEmpty = errors.New("empty token not allowed")
)
// SecureToken wraps memguard for secure token storage.
// It uses LockedBuffer which provides memory locking, guard pages,
// and secure zeroing on destroy.
type SecureToken struct {
mu sync.RWMutex
buffer *memguard.LockedBuffer
}
// Set securely replaces the token, destroying the old one first.
// The old token memory is zeroed before the new token is stored.
// The input byte slice is wiped after copying to secure memory.
// Returns ErrTokenEmpty if token is empty - use Destroy() to clear the token instead.
func (s *SecureToken) Set(token []byte) error {
if len(token) == 0 {
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
// Destroy old token first (zeros memory)
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Create new LockedBuffer from bytes (source slice is wiped by memguard)
s.buffer = memguard.NewBufferFromBytes(token)
return nil
}
// UnmarshalJSON implements json.Unmarshaler to securely parse a JSON string
// directly into memguard, wiping the input bytes after copying.
//
// Access tokens are hex-encoded HMAC-SHA256 hashes (64 chars of [0-9a-f]),
// so they never contain JSON escape sequences.
func (s *SecureToken) UnmarshalJSON(data []byte) error {
// JSON strings are quoted, so minimum valid is `""` (2 bytes).
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
memguard.WipeBytes(data)
return errors.New("invalid secure token JSON string")
}
content := data[1 : len(data)-1]
// Access tokens are hex strings - reject if contains backslash
if bytes.ContainsRune(content, '\\') {
memguard.WipeBytes(data)
return errors.New("invalid secure token: unexpected escape sequence")
}
if len(content) == 0 {
memguard.WipeBytes(data)
return ErrTokenEmpty
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
// Allocate secure buffer and copy directly into it
s.buffer = memguard.NewBuffer(len(content))
copy(s.buffer.Bytes(), content)
// Wipe the input data
memguard.WipeBytes(data)
return nil
}
// TakeFrom transfers the token from src to this SecureToken, destroying any
// existing token. The source token is cleared after transfer.
// This avoids copying the underlying bytes.
func (s *SecureToken) TakeFrom(src *SecureToken) {
if src == nil || s == src {
return
}
// Extract buffer from source
src.mu.Lock()
buffer := src.buffer
src.buffer = nil
src.mu.Unlock()
// Install buffer in destination
s.mu.Lock()
if s.buffer != nil {
s.buffer.Destroy()
}
s.buffer = buffer
s.mu.Unlock()
}
// Equals checks if token matches using constant-time comparison.
// Returns false if the receiver is nil.
func (s *SecureToken) Equals(token string) bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo([]byte(token))
}
// EqualsSecure compares this token with another SecureToken using constant-time comparison.
// Returns false if either receiver or other is nil.
func (s *SecureToken) EqualsSecure(other *SecureToken) bool {
if s == nil || other == nil {
return false
}
if s == other {
return s.IsSet()
}
// Get a copy of other's bytes (avoids holding two locks simultaneously)
otherBytes, err := other.Bytes()
if err != nil {
return false
}
defer memguard.WipeBytes(otherBytes)
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return false
}
return s.buffer.EqualTo(otherBytes)
}
// IsSet returns true if a token is stored.
// Returns false if the receiver is nil.
func (s *SecureToken) IsSet() bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.buffer != nil && s.buffer.IsAlive()
}
// Bytes returns a copy of the token bytes (for signature generation).
// The caller should zero the returned slice after use.
// Returns ErrTokenNotSet if the receiver is nil.
func (s *SecureToken) Bytes() ([]byte, error) {
if s == nil {
return nil, ErrTokenNotSet
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.buffer == nil || !s.buffer.IsAlive() {
return nil, ErrTokenNotSet
}
// Return a copy (unavoidable for signature generation)
src := s.buffer.Bytes()
result := make([]byte, len(src))
copy(result, src)
return result, nil
}
// Destroy securely wipes the token from memory.
// No-op if the receiver is nil.
func (s *SecureToken) Destroy() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.buffer != nil {
s.buffer.Destroy()
s.buffer = nil
}
}

View File

@ -0,0 +1,461 @@
package api
import (
"sync"
"testing"
"github.com/awnumar/memguard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSecureTokenSetAndEquals(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Initially not set
assert.False(t, st.IsSet(), "token should not be set initially")
assert.False(t, st.Equals("any-token"), "equals should return false when not set")
// Set token
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet(), "token should be set after Set()")
assert.True(t, st.Equals("test-token"), "equals should return true for correct token")
assert.False(t, st.Equals("wrong-token"), "equals should return false for wrong token")
assert.False(t, st.Equals(""), "equals should return false for empty token")
}
func TestSecureTokenReplace(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set initial token
err := st.Set([]byte("first-token"))
require.NoError(t, err)
assert.True(t, st.Equals("first-token"))
// Replace with new token (old one should be destroyed)
err = st.Set([]byte("second-token"))
require.NoError(t, err)
assert.True(t, st.Equals("second-token"), "should match new token")
assert.False(t, st.Equals("first-token"), "should not match old token")
}
func TestSecureTokenDestroy(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set and then destroy
err := st.Set([]byte("test-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
st.Destroy()
assert.False(t, st.IsSet(), "token should not be set after Destroy()")
assert.False(t, st.Equals("test-token"), "equals should return false after Destroy()")
// Destroy on already destroyed should be safe
st.Destroy()
assert.False(t, st.IsSet())
// Nil receiver should be safe
var nilToken *SecureToken
assert.False(t, nilToken.IsSet(), "nil receiver should return false for IsSet()")
assert.False(t, nilToken.Equals("anything"), "nil receiver should return false for Equals()")
assert.False(t, nilToken.EqualsSecure(st), "nil receiver should return false for EqualsSecure()")
nilToken.Destroy() // should not panic
_, err = nilToken.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet, "nil receiver should return ErrTokenNotSet for Bytes()")
}
func TestSecureTokenBytes(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Bytes should return error when not set
_, err := st.Bytes()
require.ErrorIs(t, err, ErrTokenNotSet)
// Set token and get bytes
err = st.Set([]byte("test-token"))
require.NoError(t, err)
bytes, err := st.Bytes()
require.NoError(t, err)
assert.Equal(t, []byte("test-token"), bytes)
// Zero out the bytes (as caller should do)
memguard.WipeBytes(bytes)
// Original should still be intact
assert.True(t, st.Equals("test-token"), "original token should still work after zeroing copy")
// After destroy, bytes should fail
st.Destroy()
_, err = st.Bytes()
assert.ErrorIs(t, err, ErrTokenNotSet)
}
func TestSecureTokenConcurrentAccess(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("initial-token"))
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 100
// Concurrent reads
for range numGoroutines {
wg.Go(func() {
st.IsSet()
st.Equals("initial-token")
})
}
// Concurrent writes
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
st.Set([]byte("token-" + string(rune('a'+idx))))
}(i)
}
wg.Wait()
// Should still be in a valid state
assert.True(t, st.IsSet())
}
func TestSecureTokenEmptyToken(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Setting empty token should return an error
err := st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after empty token error")
// Setting nil should also return an error
err = st.Set(nil)
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet(), "token should not be set after nil token error")
}
func TestSecureTokenEmptyTokenDoesNotClearExisting(t *testing.T) {
t.Parallel()
st := &SecureToken{}
// Set a valid token first
err := st.Set([]byte("valid-token"))
require.NoError(t, err)
assert.True(t, st.IsSet())
// Attempting to set empty token should fail and preserve existing token
err = st.Set([]byte{})
require.ErrorIs(t, err, ErrTokenEmpty)
assert.True(t, st.IsSet(), "existing token should be preserved after empty token error")
assert.True(t, st.Equals("valid-token"), "existing token value should be unchanged")
}
func TestSecureTokenUnmarshalJSON(t *testing.T) {
t.Parallel()
t.Run("unmarshals valid JSON string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"my-secret-token"`))
require.NoError(t, err)
assert.True(t, st.IsSet())
assert.True(t, st.Equals("my-secret-token"))
})
t.Run("returns error for empty string", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`""`))
require.ErrorIs(t, err, ErrTokenEmpty)
assert.False(t, st.IsSet())
})
t.Run("returns error for invalid JSON", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`not-valid-json`))
require.Error(t, err)
assert.False(t, st.IsSet())
})
t.Run("replaces existing token", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("old-token"))
require.NoError(t, err)
err = st.UnmarshalJSON([]byte(`"new-token"`))
require.NoError(t, err)
assert.True(t, st.Equals("new-token"))
assert.False(t, st.Equals("old-token"))
})
t.Run("wipes input buffer after parsing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte(`"secret-token-12345"`)
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("secret-token-12345"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("wipes input buffer on error", func(t *testing.T) {
t.Parallel()
// Create a buffer with an empty token (will error)
input := []byte(`""`)
st := &SecureToken{}
err := st.UnmarshalJSON(input)
require.Error(t, err)
// Verify the input buffer was still wiped
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
t.Run("rejects escape sequences", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.UnmarshalJSON([]byte(`"token\nwith\nnewlines"`))
require.Error(t, err)
assert.Contains(t, err.Error(), "escape sequence")
assert.False(t, st.IsSet())
})
}
func TestSecureTokenSetWipesInput(t *testing.T) {
t.Parallel()
t.Run("wipes input buffer after storing", func(t *testing.T) {
t.Parallel()
// Create a buffer with a known token
input := []byte("my-secret-token")
original := make([]byte, len(input))
copy(original, input)
st := &SecureToken{}
err := st.Set(input)
require.NoError(t, err)
// Verify the token was stored correctly
assert.True(t, st.Equals("my-secret-token"))
// Verify the input buffer was wiped (all zeros)
for i, b := range input {
assert.Equal(t, byte(0), b, "byte at position %d should be zero, got %d", i, b)
}
})
}
func TestSecureTokenTakeFrom(t *testing.T) {
t.Parallel()
t.Run("transfers token from source to destination", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("source-token"))
require.NoError(t, err)
dst := &SecureToken{}
dst.TakeFrom(src)
assert.True(t, dst.IsSet())
assert.True(t, dst.Equals("source-token"))
assert.False(t, src.IsSet(), "source should be empty after transfer")
})
t.Run("replaces existing destination token", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
err := src.Set([]byte("new-token"))
require.NoError(t, err)
dst := &SecureToken{}
err = dst.Set([]byte("old-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.True(t, dst.Equals("new-token"))
assert.False(t, dst.Equals("old-token"))
assert.False(t, src.IsSet())
})
t.Run("handles nil source", func(t *testing.T) {
t.Parallel()
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(nil)
assert.True(t, dst.IsSet(), "destination should be unchanged with nil source")
assert.True(t, dst.Equals("existing-token"))
})
t.Run("handles empty source", func(t *testing.T) {
t.Parallel()
src := &SecureToken{}
dst := &SecureToken{}
err := dst.Set([]byte("existing-token"))
require.NoError(t, err)
dst.TakeFrom(src)
assert.False(t, dst.IsSet(), "destination should be cleared when source is empty")
})
t.Run("self-transfer is no-op and does not deadlock", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
st.TakeFrom(st)
assert.True(t, st.IsSet(), "token should remain set after self-transfer")
assert.True(t, st.Equals("token"), "token value should be unchanged")
})
}
func TestSecureTokenEqualsSecure(t *testing.T) {
t.Parallel()
t.Run("returns true for matching tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("same-token"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("same-token"))
require.NoError(t, err)
assert.True(t, st1.EqualsSecure(st2))
assert.True(t, st2.EqualsSecure(st1))
})
t.Run("concurrent TakeFrom and EqualsSecure do not deadlock", func(t *testing.T) {
t.Parallel()
// This test verifies the fix for the lock ordering deadlock bug.
const iterations = 100
for range iterations {
a := &SecureToken{}
err := a.Set([]byte("token-a"))
require.NoError(t, err)
b := &SecureToken{}
err = b.Set([]byte("token-b"))
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)
// Goroutine 1: a.TakeFrom(b)
go func() {
defer wg.Done()
a.TakeFrom(b)
}()
// Goroutine 2: b.EqualsSecure(a)
go func() {
defer wg.Done()
b.EqualsSecure(a)
}()
wg.Wait()
}
})
t.Run("returns false for different tokens", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token-a"))
require.NoError(t, err)
st2 := &SecureToken{}
err = st2.Set([]byte("token-b"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when comparing with nil", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st.EqualsSecure(nil))
})
t.Run("returns false when other is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
err := st1.Set([]byte("token"))
require.NoError(t, err)
st2 := &SecureToken{}
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("returns false when self is not set", func(t *testing.T) {
t.Parallel()
st1 := &SecureToken{}
st2 := &SecureToken{}
err := st2.Set([]byte("token"))
require.NoError(t, err)
assert.False(t, st1.EqualsSecure(st2))
})
t.Run("self-comparison returns true when set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
err := st.Set([]byte("token"))
require.NoError(t, err)
assert.True(t, st.EqualsSecure(st), "self-comparison should return true and not deadlock")
})
t.Run("self-comparison returns false when not set", func(t *testing.T) {
t.Parallel()
st := &SecureToken{}
assert.False(t, st.EqualsSecure(st), "self-comparison on unset token should return false")
})
}

View File

@ -0,0 +1,93 @@
package api
import (
"context"
"encoding/json"
"net/http"
"sync"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
// MMDSClient provides access to MMDS metadata.
type MMDSClient interface {
GetAccessTokenHash(ctx context.Context) (string, error)
}
// DefaultMMDSClient is the production implementation that calls the real MMDS endpoint.
type DefaultMMDSClient struct{}
func (c *DefaultMMDSClient) GetAccessTokenHash(ctx context.Context) (string, error) {
return host.GetAccessTokenHashFromMMDS(ctx)
}
type API struct {
isNotFC bool
logger *zerolog.Logger
accessToken *SecureToken
defaults *execcontext.Defaults
mmdsChan chan *host.MMDSOpts
hyperloopLock sync.Mutex
mmdsClient MMDSClient
lastSetTime *utils.AtomicMax
initLock sync.Mutex
}
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
return &API{
logger: l,
defaults: defaults,
mmdsChan: mmdsChan,
isNotFC: isNotFC,
mmdsClient: &DefaultMMDSClient{},
lastSetTime: utils.NewAtomicMax(),
accessToken: &SecureToken{},
}
}
func (a *API) GetHealth(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Health check")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "")
w.WriteHeader(http.StatusNoContent)
}
func (a *API) GetMetrics(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
a.logger.Trace().Msg("Get metrics")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", "application/json")
metrics, err := host.GetMetrics()
if err != nil {
a.logger.Error().Err(err).Msg("Failed to get metrics")
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(metrics); err != nil {
a.logger.Error().Err(err).Msg("Failed to encode metrics")
}
}
func (a *API) getLogger(err error) *zerolog.Event {
if err != nil {
return a.logger.Error().Err(err) //nolint:zerologlint // this is only prep
}
return a.logger.Info() //nolint:zerologlint // this is only prep
}

309
envd/internal/api/upload.go Normal file
View File

@ -0,0 +1,309 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
"syscall"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
var ErrNoDiskSpace = fmt.Errorf("not enough disk space available")
func processFile(r *http.Request, path string, part io.Reader, uid, gid int, logger zerolog.Logger) (int, error) {
logger.Debug().
Str("path", path).
Msg("File processing")
err := permissions.EnsureDirs(filepath.Dir(path), uid, gid)
if err != nil {
err := fmt.Errorf("error ensuring directories: %w", err)
return http.StatusInternalServerError, err
}
canBePreChowned := false
stat, err := os.Stat(path)
if err != nil && !os.IsNotExist(err) {
errMsg := fmt.Errorf("error getting file info: %w", err)
return http.StatusInternalServerError, errMsg
} else if err == nil {
if stat.IsDir() {
err := fmt.Errorf("path is a directory: %s", path)
return http.StatusBadRequest, err
}
canBePreChowned = true
}
hasBeenChowned := false
if canBePreChowned {
err = os.Chown(path, uid, gid)
if err != nil {
if !os.IsNotExist(err) {
err = fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
} else {
hasBeenChowned = true
}
}
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = fmt.Errorf("not enough inodes available: %w", err)
return http.StatusInsufficientStorage, err
}
err := fmt.Errorf("error opening file: %w", err)
return http.StatusInternalServerError, err
}
defer file.Close()
if !hasBeenChowned {
err = os.Chown(path, uid, gid)
if err != nil {
err := fmt.Errorf("error changing file ownership: %w", err)
return http.StatusInternalServerError, err
}
}
_, err = file.ReadFrom(part)
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = ErrNoDiskSpace
if r.ContentLength > 0 {
err = fmt.Errorf("attempted to write %d bytes: %w", r.ContentLength, err)
}
return http.StatusInsufficientStorage, err
}
err = fmt.Errorf("error writing file: %w", err)
return http.StatusInternalServerError, err
}
return http.StatusNoContent, nil
}
func resolvePath(part *multipart.Part, paths *UploadSuccess, u *user.User, defaultPath *string, params PostFilesParams) (string, error) {
var pathToResolve string
if params.Path != nil {
pathToResolve = *params.Path
} else {
var err error
customPart := utils.NewCustomPart(part)
pathToResolve, err = customPart.FileNameWithPath()
if err != nil {
return "", fmt.Errorf("error getting multipart custom part file name: %w", err)
}
}
filePath, err := permissions.ExpandAndResolve(pathToResolve, u, defaultPath)
if err != nil {
return "", fmt.Errorf("error resolving path: %w", err)
}
for _, entry := range *paths {
if entry.Path == filePath {
var alreadyUploaded []string
for _, uploadedFile := range *paths {
if uploadedFile.Path != filePath {
alreadyUploaded = append(alreadyUploaded, uploadedFile.Path)
}
}
errMsg := fmt.Errorf("you cannot upload multiple files to the same path '%s' in one upload request, only the first specified file was uploaded", filePath)
if len(alreadyUploaded) > 1 {
errMsg = fmt.Errorf("%w, also the following files were uploaded: %v", errMsg, strings.Join(alreadyUploaded, ", "))
}
return "", errMsg
}
}
return filePath, nil
}
func (a *API) handlePart(r *http.Request, part *multipart.Part, paths UploadSuccess, u *user.User, uid, gid int, operationID string, params PostFilesParams) (*EntryInfo, int, error) {
defer part.Close()
if part.FormName() != "file" {
return nil, http.StatusOK, nil
}
filePath, err := resolvePath(part, &paths, u, a.defaults.Workdir, params)
if err != nil {
return nil, http.StatusBadRequest, err
}
logger := a.logger.
With().
Str(string(logs.OperationIDKey), operationID).
Str("event_type", "file_processing").
Logger()
status, err := processFile(r, filePath, part, uid, gid, logger)
if err != nil {
return nil, status, err
}
return &EntryInfo{
Path: filePath,
Name: filepath.Base(filePath),
Type: File,
}, http.StatusOK, nil
}
func (a *API) PostFiles(w http.ResponseWriter, r *http.Request, params PostFilesParams) {
// Capture original body to ensure it's always closed
originalBody := r.Body
defer originalBody.Close()
var errorCode int
var errMsg error
var path string
if params.Path != nil {
path = *params.Path
}
operationID := logs.AssignOperationID()
// signing authorization if needed
err := a.validateSigning(r, params.Signature, params.SignatureExpiration, params.Username, path, SigningWriteOperation)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("error during auth validation")
jsonError(w, http.StatusUnauthorized, err)
return
}
username, err := execcontext.ResolveDefaultUsername(params.Username, a.defaults.User)
if err != nil {
a.logger.Error().Err(err).Str(string(logs.OperationIDKey), operationID).Msg("no user specified")
jsonError(w, http.StatusBadRequest, err)
return
}
defer func() {
l := a.logger.
Err(errMsg).
Str("method", r.Method+" "+r.URL.Path).
Str(string(logs.OperationIDKey), operationID).
Str("path", path).
Str("username", username)
if errMsg != nil {
l = l.Int("error_code", errorCode)
}
l.Msg("File write")
}()
// Handle gzip-encoded request body
body, err := getDecompressedBody(r)
if err != nil {
errMsg = fmt.Errorf("error decompressing request body: %w", err)
errorCode = http.StatusBadRequest
jsonError(w, errorCode, errMsg)
return
}
defer body.Close()
r.Body = body
f, err := r.MultipartReader()
if err != nil {
errMsg = fmt.Errorf("error parsing multipart form: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
u, err := user.Lookup(username)
if err != nil {
errMsg = fmt.Errorf("error looking up user '%s': %w", username, err)
errorCode = http.StatusUnauthorized
jsonError(w, errorCode, errMsg)
return
}
uid, gid, err := permissions.GetUserIdInts(u)
if err != nil {
errMsg = fmt.Errorf("error getting user ids: %w", err)
jsonError(w, http.StatusInternalServerError, errMsg)
return
}
paths := UploadSuccess{}
for {
part, partErr := f.NextPart()
if partErr == io.EOF {
// We're done reading the parts.
break
} else if partErr != nil {
errMsg = fmt.Errorf("error reading form: %w", partErr)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
break
}
entry, status, err := a.handlePart(r, part, paths, u, uid, gid, operationID, params)
if err != nil {
errorCode = status
errMsg = err
jsonError(w, errorCode, errMsg)
return
}
if entry != nil {
paths = append(paths, *entry)
}
}
data, err := json.Marshal(paths)
if err != nil {
errMsg = fmt.Errorf("error marshaling response: %w", err)
errorCode = http.StatusInternalServerError
jsonError(w, errorCode, errMsg)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}

View File

@ -0,0 +1,249 @@
package api
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProcessFile(t *testing.T) {
t.Parallel()
uid := os.Getuid()
gid := os.Getgid()
newRequest := func(content []byte) (*http.Request, io.Reader) {
request := &http.Request{
ContentLength: int64(len(content)),
}
buffer := bytes.NewBuffer(content)
return request, buffer
}
var emptyReq http.Request
var emptyPart *bytes.Buffer
var emptyLogger zerolog.Logger
t.Run("failed to ensure directories", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid/not-real", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error ensuring directories: ")
})
t.Run("attempt to replace directory with a file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
httpStatus, err := processFile(&emptyReq, tempDir, emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, httpStatus, err.Error())
assert.ErrorContains(t, err, "path is a directory: ")
})
t.Run("fail to create file", func(t *testing.T) {
t.Parallel()
httpStatus, err := processFile(&emptyReq, "/proc/invalid-filename", emptyPart, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInternalServerError, httpStatus)
assert.ErrorContains(t, err, "error opening file: ")
})
t.Run("out of disk space", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
mountSize := 1024
tempDir := createTmpfsMount(t, mountSize)
// create test file
firstFileSize := mountSize / 2
tempFile1 := filepath.Join(tempDir, "test-file-1")
// fill it up
cmd := exec.CommandContext(t.Context(),
"dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", firstFileSize), "count=1")
err := cmd.Run()
require.NoError(t, err)
// create a new file that would fill up the
secondFileContents := make([]byte, mountSize*2)
for index := range secondFileContents {
secondFileContents[index] = 'a'
}
// try to replace it
request, buffer := newRequest(secondFileContents)
tempFile2 := filepath.Join(tempDir, "test-file-2")
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.Error(t, err)
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
assert.ErrorContains(t, err, "attempted to write 2048 bytes: not enough disk space")
})
t.Run("happy path", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, content, data)
})
t.Run("overwrite file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to replace it
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
})
t.Run("write new file on full disk", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
sizeInBytes := 1024
tempDir := createTmpfsMount(t, 1024)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", sizeInBytes), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough disk space available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("write new file with no inodes available", func(t *testing.T) {
t.Parallel()
// make a tiny tmpfs mount
tempDir := createTmpfsMountWithInodes(t, 1024, 2)
// create test file
tempFile1 := filepath.Join(tempDir, "test-file")
// fill it up
cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+tempFile1, fmt.Sprintf("bs=%d", 100), "count=1")
err := cmd.Run()
require.NoError(t, err)
// try to write a new file
tempFile2 := filepath.Join(tempDir, "test-file-2")
content := []byte("test-file-contents")
request, buffer := newRequest(content)
httpStatus, err := processFile(request, tempFile2, buffer, uid, gid, emptyLogger)
require.ErrorContains(t, err, "not enough inodes available")
assert.Equal(t, http.StatusInsufficientStorage, httpStatus)
})
t.Run("update sysfs or other virtual fs", func(t *testing.T) {
t.Parallel()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
filePath := "/sys/fs/cgroup/user.slice/cpu.weight"
newContent := []byte("102\n")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, filePath, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(filePath)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
t.Run("replace file", func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "test-file")
err := os.WriteFile(tempFile, []byte("old-contents"), 0o644)
require.NoError(t, err)
newContent := []byte("new-file-contents")
request, buffer := newRequest(newContent)
httpStatus, err := processFile(request, tempFile, buffer, uid, gid, emptyLogger)
require.NoError(t, err)
assert.Equal(t, http.StatusNoContent, httpStatus)
data, err := os.ReadFile(tempFile)
require.NoError(t, err)
assert.Equal(t, newContent, data)
})
}
func createTmpfsMount(t *testing.T, sizeInBytes int) string {
t.Helper()
return createTmpfsMountWithInodes(t, sizeInBytes, 5)
}
func createTmpfsMountWithInodes(t *testing.T, sizeInBytes, inodesCount int) string {
t.Helper()
if os.Geteuid() != 0 {
t.Skip("skipping sysfs updates: Operation not permitted with non-root user")
}
tempDir := t.TempDir()
cmd := exec.CommandContext(t.Context(),
"mount",
"tmpfs",
tempDir,
"-t", "tmpfs",
"-o", fmt.Sprintf("size=%d,nr_inodes=%d", sizeInBytes, inodesCount))
err := cmd.Run()
require.NoError(t, err)
t.Cleanup(func() {
ctx := context.WithoutCancel(t.Context())
cmd := exec.CommandContext(ctx, "umount", tempDir)
err := cmd.Run()
require.NoError(t, err)
})
return tempDir
}

View File

@ -0,0 +1,37 @@
package execcontext
import (
"errors"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Defaults struct {
EnvVars *utils.Map[string, string]
User string
Workdir *string
}
func ResolveDefaultWorkdir(workdir string, defaultWorkdir *string) string {
if workdir != "" {
return workdir
}
if defaultWorkdir != nil {
return *defaultWorkdir
}
return ""
}
func ResolveDefaultUsername(username *string, defaultUsername string) (string, error) {
if username != nil {
return *username, nil
}
if defaultUsername != "" {
return defaultUsername, nil
}
return "", errors.New("username not provided")
}

View File

@ -0,0 +1,93 @@
package host
import (
"math"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
"golang.org/x/sys/unix"
)
type Metrics struct {
Timestamp int64 `json:"ts"` // Unix Timestamp in UTC
CPUCount uint32 `json:"cpu_count"` // Total CPU cores
CPUUsedPercent float32 `json:"cpu_used_pct"` // Percent rounded to 2 decimal places
// Deprecated: kept for backwards compatibility with older orchestrators.
MemTotalMiB uint64 `json:"mem_total_mib"` // Total virtual memory in MiB
// Deprecated: kept for backwards compatibility with older orchestrators.
MemUsedMiB uint64 `json:"mem_used_mib"` // Used virtual memory in MiB
MemTotal uint64 `json:"mem_total"` // Total virtual memory in bytes
MemUsed uint64 `json:"mem_used"` // Used virtual memory in bytes
DiskUsed uint64 `json:"disk_used"` // Used disk space in bytes
DiskTotal uint64 `json:"disk_total"` // Total disk space in bytes
}
func GetMetrics() (*Metrics, error) {
v, err := mem.VirtualMemory()
if err != nil {
return nil, err
}
memUsedMiB := v.Used / 1024 / 1024
memTotalMiB := v.Total / 1024 / 1024
cpuTotal, err := cpu.Counts(true)
if err != nil {
return nil, err
}
cpuUsedPcts, err := cpu.Percent(0, false)
if err != nil {
return nil, err
}
cpuUsedPct := cpuUsedPcts[0]
cpuUsedPctRounded := float32(cpuUsedPct)
if cpuUsedPct > 0 {
cpuUsedPctRounded = float32(math.Round(cpuUsedPct*100) / 100)
}
diskMetrics, err := diskStats("/")
if err != nil {
return nil, err
}
return &Metrics{
Timestamp: time.Now().UTC().Unix(),
CPUCount: uint32(cpuTotal),
CPUUsedPercent: cpuUsedPctRounded,
MemUsedMiB: memUsedMiB,
MemTotalMiB: memTotalMiB,
MemTotal: v.Total,
MemUsed: v.Used,
DiskUsed: diskMetrics.Total - diskMetrics.Available,
DiskTotal: diskMetrics.Total,
}, nil
}
type diskSpace struct {
Total uint64
Available uint64
}
func diskStats(path string) (diskSpace, error) {
var st unix.Statfs_t
if err := unix.Statfs(path, &st); err != nil {
return diskSpace{}, err
}
block := uint64(st.Bsize)
// all data blocks
total := st.Blocks * block
// blocks available
available := st.Bavail * block
return diskSpace{Total: total, Available: available}, nil
}

182
envd/internal/host/mmds.go Normal file
View File

@ -0,0 +1,182 @@
package host
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
const (
WrennRunDir = "/run/wrenn" // store sandbox metadata files here
mmdsDefaultAddress = "169.254.169.254"
mmdsTokenExpiration = 60 * time.Second
mmdsAccessTokenRequestClientTimeout = 10 * time.Second
)
var mmdsAccessTokenClient = &http.Client{
Timeout: mmdsAccessTokenRequestClientTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
},
}
type MMDSOpts struct {
SandboxID string `json:"instanceID"`
TemplateID string `json:"envID"`
LogsCollectorAddress string `json:"address"`
AccessTokenHash string `json:"accessTokenHash"`
}
func (opts *MMDSOpts) Update(sandboxID, templateID, collectorAddress string) {
opts.SandboxID = sandboxID
opts.TemplateID = templateID
opts.LogsCollectorAddress = collectorAddress
}
func (opts *MMDSOpts) AddOptsToJSON(jsonLogs []byte) ([]byte, error) {
parsed := make(map[string]any)
err := json.Unmarshal(jsonLogs, &parsed)
if err != nil {
return nil, err
}
parsed["instanceID"] = opts.SandboxID
parsed["envID"] = opts.TemplateID
data, err := json.Marshal(parsed)
return data, err
}
func getMMDSToken(ctx context.Context, client *http.Client) (string, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://"+mmdsDefaultAddress+"/latest/api/token", &bytes.Buffer{})
if err != nil {
return "", err
}
request.Header["X-metadata-token-ttl-seconds"] = []string{fmt.Sprint(mmdsTokenExpiration.Seconds())}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
token := string(body)
if len(token) == 0 {
return "", fmt.Errorf("mmds token is an empty string")
}
return token, nil
}
func getMMDSOpts(ctx context.Context, client *http.Client, token string) (*MMDSOpts, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+mmdsDefaultAddress, &bytes.Buffer{})
if err != nil {
return nil, err
}
request.Header["X-metadata-token"] = []string{token}
request.Header["Accept"] = []string{"application/json"}
response, err := client.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
var opts MMDSOpts
err = json.Unmarshal(body, &opts)
if err != nil {
return nil, err
}
return &opts, nil
}
// GetAccessTokenHashFromMMDS reads the access token hash from MMDS.
// This is used to validate that /init requests come from the orchestrator.
func GetAccessTokenHashFromMMDS(ctx context.Context) (string, error) {
token, err := getMMDSToken(ctx, mmdsAccessTokenClient)
if err != nil {
return "", fmt.Errorf("failed to get MMDS token: %w", err)
}
opts, err := getMMDSOpts(ctx, mmdsAccessTokenClient, token)
if err != nil {
return "", fmt.Errorf("failed to get MMDS opts: %w", err)
}
return opts.AccessTokenHash, nil
}
func PollForMMDSOpts(ctx context.Context, mmdsChan chan<- *MMDSOpts, envVars *utils.Map[string, string]) {
httpClient := &http.Client{}
defer httpClient.CloseIdleConnections()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "context cancelled while waiting for mmds opts")
return
case <-ticker.C:
token, err := getMMDSToken(ctx, httpClient)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds token: %v\n", err)
continue
}
mmdsOpts, err := getMMDSOpts(ctx, httpClient, token)
if err != nil {
fmt.Fprintf(os.Stderr, "error getting mmds opts: %v\n", err)
continue
}
envVars.Store("WRENN_SANDBOX_ID", mmdsOpts.SandboxID)
envVars.Store("WRENN_TEMPLATE_ID", mmdsOpts.TemplateID)
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_SANDBOX_ID"), []byte(mmdsOpts.SandboxID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing sandbox ID file: %v\n", err)
}
if err := os.WriteFile(filepath.Join(WrennRunDir, ".WRENN_TEMPLATE_ID"), []byte(mmdsOpts.TemplateID), 0o666); err != nil {
fmt.Fprintf(os.Stderr, "error writing template ID file: %v\n", err)
}
if mmdsOpts.LogsCollectorAddress != "" {
mmdsChan <- mmdsOpts
}
return
}
}
}

View File

@ -0,0 +1,47 @@
package logs
import (
"time"
"github.com/rs/zerolog"
)
const (
defaultMaxBufferSize = 2 << 15
defaultTimeout = 2 * time.Second
)
func LogBufferedDataEvents(dataCh <-chan []byte, logger *zerolog.Logger, eventType string) {
timer := time.NewTicker(defaultTimeout)
defer timer.Stop()
var buffer []byte
defer func() {
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event (flush)")
}
}()
for {
select {
case <-timer.C:
if len(buffer) > 0 {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
}
case data, ok := <-dataCh:
if !ok {
return
}
buffer = append(buffer, data...)
if len(buffer) >= defaultMaxBufferSize {
logger.Info().Str(eventType, string(buffer)).Msg("Streaming process event")
buffer = nil
continue
}
}
}
}

View File

@ -0,0 +1,172 @@
package exporter
import (
"bytes"
"context"
"fmt"
"log"
"net/http"
"os"
"sync"
"time"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
)
const ExporterTimeout = 10 * time.Second
type HTTPExporter struct {
client http.Client
logs [][]byte
isNotFC bool
mmdsOpts *host.MMDSOpts
// Concurrency coordination
triggers chan struct{}
logLock sync.RWMutex
mmdsLock sync.RWMutex
startOnce sync.Once
}
func NewHTTPLogsExporter(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *HTTPExporter {
exporter := &HTTPExporter{
client: http.Client{
Timeout: ExporterTimeout,
},
triggers: make(chan struct{}, 1),
isNotFC: isNotFC,
startOnce: sync.Once{},
mmdsOpts: &host.MMDSOpts{
SandboxID: "unknown",
TemplateID: "unknown",
LogsCollectorAddress: "",
},
}
go exporter.listenForMMDSOptsAndStart(ctx, mmdsChan)
return exporter
}
func (w *HTTPExporter) sendInstanceLogs(ctx context.Context, logs []byte, address string) error {
if address == "" {
return nil
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, address, bytes.NewBuffer(logs))
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
response, err := w.client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
return nil
}
func printLog(logs []byte) {
fmt.Fprintf(os.Stdout, "%v", string(logs))
}
func (w *HTTPExporter) listenForMMDSOptsAndStart(ctx context.Context, mmdsChan <-chan *host.MMDSOpts) {
for {
select {
case <-ctx.Done():
return
case mmdsOpts, ok := <-mmdsChan:
if !ok {
return
}
w.mmdsLock.Lock()
w.mmdsOpts.Update(mmdsOpts.SandboxID, mmdsOpts.TemplateID, mmdsOpts.LogsCollectorAddress)
w.mmdsLock.Unlock()
w.startOnce.Do(func() {
go w.start(ctx)
})
}
}
}
func (w *HTTPExporter) start(ctx context.Context) {
for range w.triggers {
logs := w.getAllLogs()
if len(logs) == 0 {
continue
}
if w.isNotFC {
for _, log := range logs {
fmt.Fprintf(os.Stdout, "%v", string(log))
}
continue
}
for _, logLine := range logs {
w.mmdsLock.RLock()
logLineWithOpts, err := w.mmdsOpts.AddOptsToJSON(logLine)
w.mmdsLock.RUnlock()
if err != nil {
log.Printf("error adding instance logging options (%+v) to JSON (%+v) with logs : %v\n", w.mmdsOpts, logLine, err)
printLog(logLine)
continue
}
err = w.sendInstanceLogs(ctx, logLineWithOpts, w.mmdsOpts.LogsCollectorAddress)
if err != nil {
log.Printf("error sending instance logs: %+v", err)
printLog(logLine)
continue
}
}
}
}
func (w *HTTPExporter) resumeProcessing() {
select {
case w.triggers <- struct{}{}:
default:
// Exporter processing already triggered
// This is expected behavior if the exporter is already processing logs
}
}
func (w *HTTPExporter) Write(logs []byte) (int, error) {
logsCopy := make([]byte, len(logs))
copy(logsCopy, logs)
go w.addLogs(logsCopy)
return len(logs), nil
}
func (w *HTTPExporter) getAllLogs() [][]byte {
w.logLock.Lock()
defer w.logLock.Unlock()
logs := w.logs
w.logs = nil
return logs
}
func (w *HTTPExporter) addLogs(logs []byte) {
w.logLock.Lock()
defer w.logLock.Unlock()
w.logs = append(w.logs, logs)
w.resumeProcessing()
}

View File

@ -0,0 +1,172 @@
package logs
import (
"context"
"fmt"
"strconv"
"strings"
"sync/atomic"
"connectrpc.com/connect"
"github.com/rs/zerolog"
)
type OperationID string
const (
OperationIDKey OperationID = "operation_id"
DefaultHTTPMethod string = "POST"
)
var operationID = atomic.Int32{}
func AssignOperationID() string {
id := operationID.Add(1)
return strconv.Itoa(int(id))
}
func AddRequestIDToContext(ctx context.Context) context.Context {
return context.WithValue(ctx, OperationIDKey, AssignOperationID())
}
func formatMethod(method string) string {
parts := strings.Split(method, ".")
if len(parts) < 2 {
return method
}
split := strings.Split(parts[1], "/")
if len(split) < 2 {
return method
}
servicePart := split[0]
servicePart = strings.ToUpper(servicePart[:1]) + servicePart[1:]
methodPart := split[1]
methodPart = strings.ToLower(methodPart[:1]) + methodPart[1:]
return fmt.Sprintf("%s %s", servicePart, methodPart)
}
func NewUnaryLogInterceptor(logger *zerolog.Logger) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
ctx = AddRequestIDToContext(ctx)
res, err := next(ctx, req)
l := logger.
Err(err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
l = l.Int("error_code", int(connect.CodeOf(err)))
}
if req != nil {
l = l.Interface("request", req.Any())
}
if res != nil && err == nil {
l = l.Interface("response", res.Any())
}
if res == nil && err == nil {
l = l.Interface("response", nil)
}
l.Msg(formatMethod(req.Spec().Procedure))
return res, err
})
}
return connect.UnaryInterceptorFunc(interceptor)
}
func LogServerStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
req *connect.Request[R],
stream *connect.ServerStream[T],
handler func(ctx context.Context, req *connect.Request[R], stream *connect.ServerStream[T]) error,
) error {
ctx = AddRequestIDToContext(ctx)
l := logger.Debug().
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if req != nil {
l = l.Interface("request", req.Any())
}
l.Msg(fmt.Sprintf("%s (server stream start)", formatMethod(req.Spec().Procedure)))
err := handler(ctx, req, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+req.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
} else {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (server stream end)", formatMethod(req.Spec().Procedure)))
return err
}
func LogClientStreamWithoutEvents[T any, R any](
ctx context.Context,
logger *zerolog.Logger,
stream *connect.ClientStream[T],
handler func(ctx context.Context, stream *connect.ClientStream[T]) (*connect.Response[R], error),
) (*connect.Response[R], error) {
ctx = AddRequestIDToContext(ctx)
logger.Debug().
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string)).
Msg(fmt.Sprintf("%s (client stream start)", formatMethod(stream.Spec().Procedure)))
res, err := handler(ctx, stream)
logEvent := getErrDebugLogEvent(logger, err).
Str("method", DefaultHTTPMethod+" "+stream.Spec().Procedure).
Str(string(OperationIDKey), ctx.Value(OperationIDKey).(string))
if err != nil {
logEvent = logEvent.Int("error_code", int(connect.CodeOf(err)))
}
if res != nil && err == nil {
logEvent = logEvent.Interface("response", res.Any())
}
if res == nil && err == nil {
logEvent = logEvent.Interface("response", nil)
}
logEvent.Msg(fmt.Sprintf("%s (client stream end)", formatMethod(stream.Spec().Procedure)))
return res, err
}
// Return logger with error level if err is not nil, otherwise return logger with debug level
func getErrDebugLogEvent(logger *zerolog.Logger, err error) *zerolog.Event {
if err != nil {
return logger.Error().Err(err) //nolint:zerologlint // this builds an event, it is not expected to return it
}
return logger.Debug() //nolint:zerologlint // this builds an event, it is not expected to return it
}

View File

@ -0,0 +1,35 @@
package logs
import (
"context"
"io"
"os"
"time"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs/exporter"
)
func NewLogger(ctx context.Context, isNotFC bool, mmdsChan <-chan *host.MMDSOpts) *zerolog.Logger {
zerolog.TimestampFieldName = "timestamp"
zerolog.TimeFieldFormat = time.RFC3339Nano
exporters := []io.Writer{}
if isNotFC {
exporters = append(exporters, os.Stdout)
} else {
exporters = append(exporters, exporter.NewHTTPLogsExporter(ctx, isNotFC, mmdsChan), os.Stdout)
}
l := zerolog.
New(io.MultiWriter(exporters...)).
With().
Timestamp().
Logger().
Level(zerolog.DebugLevel)
return &l
}

View File

@ -0,0 +1,47 @@
package permissions
import (
"context"
"fmt"
"os/user"
"connectrpc.com/authn"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
)
func AuthenticateUsername(_ context.Context, req authn.Request) (any, error) {
username, _, ok := req.BasicAuth()
if !ok {
// When no username is provided, ignore the authentication method (not all endpoints require it)
// Missing user is then handled in the GetAuthUser function
return nil, nil
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid username: '%s'", username)
}
return u, nil
}
func GetAuthUser(ctx context.Context, defaultUser string) (*user.User, error) {
u, ok := authn.GetInfo(ctx).(*user.User)
if !ok {
username, err := execcontext.ResolveDefaultUsername(nil, defaultUser)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("no user specified"))
}
u, err := GetUser(username)
if err != nil {
return nil, authn.Errorf("invalid default user: '%s'", username)
}
return u, nil
}
return u, nil
}

View File

@ -0,0 +1,29 @@
package permissions
import (
"strconv"
"time"
"connectrpc.com/connect"
)
const defaultKeepAliveInterval = 90 * time.Second
func GetKeepAliveTicker[T any](req *connect.Request[T]) (*time.Ticker, func()) {
keepAliveIntervalHeader := req.Header().Get("Keepalive-Ping-Interval")
var interval time.Duration
keepAliveIntervalInt, err := strconv.Atoi(keepAliveIntervalHeader)
if err != nil {
interval = defaultKeepAliveInterval
} else {
interval = time.Duration(keepAliveIntervalInt) * time.Second
}
ticker := time.NewTicker(interval)
return ticker, func() {
ticker.Reset(interval)
}
}

View File

@ -0,0 +1,96 @@
package permissions
import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"slices"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
)
func expand(path, homedir string) (string, error) {
if len(path) == 0 {
return path, nil
}
if path[0] != '~' {
return path, nil
}
if len(path) > 1 && path[1] != '/' && path[1] != '\\' {
return "", errors.New("cannot expand user-specific home dir")
}
return filepath.Join(homedir, path[1:]), nil
}
func ExpandAndResolve(path string, user *user.User, defaultPath *string) (string, error) {
path = execcontext.ResolveDefaultWorkdir(path, defaultPath)
path, err := expand(path, user.HomeDir)
if err != nil {
return "", fmt.Errorf("failed to expand path '%s' for user '%s': %w", path, user.Username, err)
}
if filepath.IsAbs(path) {
return path, nil
}
// The filepath.Abs can correctly resolve paths like /home/user/../file
path = filepath.Join(user.HomeDir, path)
abs, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("failed to resolve path '%s' for user '%s' with home dir '%s': %w", path, user.Username, user.HomeDir, err)
}
return abs, nil
}
func getSubpaths(path string) (subpaths []string) {
for {
subpaths = append(subpaths, path)
path = filepath.Dir(path)
if path == "/" {
break
}
}
slices.Reverse(subpaths)
return subpaths
}
func EnsureDirs(path string, uid, gid int) error {
subpaths := getSubpaths(path)
for _, subpath := range subpaths {
info, err := os.Stat(subpath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to stat directory: %w", err)
}
if err != nil && os.IsNotExist(err) {
err = os.Mkdir(subpath, 0o755)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
err = os.Chown(subpath, uid, gid)
if err != nil {
return fmt.Errorf("failed to chown directory: %w", err)
}
continue
}
if !info.IsDir() {
return fmt.Errorf("path is a file: %s", subpath)
}
}
return nil
}

View File

@ -0,0 +1,44 @@
package permissions
import (
"fmt"
"os/user"
"strconv"
)
func GetUserIdUints(u *user.User) (uid, gid uint32, err error) {
newUID, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
}
newGID, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
}
return uint32(newUID), uint32(newGID), nil
}
func GetUserIdInts(u *user.User) (uid, gid int, err error) {
newUID, err := strconv.ParseInt(u.Uid, 10, strconv.IntSize)
if err != nil {
return 0, 0, fmt.Errorf("error parsing uid '%s': %w", u.Uid, err)
}
newGID, err := strconv.ParseInt(u.Gid, 10, strconv.IntSize)
if err != nil {
return 0, 0, fmt.Errorf("error parsing gid '%s': %w", u.Gid, err)
}
return int(newUID), int(newGID), nil
}
func GetUser(username string) (u *user.User, err error) {
u, err = user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("error looking up user '%s': %w", username, err)
}
return u, nil
}

View File

@ -0,0 +1,218 @@
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
// and launches `socat` process for every such port in the background.
// socat forward traffic from `sourceIP`:port to the 127.0.0.1:port.
// WARNING: portf isn't thread safe!
package port
import (
"context"
"fmt"
"net"
"os/exec"
"syscall"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
)
type PortState string
const (
PortStateForward PortState = "FORWARD"
PortStateDelete PortState = "DELETE"
)
var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
type PortToForward struct {
socat *exec.Cmd
// Process ID of the process that's listening on port.
pid int32
// family version of the ip.
family uint32
state PortState
port uint32
}
type Forwarder struct {
logger *zerolog.Logger
cgroupManager cgroups.Manager
// Map of ports that are being currently forwarded.
ports map[string]*PortToForward
scannerSubscriber *ScannerSubscriber
sourceIP net.IP
}
func NewForwarder(
logger *zerolog.Logger,
scanner *Scanner,
cgroupManager cgroups.Manager,
) *Forwarder {
scannerSub := scanner.AddSubscriber(
logger,
"port-forwarder",
// We only want to forward ports that are actively listening on localhost.
&ScannerFilter{
IPs: []string{"127.0.0.1", "localhost", "::1"},
State: "LISTEN",
},
)
return &Forwarder{
logger: logger,
sourceIP: defaultGatewayIP,
ports: make(map[string]*PortToForward),
scannerSubscriber: scannerSub,
cgroupManager: cgroupManager,
}
}
func (f *Forwarder) StartForwarding(ctx context.Context) {
if f.scannerSubscriber == nil {
f.logger.Error().Msg("Cannot start forwarding because scanner subscriber is nil")
return
}
for {
// procs is an array of currently opened ports.
if procs, ok := <-f.scannerSubscriber.Messages; ok {
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
// and maybe remove some.
// Go through the ports that are currently being forwarded and set all of them
// to the `DELETE` state. We don't know yet if they will be there after refresh.
for _, v := range f.ports {
v.state = PortStateDelete
}
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
// This will make sure we won't delete them later.
for _, p := range procs {
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
// We check if the opened port is in our map of forwarded ports.
val, portOk := f.ports[key]
if portOk {
// Just mark the port as being forwarded so we don't delete it.
// The actual socat process that handles forwarding should be running from the last iteration.
val.state = PortStateForward
} else {
f.logger.Debug().
Str("ip", p.Laddr.IP).
Uint32("port", p.Laddr.Port).
Uint32("family", familyToIPVersion(p.Family)).
Str("state", p.Status).
Msg("Detected new opened port on localhost that is not forwarded")
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
ptf := &PortToForward{
pid: p.Pid,
port: p.Laddr.Port,
state: PortStateForward,
family: familyToIPVersion(p.Family),
}
f.ports[key] = ptf
f.startPortForwarding(ctx, ptf)
}
}
// We go through the ports map one more time and stop forwarding all ports
// that stayed marked as "DELETE".
for _, v := range f.ports {
if v.state == PortStateDelete {
f.stopPortForwarding(v)
}
}
}
}
}
func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
cmd := exec.CommandContext(ctx,
"socat", "-d", "-d", "-d",
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
)
cgroupFD, ok := f.cgroupManager.GetFileDescriptor(cgroups.ProcessTypeSocat)
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
CgroupFD: cgroupFD,
UseCgroupFD: ok,
}
f.logger.Debug().
Str("socatCmd", cmd.String()).
Int32("pid", p.pid).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
Msg("About to start port forwarding")
if err := cmd.Start(); err != nil {
f.logger.
Error().
Str("socatCmd", cmd.String()).
Err(err).
Msg("Failed to start port forwarding - failed to start socat")
return
}
go func() {
if err := cmd.Wait(); err != nil {
f.logger.
Debug().
Str("socatCmd", cmd.String()).
Err(err).
Msg("Port forwarding socat process exited")
}
}()
p.socat = cmd
}
func (f *Forwarder) stopPortForwarding(p *PortToForward) {
if p.socat == nil {
return
}
defer func() { p.socat = nil }()
logger := f.logger.With().
Str("socatCmd", p.socat.String()).
Int32("pid", p.pid).
Uint32("family", p.family).
IPAddr("sourceIP", f.sourceIP.To4()).
Uint32("port", p.port).
Logger()
logger.Debug().Msg("Stopping port forwarding")
if err := syscall.Kill(-p.socat.Process.Pid, syscall.SIGKILL); err != nil {
logger.Error().Err(err).Msg("Failed to kill process group")
return
}
logger.Debug().Msg("Stopped port forwarding")
}
func familyToIPVersion(family uint32) uint32 {
switch family {
case syscall.AF_INET:
return 4
case syscall.AF_INET6:
return 6
default:
return 0 // Unknown or unsupported family
}
}

View File

@ -0,0 +1,59 @@
package port
import (
"time"
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
)
type Scanner struct {
Processes chan net.ConnectionStat
scanExit chan struct{}
subs *smap.Map[*ScannerSubscriber]
period time.Duration
}
func (s *Scanner) Destroy() {
close(s.scanExit)
}
func NewScanner(period time.Duration) *Scanner {
return &Scanner{
period: period,
subs: smap.New[*ScannerSubscriber](),
scanExit: make(chan struct{}),
Processes: make(chan net.ConnectionStat),
}
}
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
subscriber := NewScannerSubscriber(logger, id, filter)
s.subs.Insert(id, subscriber)
return subscriber
}
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
s.subs.Remove(sub.ID())
sub.Destroy()
}
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
func (s *Scanner) ScanAndBroadcast() {
for {
// tcp monitors both ipv4 and ipv6 connections.
processes, _ := net.Connections("tcp")
for _, sub := range s.subs.Items() {
sub.Signal(processes)
}
select {
case <-s.scanExit:
return
default:
time.Sleep(s.period)
}
}
}

View File

@ -0,0 +1,50 @@
package port
import (
"github.com/rs/zerolog"
"github.com/shirou/gopsutil/v4/net"
)
// If we want to create a listener/subscriber pattern somewhere else we should move
// from a concrete implementation to combination of generics and interfaces.
type ScannerSubscriber struct {
logger *zerolog.Logger
filter *ScannerFilter
Messages chan ([]net.ConnectionStat)
id string
}
func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
return &ScannerSubscriber{
logger: logger,
id: id,
filter: filter,
Messages: make(chan []net.ConnectionStat),
}
}
func (ss *ScannerSubscriber) ID() string {
return ss.id
}
func (ss *ScannerSubscriber) Destroy() {
close(ss.Messages)
}
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
// Filter isn't specified. Accept everything.
if ss.filter == nil {
ss.Messages <- proc
} else {
filtered := []net.ConnectionStat{}
for i := range proc {
// We need to access the list directly otherwise there will be implicit memory aliasing
// If the filter matched a process, we will send it to a channel.
if ss.filter.Match(&proc[i]) {
filtered = append(filtered, proc[i])
}
}
ss.Messages <- filtered
}
}

View File

@ -0,0 +1,27 @@
package port
import (
"slices"
"github.com/shirou/gopsutil/v4/net"
)
type ScannerFilter struct {
State string
IPs []string
}
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
// Filter is an empty struct.
if sf.State == "" && len(sf.IPs) == 0 {
return false
}
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
if ipMatch && sf.State == proc.Status {
return true
}
return false
}

View File

@ -0,0 +1,127 @@
package cgroups
import (
"errors"
"fmt"
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
type Cgroup2Manager struct {
cgroupFDs map[ProcessType]int
}
var _ Manager = (*Cgroup2Manager)(nil)
type cgroup2Config struct {
rootPath string
processTypes map[ProcessType]Cgroup2Config
}
type Cgroup2ManagerOption func(*cgroup2Config)
func WithCgroup2RootSysFSPath(path string) Cgroup2ManagerOption {
return func(config *cgroup2Config) {
config.rootPath = path
}
}
func WithCgroup2ProcessType(processType ProcessType, path string, properties map[string]string) Cgroup2ManagerOption {
return func(config *cgroup2Config) {
if config.processTypes == nil {
config.processTypes = make(map[ProcessType]Cgroup2Config)
}
config.processTypes[processType] = Cgroup2Config{Path: path, Properties: properties}
}
}
type Cgroup2Config struct {
Path string
Properties map[string]string
}
func NewCgroup2Manager(opts ...Cgroup2ManagerOption) (*Cgroup2Manager, error) {
config := cgroup2Config{
rootPath: "/sys/fs/cgroup",
}
for _, opt := range opts {
opt(&config)
}
cgroupFDs, err := createCgroups(config)
if err != nil {
return nil, fmt.Errorf("failed to create cgroups: %w", err)
}
return &Cgroup2Manager{cgroupFDs: cgroupFDs}, nil
}
func createCgroups(configs cgroup2Config) (map[ProcessType]int, error) {
var (
results = make(map[ProcessType]int)
errs []error
)
for procType, config := range configs.processTypes {
fullPath := filepath.Join(configs.rootPath, config.Path)
fd, err := createCgroup(fullPath, config.Properties)
if err != nil {
errs = append(errs, fmt.Errorf("failed to create %s cgroup: %w", procType, err))
continue
}
results[procType] = fd
}
if len(errs) > 0 {
for procType, fd := range results {
err := unix.Close(fd)
if err != nil {
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
}
}
return nil, errors.Join(errs...)
}
return results, nil
}
func createCgroup(fullPath string, properties map[string]string) (int, error) {
if err := os.MkdirAll(fullPath, 0o755); err != nil {
return -1, fmt.Errorf("failed to create cgroup root: %w", err)
}
var errs []error
for name, value := range properties {
if err := os.WriteFile(filepath.Join(fullPath, name), []byte(value), 0o644); err != nil {
errs = append(errs, fmt.Errorf("failed to write cgroup property: %w", err))
}
}
if len(errs) > 0 {
return -1, errors.Join(errs...)
}
return unix.Open(fullPath, unix.O_RDONLY, 0)
}
func (c Cgroup2Manager) GetFileDescriptor(procType ProcessType) (int, bool) {
fd, ok := c.cgroupFDs[procType]
return fd, ok
}
func (c Cgroup2Manager) Close() error {
var errs []error
for procType, fd := range c.cgroupFDs {
if err := unix.Close(fd); err != nil {
errs = append(errs, fmt.Errorf("failed to close cgroup fd for %s: %w", procType, err))
}
delete(c.cgroupFDs, procType)
}
return errors.Join(errs...)
}

View File

@ -0,0 +1,185 @@
package cgroups
import (
"context"
"fmt"
"math/rand"
"os"
"os/exec"
"strconv"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
oneByte = 1
kilobyte = 1024 * oneByte
megabyte = 1024 * kilobyte
)
func TestCgroupRoundTrip(t *testing.T) {
t.Parallel()
if os.Geteuid() != 0 {
t.Skip("must run as root")
return
}
maxTimeout := time.Second * 5
t.Run("process does not die without cgroups", func(t *testing.T) {
t.Parallel()
// create manager
m, err := NewCgroup2Manager()
require.NoError(t, err)
// create new child process
cmd := startProcess(t, m, "not-a-real-one")
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("process dies with cgroups", func(t *testing.T) {
t.Parallel()
cgroupPath := createCgroupPath(t, "real-one")
// create manager
m, err := NewCgroup2Manager(
WithCgroup2ProcessType(ProcessTypePTY, cgroupPath, map[string]string{
"memory.max": strconv.Itoa(1 * megabyte),
}),
)
require.NoError(t, err)
t.Cleanup(func() {
err := m.Close()
assert.NoError(t, err)
})
// create new child process
cmd := startProcess(t, m, ProcessTypePTY)
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
// verify process exited correctly
var exitErr *exec.ExitError
require.ErrorAs(t, err, &exitErr)
assert.Equal(t, "signal: killed", exitErr.Error())
assert.False(t, exitErr.Exited())
assert.False(t, exitErr.Success())
assert.Equal(t, -1, exitErr.ExitCode())
// dig a little deeper
ws, ok := exitErr.Sys().(syscall.WaitStatus)
require.True(t, ok)
assert.Equal(t, syscall.SIGKILL, ws.Signal())
assert.True(t, ws.Signaled())
assert.False(t, ws.Stopped())
assert.False(t, ws.Continued())
assert.False(t, ws.CoreDump())
assert.False(t, ws.Exited())
assert.Equal(t, -1, ws.ExitStatus())
})
t.Run("process cannot be spawned because memory limit is too low", func(t *testing.T) {
t.Parallel()
cgroupPath := createCgroupPath(t, "real-one")
// create manager
m, err := NewCgroup2Manager(
WithCgroup2ProcessType(ProcessTypeSocat, cgroupPath, map[string]string{
"memory.max": strconv.Itoa(1 * kilobyte),
}),
)
require.NoError(t, err)
t.Cleanup(func() {
err := m.Close()
assert.NoError(t, err)
})
// create new child process
cmd := startProcess(t, m, ProcessTypeSocat)
// wait for child process to die
err = waitForProcess(t, cmd, maxTimeout)
// verify process exited correctly
var exitErr *exec.ExitError
require.ErrorAs(t, err, &exitErr)
assert.Equal(t, "exit status 253", exitErr.Error())
assert.True(t, exitErr.Exited())
assert.False(t, exitErr.Success())
assert.Equal(t, 253, exitErr.ExitCode())
// dig a little deeper
ws, ok := exitErr.Sys().(syscall.WaitStatus)
require.True(t, ok)
assert.Equal(t, syscall.Signal(-1), ws.Signal())
assert.False(t, ws.Signaled())
assert.False(t, ws.Stopped())
assert.False(t, ws.Continued())
assert.False(t, ws.CoreDump())
assert.True(t, ws.Exited())
assert.Equal(t, 253, ws.ExitStatus())
})
}
func createCgroupPath(t *testing.T, s string) string {
t.Helper()
randPart := rand.Int()
return fmt.Sprintf("envd-test-%s-%d", s, randPart)
}
func startProcess(t *testing.T, m *Cgroup2Manager, pt ProcessType) *exec.Cmd {
t.Helper()
cmdName, args := "bash", []string{"-c", `sleep 1 && tail /dev/zero`}
cmd := exec.CommandContext(t.Context(), cmdName, args...)
fd, ok := m.GetFileDescriptor(pt)
cmd.SysProcAttr = &syscall.SysProcAttr{
UseCgroupFD: ok,
CgroupFD: fd,
}
err := cmd.Start()
require.NoError(t, err)
return cmd
}
func waitForProcess(t *testing.T, cmd *exec.Cmd, timeout time.Duration) error {
t.Helper()
done := make(chan error, 1)
go func() {
defer close(done)
done <- cmd.Wait()
}()
ctx, cancel := context.WithTimeout(t.Context(), timeout)
t.Cleanup(cancel)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}

View File

@ -0,0 +1,14 @@
package cgroups
type ProcessType string
const (
ProcessTypePTY ProcessType = "pty"
ProcessTypeUser ProcessType = "user"
ProcessTypeSocat ProcessType = "socat"
)
type Manager interface {
GetFileDescriptor(procType ProcessType) (int, bool)
Close() error
}

View File

@ -0,0 +1,17 @@
package cgroups
type NoopManager struct{}
var _ Manager = (*NoopManager)(nil)
func NewNoopManager() *NoopManager {
return &NoopManager{}
}
func (n NoopManager) GetFileDescriptor(ProcessType) (int, bool) {
return 0, false
}
func (n NoopManager) Close() error {
return nil
}

View File

@ -0,0 +1,184 @@
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) ListDir(ctx context.Context, req *connect.Request[rpc.ListDirRequest]) (*connect.Response[rpc.ListDirResponse], error) {
depth := req.Msg.GetDepth()
if depth == 0 {
depth = 1 // default depth to current directory
}
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
requestedPath := req.Msg.GetPath()
// Expand the path so we can return absolute paths in the response.
requestedPath, err = permissions.ExpandAndResolve(requestedPath, u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
resolvedPath, err := followSymlink(requestedPath)
if err != nil {
return nil, err
}
err = checkIfDirectory(resolvedPath)
if err != nil {
return nil, err
}
entries, err := walkDir(requestedPath, resolvedPath, int(depth))
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.ListDirResponse{
Entries: entries,
}), nil
}
func (s Service) MakeDir(ctx context.Context, req *connect.Request[rpc.MakeDirRequest]) (*connect.Response[rpc.MakeDirResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
dirPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
stat, err := os.Stat(dirPath)
if err != nil && !os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
if err == nil {
if stat.IsDir() {
return nil, connect.NewError(connect.CodeAlreadyExists, fmt.Errorf("directory already exists: %s", dirPath))
}
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path already exists but it is not a directory: %s", dirPath))
}
uid, gid, userErr := permissions.GetUserIdInts(u)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
userErr = permissions.EnsureDirs(dirPath, uid, gid)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
entry, err := entryInfo(dirPath)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.MakeDirResponse{
Entry: entry,
}), nil
}
// followSymlink resolves a symbolic link to its target path.
func followSymlink(path string) (string, error) {
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(path)
if err != nil {
if os.IsNotExist(err) {
return "", connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %w", err))
}
if strings.Contains(err.Error(), "too many links") {
return "", connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("cyclic symlink or chain >255 links at %q", path))
}
return "", connect.NewError(connect.CodeInternal, fmt.Errorf("error resolving symlink: %w", err))
}
return resolvedPath, nil
}
// checkIfDirectory checks if the given path is a directory.
func checkIfDirectory(path string) error {
stat, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %w", err))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
if !stat.IsDir() {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path is not a directory: %s", path))
}
return nil
}
// walkDir walks the directory tree starting from dirPath up to the specified depth (doesn't follow symlinks).
func walkDir(requestedPath string, dirPath string, depth int) (entries []*rpc.EntryInfo, err error) {
err = filepath.WalkDir(dirPath, func(path string, _ os.DirEntry, err error) error {
if err != nil {
return err
}
// Skip the root directory itself
if path == dirPath {
return nil
}
// Calculate current depth
relPath, err := filepath.Rel(dirPath, path)
if err != nil {
return err
}
currentDepth := len(strings.Split(relPath, string(os.PathSeparator)))
if currentDepth > depth {
return filepath.SkipDir
}
entryInfo, err := entryInfo(path)
if err != nil {
var connectErr *connect.Error
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeNotFound {
// Skip entries that don't exist anymore
return nil
}
return err
}
// Return the requested path as the base path instead of the symlink-resolved path
path = filepath.Join(requestedPath, relPath)
entryInfo.Path = path
entries = append(entries, entryInfo)
return nil
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error reading directory %s: %w", dirPath, err))
}
return entries, nil
}

View File

@ -0,0 +1,405 @@
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestListDir(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure
testFolder := filepath.Join(root, "test")
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-1"), 0o755))
require.NoError(t, os.MkdirAll(filepath.Join(testFolder, "test-dir", "sub-dir-2"), 0o755))
filePath := filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
// Service instance
svc := mockService()
// Helper to inject user into context
injectUser := func(ctx context.Context, u *user.User) context.Context {
return authn.SetInfo(ctx, u)
}
tests := []struct {
name string
depth uint32
expectedPaths []string
}{
{
name: "depth 0 lists only root directory",
depth: 0,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
},
},
{
name: "depth 1 lists root directory",
depth: 1,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
},
},
{
name: "depth 2 lists first level of subdirectories (in this case the root directory)",
depth: 2,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
},
},
{
name: "depth 3 lists all directories and files",
depth: 3,
expectedPaths: []string{
filepath.Join(testFolder, "test-dir"),
filepath.Join(testFolder, "test-dir", "sub-dir-1"),
filepath.Join(testFolder, "test-dir", "sub-dir-2"),
filepath.Join(testFolder, "test-dir", "sub-dir-1", "file.txt"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := injectUser(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: testFolder,
Depth: tt.depth,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Msg)
assert.Len(t, resp.Msg.GetEntries(), len(tt.expectedPaths))
actualPaths := make([]string, len(resp.Msg.GetEntries()))
for i, entry := range resp.Msg.GetEntries() {
actualPaths[i] = entry.GetPath()
}
assert.ElementsMatch(t, tt.expectedPaths, actualPaths)
})
}
}
func TestListDirNonExistingPath(t *testing.T) {
t.Parallel()
svc := mockService()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: "/non-existing-path",
Depth: 1,
})
_, err = svc.ListDir(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
}
func TestListDirRelativePath(t *testing.T) {
t.Parallel()
// Setup temp root and user
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure
testRelativePath := fmt.Sprintf("test-%s", uuid.New())
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
filePath := filepath.Join(testFolderPath, "file.txt")
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
require.NoError(t, os.WriteFile(filePath, []byte("Hello, World!"), 0o644))
// Service instance
svc := mockService()
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: testRelativePath,
Depth: 1,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Msg)
expectedPaths := []string{
filepath.Join(testFolderPath, "file.txt"),
}
assert.Len(t, resp.Msg.GetEntries(), len(expectedPaths))
actualPaths := make([]string, len(resp.Msg.GetEntries()))
for i, entry := range resp.Msg.GetEntries() {
actualPaths[i] = entry.GetPath()
}
assert.ElementsMatch(t, expectedPaths, actualPaths)
}
func TestListDir_Symlinks(t *testing.T) {
t.Parallel()
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
symlinkRoot := filepath.Join(root, "test-symlinks")
require.NoError(t, os.MkdirAll(symlinkRoot, 0o755))
// 1. Prepare a real directory + file that a symlink will point to
realDir := filepath.Join(symlinkRoot, "real-dir")
require.NoError(t, os.MkdirAll(realDir, 0o755))
filePath := filepath.Join(realDir, "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
// 2. Prepare a standalone real file (points-to-file scenario)
realFile := filepath.Join(symlinkRoot, "real-file.txt")
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
// 3. Create the three symlinks
linkToDir := filepath.Join(symlinkRoot, "link-dir") // → directory
linkToFile := filepath.Join(symlinkRoot, "link-file") // → file
cyclicLink := filepath.Join(symlinkRoot, "cyclic") // → itself
require.NoError(t, os.Symlink(realDir, linkToDir))
require.NoError(t, os.Symlink(realFile, linkToFile))
require.NoError(t, os.Symlink(cyclicLink, cyclicLink))
svc := mockService()
t.Run("symlink to directory behaves like directory and the content looks like inside the directory", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: linkToDir,
Depth: 1,
})
resp, err := svc.ListDir(ctx, req)
require.NoError(t, err)
expected := []string{
filepath.Join(linkToDir, "file.txt"),
}
actual := make([]string, len(resp.Msg.GetEntries()))
for i, e := range resp.Msg.GetEntries() {
actual[i] = e.GetPath()
}
assert.ElementsMatch(t, expected, actual)
})
t.Run("link to file", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: linkToFile,
Depth: 1,
})
_, err := svc.ListDir(ctx, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a directory")
})
t.Run("cyclic symlink surfaces 'too many links' → invalid-argument", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: cyclicLink,
})
_, err := svc.ListDir(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeFailedPrecondition, connectErr.Code())
assert.Contains(t, connectErr.Error(), "cyclic symlink")
})
t.Run("symlink not resolved if not root", func(t *testing.T) {
t.Parallel()
req := connect.NewRequest(&filesystem.ListDirRequest{
Path: symlinkRoot,
Depth: 3,
})
res, err := svc.ListDir(ctx, req)
require.NoError(t, err)
expected := []string{
filepath.Join(symlinkRoot, "cyclic"),
filepath.Join(symlinkRoot, "link-dir"),
filepath.Join(symlinkRoot, "link-file"),
filepath.Join(symlinkRoot, "real-dir"),
filepath.Join(symlinkRoot, "real-dir", "file.txt"),
filepath.Join(symlinkRoot, "real-file.txt"),
}
actual := make([]string, len(res.Msg.GetEntries()))
for i, e := range res.Msg.GetEntries() {
actual[i] = e.GetPath()
}
assert.ElementsMatch(t, expected, actual, "symlinks should not be resolved when listing the symlink root directory")
})
}
// TestFollowSymlink_Success makes sure that followSymlink resolves symlinks,
// while also being robust to the /var → /private/var indirection that exists on macOS.
func TestFollowSymlink_Success(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
base := t.TempDir()
// Create a real directory that we ultimately want to resolve to.
target := filepath.Join(base, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Create a symlink pointing at the real directory so we can verify that
// followSymlink follows it.
link := filepath.Join(base, "link")
require.NoError(t, os.Symlink(target, link))
got, err := followSymlink(link)
require.NoError(t, err)
// Canonicalise the expected path too, so that /var → /private/var (macOS)
// or any other benign symlink indirections dont cause flaky tests.
want, err := filepath.EvalSymlinks(link)
require.NoError(t, err)
require.Equal(t, want, got, "followSymlink should resolve and canonicalise symlinks")
}
// TestFollowSymlink_MultiSymlinkChain verifies that followSymlink follows a chain
// of several symlinks (noncyclic) correctly.
func TestFollowSymlink_MultiSymlinkChain(t *testing.T) {
t.Parallel()
base := t.TempDir()
// Final destination directory.
target := filepath.Join(base, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Build a 3link chain: link1 → link2 → link3 → target.
link3 := filepath.Join(base, "link3")
require.NoError(t, os.Symlink(target, link3))
link2 := filepath.Join(base, "link2")
require.NoError(t, os.Symlink(link3, link2))
link1 := filepath.Join(base, "link1")
require.NoError(t, os.Symlink(link2, link1))
got, err := followSymlink(link1)
require.NoError(t, err)
want, err := filepath.EvalSymlinks(link1)
require.NoError(t, err)
require.Equal(t, want, got, "followSymlink should resolve an arbitrary symlink chain")
}
func TestFollowSymlink_NotFound(t *testing.T) {
t.Parallel()
_, err := followSymlink("/definitely/does/not/exist")
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeNotFound, cerr.Code())
}
func TestFollowSymlink_CyclicSymlink(t *testing.T) {
t.Parallel()
dir := t.TempDir()
a := filepath.Join(dir, "a")
b := filepath.Join(dir, "b")
require.NoError(t, os.MkdirAll(a, 0o755))
require.NoError(t, os.MkdirAll(b, 0o755))
// Create a twonode loop: a/loop → b/loop, b/loop → a/loop.
require.NoError(t, os.Symlink(filepath.Join(b, "loop"), filepath.Join(a, "loop")))
require.NoError(t, os.Symlink(filepath.Join(a, "loop"), filepath.Join(b, "loop")))
_, err := followSymlink(filepath.Join(a, "loop"))
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeFailedPrecondition, cerr.Code())
require.Contains(t, cerr.Message(), "cyclic")
}
func TestCheckIfDirectory(t *testing.T) {
t.Parallel()
dir := t.TempDir()
require.NoError(t, checkIfDirectory(dir))
file := filepath.Join(dir, "file.txt")
require.NoError(t, os.WriteFile(file, []byte("hello"), 0o644))
err := checkIfDirectory(file)
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeInvalidArgument, cerr.Code())
}
func TestWalkDir_Depth(t *testing.T) {
t.Parallel()
root := t.TempDir()
sub := filepath.Join(root, "sub")
subsub := filepath.Join(sub, "subsub")
require.NoError(t, os.MkdirAll(subsub, 0o755))
entries, err := walkDir(root, root, 1)
require.NoError(t, err)
// Collect the names for easier assertions.
names := make([]string, 0, len(entries))
for _, e := range entries {
names = append(names, e.GetName())
}
require.Contains(t, names, "sub")
require.NotContains(t, names, "subsub", "entries beyond depth should be excluded")
}
func TestWalkDir_Error(t *testing.T) {
t.Parallel()
_, err := walkDir("/does/not/exist", "/does/not/exist", 1)
require.Error(t, err)
var cerr *connect.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, connect.CodeInternal, cerr.Code())
}

View File

@ -0,0 +1,58 @@
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Move(ctx context.Context, req *connect.Request[rpc.MoveRequest]) (*connect.Response[rpc.MoveResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
source, err := permissions.ExpandAndResolve(req.Msg.GetSource(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
destination, err := permissions.ExpandAndResolve(req.Msg.GetDestination(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
uid, gid, userErr := permissions.GetUserIdInts(u)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
userErr = permissions.EnsureDirs(filepath.Dir(destination), uid, gid)
if userErr != nil {
return nil, connect.NewError(connect.CodeInternal, userErr)
}
err = os.Rename(source, destination)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("source file not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error renaming: %w", err))
}
entry, err := entryInfo(destination)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.MoveResponse{
Entry: entry,
}), nil
}

View File

@ -0,0 +1,364 @@
package filesystem
import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestMove(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup source and destination directories
sourceDir := filepath.Join(root, "source")
destDir := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(sourceDir, 0o755))
require.NoError(t, os.MkdirAll(destDir, 0o755))
// Create a test file to move
sourceFile := filepath.Join(sourceDir, "test-file.txt")
testContent := []byte("Hello, World!")
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
// Destination file path
destFile := filepath.Join(destDir, "test-file.txt")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceFile,
Destination: destFile,
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
// Verify the file exists at the destination
_, err = os.Stat(destFile)
require.NoError(t, err)
// Verify the file no longer exists at the source
_, err = os.Stat(sourceFile)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved file
content, err := os.ReadFile(destFile)
require.NoError(t, err)
assert.Equal(t, testContent, content)
}
func TestMoveDirectory(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup source and destination directories
sourceParent := filepath.Join(root, "source-parent")
destParent := filepath.Join(root, "dest-parent")
require.NoError(t, os.MkdirAll(sourceParent, 0o755))
require.NoError(t, os.MkdirAll(destParent, 0o755))
// Create a test directory with files to move
sourceDir := filepath.Join(sourceParent, "test-dir")
require.NoError(t, os.MkdirAll(filepath.Join(sourceDir, "subdir"), 0o755))
// Create some files in the directory
file1 := filepath.Join(sourceDir, "file1.txt")
file2 := filepath.Join(sourceDir, "subdir", "file2.txt")
require.NoError(t, os.WriteFile(file1, []byte("File 1 content"), 0o644))
require.NoError(t, os.WriteFile(file2, []byte("File 2 content"), 0o644))
// Destination directory path
destDir := filepath.Join(destParent, "test-dir")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceDir,
Destination: destDir,
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destDir, resp.Msg.GetEntry().GetPath())
// Verify the directory exists at the destination
_, err = os.Stat(destDir)
require.NoError(t, err)
// Verify the files exist at the destination
destFile1 := filepath.Join(destDir, "file1.txt")
destFile2 := filepath.Join(destDir, "subdir", "file2.txt")
_, err = os.Stat(destFile1)
require.NoError(t, err)
_, err = os.Stat(destFile2)
require.NoError(t, err)
// Verify the directory no longer exists at the source
_, err = os.Stat(sourceDir)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved files
content1, err := os.ReadFile(destFile1)
require.NoError(t, err)
assert.Equal(t, []byte("File 1 content"), content1)
content2, err := os.ReadFile(destFile2)
require.NoError(t, err)
assert.Equal(t, []byte("File 2 content"), content2)
}
func TestMoveNonExistingFile(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
// Setup destination directory
destDir := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(destDir, 0o755))
// Non-existing source file
sourceFile := filepath.Join(root, "non-existing-file.txt")
// Destination file path
destFile := filepath.Join(destDir, "moved-file.txt")
// Service instance
svc := mockService()
// Call the Move function
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: sourceFile,
Destination: destFile,
})
_, err = svc.Move(ctx, req)
// Verify the correct error is returned
require.Error(t, err)
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok, "expected error to be of type *connect.Error")
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
assert.Contains(t, connectErr.Message(), "source file not found")
}
func TestMoveRelativePath(t *testing.T) {
t.Parallel()
// Setup user
u, err := user.Current()
require.NoError(t, err)
// Setup directory structure with unique name to avoid conflicts
testRelativePath := fmt.Sprintf("test-move-%s", uuid.New())
testFolderPath := filepath.Join(u.HomeDir, testRelativePath)
require.NoError(t, os.MkdirAll(testFolderPath, 0o755))
// Create a test file to move
sourceFile := filepath.Join(testFolderPath, "source-file.txt")
testContent := []byte("Hello from relative path!")
require.NoError(t, os.WriteFile(sourceFile, testContent, 0o644))
// Destination file path (also relative)
destRelativePath := fmt.Sprintf("test-move-dest-%s", uuid.New())
destFolderPath := filepath.Join(u.HomeDir, destRelativePath)
require.NoError(t, os.MkdirAll(destFolderPath, 0o755))
destFile := filepath.Join(destFolderPath, "moved-file.txt")
// Service instance
svc := mockService()
// Call the Move function with relative paths
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.MoveRequest{
Source: filepath.Join(testRelativePath, "source-file.txt"), // Relative path
Destination: filepath.Join(destRelativePath, "moved-file.txt"), // Relative path
})
resp, err := svc.Move(ctx, req)
// Verify the move was successful
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, destFile, resp.Msg.GetEntry().GetPath())
// Verify the file exists at the destination
_, err = os.Stat(destFile)
require.NoError(t, err)
// Verify the file no longer exists at the source
_, err = os.Stat(sourceFile)
assert.True(t, os.IsNotExist(err))
// Verify the content of the moved file
content, err := os.ReadFile(destFile)
require.NoError(t, err)
assert.Equal(t, testContent, content)
// Clean up
os.RemoveAll(testFolderPath)
os.RemoveAll(destFolderPath)
}
func TestMove_Symlinks(t *testing.T) { //nolint:tparallel // this test cannot be executed in parallel
root := t.TempDir()
u, err := user.Current()
require.NoError(t, err)
ctx := authn.SetInfo(t.Context(), u)
// Setup source and destination directories
sourceRoot := filepath.Join(root, "source")
destRoot := filepath.Join(root, "destination")
require.NoError(t, os.MkdirAll(sourceRoot, 0o755))
require.NoError(t, os.MkdirAll(destRoot, 0o755))
// 1. Prepare a real directory + file that a symlink will point to
realDir := filepath.Join(sourceRoot, "real-dir")
require.NoError(t, os.MkdirAll(realDir, 0o755))
filePath := filepath.Join(realDir, "file.txt")
require.NoError(t, os.WriteFile(filePath, []byte("hello via symlink"), 0o644))
// 2. Prepare a standalone real file (points-to-file scenario)
realFile := filepath.Join(sourceRoot, "real-file.txt")
require.NoError(t, os.WriteFile(realFile, []byte("i am a plain file"), 0o644))
// 3. Create symlinks
linkToDir := filepath.Join(sourceRoot, "link-dir") // → directory
linkToFile := filepath.Join(sourceRoot, "link-file") // → file
require.NoError(t, os.Symlink(realDir, linkToDir))
require.NoError(t, os.Symlink(realFile, linkToFile))
svc := mockService()
t.Run("move symlink to directory", func(t *testing.T) {
t.Parallel()
destPath := filepath.Join(destRoot, "moved-link-dir")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: linkToDir,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the symlink was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify it's still a symlink
info, err := os.Lstat(destPath)
require.NoError(t, err)
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
// Verify the symlink target is still correct
target, err := os.Readlink(destPath)
require.NoError(t, err)
assert.Equal(t, realDir, target)
// Verify the original symlink is gone
_, err = os.Stat(linkToDir)
assert.True(t, os.IsNotExist(err))
// Verify the real directory still exists
_, err = os.Stat(realDir)
assert.NoError(t, err)
})
t.Run("move symlink to file", func(t *testing.T) { //nolint:paralleltest
destPath := filepath.Join(destRoot, "moved-link-file")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: linkToFile,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the symlink was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify it's still a symlink
info, err := os.Lstat(destPath)
require.NoError(t, err)
assert.NotEqual(t, 0, info.Mode()&os.ModeSymlink, "expected a symlink")
// Verify the symlink target is still correct
target, err := os.Readlink(destPath)
require.NoError(t, err)
assert.Equal(t, realFile, target)
// Verify the original symlink is gone
_, err = os.Stat(linkToFile)
assert.True(t, os.IsNotExist(err))
// Verify the real file still exists
_, err = os.Stat(realFile)
assert.NoError(t, err)
})
t.Run("move real file that is target of symlink", func(t *testing.T) {
t.Parallel()
// Create a new symlink to the real file
newLinkToFile := filepath.Join(sourceRoot, "new-link-file")
require.NoError(t, os.Symlink(realFile, newLinkToFile))
destPath := filepath.Join(destRoot, "moved-real-file.txt")
req := connect.NewRequest(&filesystem.MoveRequest{
Source: realFile,
Destination: destPath,
})
resp, err := svc.Move(ctx, req)
require.NoError(t, err)
assert.Equal(t, destPath, resp.Msg.GetEntry().GetPath())
// Verify the real file was moved
_, err = os.Stat(destPath)
require.NoError(t, err)
// Verify the original file is gone
_, err = os.Stat(realFile)
assert.True(t, os.IsNotExist(err))
// Verify the symlink still exists but now points to a non-existent file
_, err = os.Stat(newLinkToFile)
require.Error(t, err, "symlink should point to non-existent file")
})
}

View File

@ -0,0 +1,31 @@
package filesystem
import (
"context"
"fmt"
"os"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Remove(ctx context.Context, req *connect.Request[rpc.RemoveRequest]) (*connect.Response[rpc.RemoveResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
err = os.RemoveAll(path)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error removing file or directory: %w", err))
}
return connect.NewResponse(&rpc.RemoveResponse{}), nil
}

View File

@ -0,0 +1,34 @@
package filesystem
import (
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem/filesystemconnect"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Service struct {
logger *zerolog.Logger
watchers *utils.Map[string, *FileWatcher]
defaults *execcontext.Defaults
}
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults) {
service := Service{
logger: l,
watchers: utils.NewMap[string, *FileWatcher](),
defaults: defaults,
}
interceptors := connect.WithInterceptors(
logs.NewUnaryLogInterceptor(l),
)
path, handler := spec.NewFilesystemHandler(service, interceptors)
server.Mount(path, handler)
}

View File

@ -0,0 +1,14 @@
package filesystem
import (
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func mockService() Service {
return Service{
defaults: &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
},
}
}

View File

@ -0,0 +1,29 @@
package filesystem
import (
"context"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func (s Service) Stat(ctx context.Context, req *connect.Request[rpc.StatRequest]) (*connect.Response[rpc.StatResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
path, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
entry, err := entryInfo(path)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.StatResponse{Entry: entry}), nil
}

View File

@ -0,0 +1,114 @@
package filesystem
import (
"context"
"os"
"os/user"
"path/filepath"
"testing"
"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
)
func TestStat(t *testing.T) {
t.Parallel()
// Setup temp root and user
root := t.TempDir()
// Get the actual path to the temp directory (symlinks can cause issues)
root, err := filepath.EvalSymlinks(root)
require.NoError(t, err)
u, err := user.Current()
require.NoError(t, err)
group, err := user.LookupGroupId(u.Gid)
require.NoError(t, err)
// Setup directory structure
testFolder := filepath.Join(root, "test")
err = os.MkdirAll(testFolder, 0o755)
require.NoError(t, err)
testFile := filepath.Join(testFolder, "file.txt")
err = os.WriteFile(testFile, []byte("Hello, World!"), 0o644)
require.NoError(t, err)
linkedFile := filepath.Join(testFolder, "linked-file.txt")
err = os.Symlink(testFile, linkedFile)
require.NoError(t, err)
// Service instance
svc := mockService()
// Helper to inject user into context
injectUser := func(ctx context.Context, u *user.User) context.Context {
return authn.SetInfo(ctx, u)
}
tests := []struct {
name string
path string
}{
{
name: "Stat file directory",
path: testFile,
},
{
name: "Stat symlink to file",
path: linkedFile,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := injectUser(t.Context(), u)
req := connect.NewRequest(&filesystem.StatRequest{
Path: tt.path,
})
resp, err := svc.Stat(ctx, req)
require.NoError(t, err)
require.NotEmpty(t, resp.Msg)
require.NotNil(t, resp.Msg.GetEntry())
assert.Equal(t, tt.path, resp.Msg.GetEntry().GetPath())
assert.Equal(t, filesystem.FileType_FILE_TYPE_FILE, resp.Msg.GetEntry().GetType())
assert.Equal(t, u.Username, resp.Msg.GetEntry().GetOwner())
assert.Equal(t, group.Name, resp.Msg.GetEntry().GetGroup())
assert.Equal(t, uint32(0o644), resp.Msg.GetEntry().GetMode())
if tt.path == linkedFile {
require.NotNil(t, resp.Msg.GetEntry().GetSymlinkTarget())
assert.Equal(t, testFile, resp.Msg.GetEntry().GetSymlinkTarget())
} else {
assert.Empty(t, resp.Msg.GetEntry().GetSymlinkTarget())
}
})
}
}
func TestStatMissingPathReturnsNotFound(t *testing.T) {
t.Parallel()
u, err := user.Current()
require.NoError(t, err)
svc := mockService()
ctx := authn.SetInfo(t.Context(), u)
req := connect.NewRequest(&filesystem.StatRequest{
Path: filepath.Join(t.TempDir(), "missing.txt"),
})
_, err = svc.Stat(ctx, req)
require.Error(t, err)
var connectErr *connect.Error
require.ErrorAs(t, err, &connectErr)
assert.Equal(t, connect.CodeNotFound, connectErr.Code())
}

View File

@ -0,0 +1,107 @@
package filesystem
import (
"fmt"
"os"
"os/user"
"syscall"
"time"
"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/timestamppb"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
)
// Filesystem magic numbers from Linux kernel (include/uapi/linux/magic.h)
const (
nfsSuperMagic = 0x6969
cifsMagic = 0xFF534D42
smbSuperMagic = 0x517B
smb2MagicNumber = 0xFE534D42
fuseSuperMagic = 0x65735546
)
// IsPathOnNetworkMount checks if the given path is on a network filesystem mount.
// Returns true if the path is on NFS, CIFS, SMB, or FUSE filesystem.
func IsPathOnNetworkMount(path string) (bool, error) {
var statfs syscall.Statfs_t
if err := syscall.Statfs(path, &statfs); err != nil {
return false, fmt.Errorf("failed to statfs %s: %w", path, err)
}
switch statfs.Type {
case nfsSuperMagic, cifsMagic, smbSuperMagic, smb2MagicNumber, fuseSuperMagic:
return true, nil
default:
return false, nil
}
}
func entryInfo(path string) (*rpc.EntryInfo, error) {
info, err := filesystem.GetEntryFromPath(path)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error getting file info: %w", err))
}
owner, group := getFileOwnership(info)
return &rpc.EntryInfo{
Name: info.Name,
Type: getEntryType(info.Type),
Path: info.Path,
Size: info.Size,
Mode: uint32(info.Mode),
Permissions: info.Permissions,
Owner: owner,
Group: group,
ModifiedTime: toTimestamp(info.ModifiedTime),
SymlinkTarget: info.SymlinkTarget,
}, nil
}
func toTimestamp(time time.Time) *timestamppb.Timestamp {
if time.IsZero() {
return nil
}
return timestamppb.New(time)
}
// getFileOwnership returns the owner and group names for a file.
// If the lookup fails, it returns the numeric UID and GID as strings.
func getFileOwnership(fileInfo filesystem.EntryInfo) (owner, group string) {
// Look up username
owner = fmt.Sprintf("%d", fileInfo.UID)
if u, err := user.LookupId(owner); err == nil {
owner = u.Username
}
// Look up group name
group = fmt.Sprintf("%d", fileInfo.GID)
if g, err := user.LookupGroupId(group); err == nil {
group = g.Name
}
return owner, group
}
// getEntryType determines the type of file entry based on its mode and path.
// If the file is a symlink, it follows the symlink to determine the actual type.
func getEntryType(fileType filesystem.FileType) rpc.FileType {
switch fileType {
case filesystem.FileFileType:
return rpc.FileType_FILE_TYPE_FILE
case filesystem.DirectoryFileType:
return rpc.FileType_FILE_TYPE_DIRECTORY
case filesystem.SymlinkFileType:
return rpc.FileType_FILE_TYPE_SYMLINK
default:
return rpc.FileType_FILE_TYPE_UNSPECIFIED
}
}

View File

@ -0,0 +1,149 @@
package filesystem
import (
"context"
"os/exec"
osuser "os/user"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fsmodel "git.omukk.dev/wrenn/sandbox/envd/internal/shared/filesystem"
)
func TestIsPathOnNetworkMount(t *testing.T) {
t.Parallel()
// Test with a regular directory (should not be on network mount)
tempDir := t.TempDir()
isNetwork, err := IsPathOnNetworkMount(tempDir)
require.NoError(t, err)
assert.False(t, isNetwork, "temp directory should not be on a network mount")
}
func TestIsPathOnNetworkMount_FuseMount(t *testing.T) {
t.Parallel()
// Require bindfs to be available
_, err := exec.LookPath("bindfs")
require.NoError(t, err, "bindfs must be installed for this test")
// Require fusermount to be available (needed for unmounting)
_, err = exec.LookPath("fusermount")
require.NoError(t, err, "fusermount must be installed for this test")
// Create source and mount directories
sourceDir := t.TempDir()
mountDir := t.TempDir()
// Mount sourceDir onto mountDir using bindfs (FUSE)
ctx := context.Background()
cmd := exec.CommandContext(ctx, "bindfs", sourceDir, mountDir)
require.NoError(t, cmd.Run(), "failed to mount bindfs")
// Ensure we unmount on cleanup
t.Cleanup(func() {
_ = exec.CommandContext(context.Background(), "fusermount", "-u", mountDir).Run()
})
// Test that the FUSE mount is detected
isNetwork, err := IsPathOnNetworkMount(mountDir)
require.NoError(t, err)
assert.True(t, isNetwork, "FUSE mount should be detected as network filesystem")
// Test that the source directory is NOT detected as network mount
isNetworkSource, err := IsPathOnNetworkMount(sourceDir)
require.NoError(t, err)
assert.False(t, isNetworkSource, "source directory should not be detected as network filesystem")
}
func TestGetFileOwnership_CurrentUser(t *testing.T) {
t.Parallel()
t.Run("current user", func(t *testing.T) {
t.Parallel()
// Get current user running the tests
cur, err := osuser.Current()
if err != nil {
t.Skipf("unable to determine current user: %v", err)
}
// Determine expected owner/group using the same lookup logic
expectedOwner := cur.Uid
if u, err := osuser.LookupId(cur.Uid); err == nil {
expectedOwner = u.Username
}
expectedGroup := cur.Gid
if g, err := osuser.LookupGroupId(cur.Gid); err == nil {
expectedGroup = g.Name
}
// Parse UID/GID strings to uint32 for EntryInfo
uid64, err := strconv.ParseUint(cur.Uid, 10, 32)
require.NoError(t, err)
gid64, err := strconv.ParseUint(cur.Gid, 10, 32)
require.NoError(t, err)
// Build a minimal EntryInfo with current UID/GID
info := fsmodel.EntryInfo{ // from shared pkg
UID: uint32(uid64),
GID: uint32(gid64),
}
owner, group := getFileOwnership(info)
assert.Equal(t, expectedOwner, owner)
assert.Equal(t, expectedGroup, group)
})
t.Run("no user", func(t *testing.T) {
t.Parallel()
// Find a UID that does not exist on this system
var unknownUIDStr string
for i := 60001; i < 70000; i++ { // search a high range typically unused
idStr := strconv.Itoa(i)
if _, err := osuser.LookupId(idStr); err != nil {
unknownUIDStr = idStr
break
}
}
if unknownUIDStr == "" {
t.Skip("could not find a non-existent UID in the probed range")
}
// Find a GID that does not exist on this system
var unknownGIDStr string
for i := 60001; i < 70000; i++ { // search a high range typically unused
idStr := strconv.Itoa(i)
if _, err := osuser.LookupGroupId(idStr); err != nil {
unknownGIDStr = idStr
break
}
}
if unknownGIDStr == "" {
t.Skip("could not find a non-existent GID in the probed range")
}
// Parse to uint32 for EntryInfo construction
uid64, err := strconv.ParseUint(unknownUIDStr, 10, 32)
require.NoError(t, err)
gid64, err := strconv.ParseUint(unknownGIDStr, 10, 32)
require.NoError(t, err)
info := fsmodel.EntryInfo{
UID: uint32(uid64),
GID: uint32(gid64),
}
owner, group := getFileOwnership(info)
// Expect numeric fallbacks because lookups should fail for unknown IDs
assert.Equal(t, unknownUIDStr, owner)
assert.Equal(t, unknownGIDStr, group)
})
}

View File

@ -0,0 +1,159 @@
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"connectrpc.com/connect"
"github.com/e2b-dev/fsnotify"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func (s Service) WatchDir(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.watchHandler)
}
func (s Service) watchHandler(ctx context.Context, req *connect.Request[rpc.WatchDirRequest], stream *connect.ServerStream[rpc.WatchDirResponse]) error {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return err
}
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
info, err := os.Stat(watchPath)
if err != nil {
if os.IsNotExist(err) {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
}
if !info.IsDir() {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
}
// Check if path is on a network filesystem mount
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
}
if isNetworkMount {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
}
w, err := fsnotify.NewWatcher()
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
}
defer w.Close()
err = w.Add(utils.FsnotifyPath(watchPath, req.Msg.GetRecursive()))
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
}
err = stream.Send(&rpc.WatchDirResponse{
Event: &rpc.WatchDirResponse_Start{
Start: &rpc.WatchDirResponse_StartEvent{},
},
})
if err != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", err))
}
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.WatchDirResponse{
Event: &rpc.WatchDirResponse_Keepalive{
Keepalive: &rpc.WatchDirResponse_KeepAlive{},
},
})
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr))
}
case <-ctx.Done():
return ctx.Err()
case chErr, ok := <-w.Errors:
if !ok {
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
case e, ok := <-w.Events:
if !ok {
return connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
}
// One event can have multiple operations.
ops := []rpc.EventType{}
if fsnotify.Create.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
}
if fsnotify.Rename.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
}
if fsnotify.Chmod.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
}
if fsnotify.Write.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
}
if fsnotify.Remove.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
}
for _, op := range ops {
name, nameErr := filepath.Rel(watchPath, e.Name)
if nameErr != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
}
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
Filesystem: &rpc.FilesystemEvent{
Name: name,
Type: op,
},
}
event := &rpc.WatchDirResponse{
Event: filesystemEvent,
}
streamErr := stream.Send(event)
s.logger.
Debug().
Str("event_type", "filesystem_event").
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Interface("filesystem_event", event).
Msg("Streaming filesystem event")
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending filesystem event: %w", streamErr))
}
resetKeepalive()
}
}
}
}

View File

@ -0,0 +1,224 @@
package filesystem
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"connectrpc.com/connect"
"github.com/e2b-dev/fsnotify"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/id"
)
type FileWatcher struct {
watcher *fsnotify.Watcher
Events []*rpc.FilesystemEvent
cancel func()
Error error
Lock sync.Mutex
}
func CreateFileWatcher(ctx context.Context, watchPath string, recursive bool, operationID string, logger *zerolog.Logger) (*FileWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating watcher: %w", err))
}
// We don't want to cancel the context when the request is finished
ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
err = w.Add(utils.FsnotifyPath(watchPath, recursive))
if err != nil {
_ = w.Close()
cancel()
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error adding path %s to watcher: %w", watchPath, err))
}
fw := &FileWatcher{
watcher: w,
cancel: cancel,
Events: []*rpc.FilesystemEvent{},
Error: nil,
}
go func() {
for {
select {
case <-ctx.Done():
return
case chErr, ok := <-w.Errors:
if !ok {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error channel closed"))
return
}
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher error: %w", chErr))
return
case e, ok := <-w.Events:
if !ok {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("watcher event channel closed"))
return
}
// One event can have multiple operations.
ops := []rpc.EventType{}
if fsnotify.Create.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CREATE)
}
if fsnotify.Rename.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_RENAME)
}
if fsnotify.Chmod.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_CHMOD)
}
if fsnotify.Write.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_WRITE)
}
if fsnotify.Remove.Has(e.Op) {
ops = append(ops, rpc.EventType_EVENT_TYPE_REMOVE)
}
for _, op := range ops {
name, nameErr := filepath.Rel(watchPath, e.Name)
if nameErr != nil {
fw.Error = connect.NewError(connect.CodeInternal, fmt.Errorf("error getting relative path: %w", nameErr))
return
}
fw.Lock.Lock()
fw.Events = append(fw.Events, &rpc.FilesystemEvent{
Name: name,
Type: op,
})
fw.Lock.Unlock()
// these are only used for logging
filesystemEvent := &rpc.WatchDirResponse_Filesystem{
Filesystem: &rpc.FilesystemEvent{
Name: name,
Type: op,
},
}
event := &rpc.WatchDirResponse{
Event: filesystemEvent,
}
logger.
Debug().
Str("event_type", "filesystem_event").
Str(string(logs.OperationIDKey), operationID).
Interface("filesystem_event", event).
Msg("Streaming filesystem event")
}
}
}
}()
return fw, nil
}
func (fw *FileWatcher) Close() {
_ = fw.watcher.Close()
fw.cancel()
}
func (s Service) CreateWatcher(ctx context.Context, req *connect.Request[rpc.CreateWatcherRequest]) (*connect.Response[rpc.CreateWatcherResponse], error) {
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return nil, err
}
watchPath, err := permissions.ExpandAndResolve(req.Msg.GetPath(), u, s.defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
info, err := os.Stat(watchPath)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path %s not found: %w", watchPath, err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error statting path %s: %w", watchPath, err))
}
if !info.IsDir() {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("path %s not a directory: %w", watchPath, err))
}
// Check if path is on a network filesystem mount
isNetworkMount, err := IsPathOnNetworkMount(watchPath)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error checking mount status: %w", err))
}
if isNetworkMount {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cannot watch path on network filesystem: %s", watchPath))
}
watcherId := "w" + id.Generate()
w, err := CreateFileWatcher(ctx, watchPath, req.Msg.GetRecursive(), watcherId, s.logger)
if err != nil {
return nil, err
}
s.watchers.Store(watcherId, w)
return connect.NewResponse(&rpc.CreateWatcherResponse{
WatcherId: watcherId,
}), nil
}
func (s Service) GetWatcherEvents(_ context.Context, req *connect.Request[rpc.GetWatcherEventsRequest]) (*connect.Response[rpc.GetWatcherEventsResponse], error) {
watcherId := req.Msg.GetWatcherId()
w, ok := s.watchers.Load(watcherId)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
}
if w.Error != nil {
return nil, w.Error
}
w.Lock.Lock()
defer w.Lock.Unlock()
events := w.Events
w.Events = []*rpc.FilesystemEvent{}
return connect.NewResponse(&rpc.GetWatcherEventsResponse{
Events: events,
}), nil
}
func (s Service) RemoveWatcher(_ context.Context, req *connect.Request[rpc.RemoveWatcherRequest]) (*connect.Response[rpc.RemoveWatcherResponse], error) {
watcherId := req.Msg.GetWatcherId()
w, ok := s.watchers.Load(watcherId)
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("watcher with id %s not found", watcherId))
}
w.Close()
s.watchers.Delete(watcherId)
return connect.NewResponse(&rpc.RemoveWatcherResponse{}), nil
}

View File

@ -0,0 +1,126 @@
package process
import (
"context"
"errors"
"fmt"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Connect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleConnect)
}
func (s *Service) handleConnect(ctx context.Context, req *connect.Request[rpc.ConnectRequest], stream *connect.ServerStream[rpc.ConnectResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return err
}
exitChan := make(chan struct{})
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: proc.Pid(),
},
},
},
})
if streamErr != nil {
return connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr))
}
go func() {
defer close(exitChan)
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.ConnectResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}

View File

@ -0,0 +1,478 @@
package handler
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"syscall"
"connectrpc.com/connect"
"github.com/creack/pty"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
const (
defaultNice = 0
defaultOomScore = 100
outputBufferSize = 64
stdChunkSize = 2 << 14
ptyChunkSize = 2 << 13
)
type ProcessExit struct {
Error *string
Status string
Exited bool
Code int32
}
type Handler struct {
Config *rpc.ProcessConfig
logger *zerolog.Logger
Tag *string
cmd *exec.Cmd
tty *os.File
cancel context.CancelFunc
outCtx context.Context //nolint:containedctx // todo: refactor so this can be removed
outCancel context.CancelFunc
stdinMu sync.Mutex
stdin io.WriteCloser
DataEvent *MultiplexedChannel[rpc.ProcessEvent_Data]
EndEvent *MultiplexedChannel[rpc.ProcessEvent_End]
}
// This method must be called only after the process has been started
func (p *Handler) Pid() uint32 {
return uint32(p.cmd.Process.Pid)
}
// userCommand returns a human-readable representation of the user's original command,
// without the internal OOM/nice wrapper that is prepended to the actual exec.
func (p *Handler) userCommand() string {
return strings.Join(append([]string{p.Config.GetCmd()}, p.Config.GetArgs()...), " ")
}
// currentNice returns the nice value of the current process.
func currentNice() int {
prio, err := syscall.Getpriority(syscall.PRIO_PROCESS, 0)
if err != nil {
return 0
}
// Getpriority returns 20 - nice on Linux.
return 20 - prio
}
func New(
ctx context.Context,
user *user.User,
req *rpc.StartRequest,
logger *zerolog.Logger,
defaults *execcontext.Defaults,
cgroupManager cgroups.Manager,
cancel context.CancelFunc,
) (*Handler, error) {
// User command string for logging (without the internal wrapper details).
userCmd := strings.Join(append([]string{req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...), " ")
// Wrap the command in a shell that sets the OOM score and nice value before exec-ing the actual command.
// This eliminates the race window where grandchildren could inherit the parent's protected OOM score (-1000)
// or high CPU priority (nice -20) before the post-start calls had a chance to correct them.
// nice(1) applies a relative adjustment, so we compute the delta from the current (inherited) nice to the target.
niceDelta := defaultNice - currentNice()
oomWrapperScript := fmt.Sprintf(`echo %d > /proc/$$/oom_score_adj && exec /usr/bin/nice -n %d "${@}"`, defaultOomScore, niceDelta)
wrapperArgs := append([]string{"-c", oomWrapperScript, "--", req.GetProcess().GetCmd()}, req.GetProcess().GetArgs()...)
cmd := exec.CommandContext(ctx, "/bin/sh", wrapperArgs...)
uid, gid, err := permissions.GetUserIdUints(user)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
groups := []uint32{gid}
if gids, err := user.GroupIds(); err != nil {
logger.Warn().Err(err).Str("user", user.Username).Msg("failed to get supplementary groups")
} else {
for _, g := range gids {
if parsed, err := strconv.ParseUint(g, 10, 32); err == nil {
groups = append(groups, uint32(parsed))
}
}
}
cgroupFD, ok := cgroupManager.GetFileDescriptor(getProcType(req))
cmd.SysProcAttr = &syscall.SysProcAttr{
UseCgroupFD: ok,
CgroupFD: cgroupFD,
Credential: &syscall.Credential{
Uid: uid,
Gid: gid,
Groups: groups,
},
}
resolvedPath, err := permissions.ExpandAndResolve(req.GetProcess().GetCwd(), user, defaults.Workdir)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
// Check if the cwd resolved path exists
if _, err := os.Stat(resolvedPath); errors.Is(err, os.ErrNotExist) {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("cwd '%s' does not exist", resolvedPath))
}
cmd.Dir = resolvedPath
var formattedVars []string
// Take only 'PATH' variable from the current environment
// The 'PATH' should ideally be set in the environment
formattedVars = append(formattedVars, "PATH="+os.Getenv("PATH"))
formattedVars = append(formattedVars, "HOME="+user.HomeDir)
formattedVars = append(formattedVars, "USER="+user.Username)
formattedVars = append(formattedVars, "LOGNAME="+user.Username)
// Add the environment variables from the global environment
if defaults.EnvVars != nil {
defaults.EnvVars.Range(func(key string, value string) bool {
formattedVars = append(formattedVars, key+"="+value)
return true
})
}
// Only the last values of the env vars are used - this allows for overwriting defaults
for key, value := range req.GetProcess().GetEnvs() {
formattedVars = append(formattedVars, key+"="+value)
}
cmd.Env = formattedVars
outMultiplex := NewMultiplexedChannel[rpc.ProcessEvent_Data](outputBufferSize)
var outWg sync.WaitGroup
// Create a context for waiting for and cancelling output pipes.
// Cancellation of the process via timeout will propagate and cancel this context too.
outCtx, outCancel := context.WithCancel(ctx)
h := &Handler{
Config: req.GetProcess(),
cmd: cmd,
Tag: req.Tag,
DataEvent: outMultiplex,
cancel: cancel,
outCtx: outCtx,
outCancel: outCancel,
EndEvent: NewMultiplexedChannel[rpc.ProcessEvent_End](0),
logger: logger,
}
if req.GetPty() != nil {
// The pty should ideally start only in the Start method, but the package does not support that and we would have to code it manually.
// The output of the pty should correctly be passed though.
tty, err := pty.StartWithSize(cmd, &pty.Winsize{
Cols: uint16(req.GetPty().GetSize().GetCols()),
Rows: uint16(req.GetPty().GetSize().GetRows()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error starting pty with command '%s' in dir '%s' with '%d' cols and '%d' rows: %w", userCmd, cmd.Dir, req.GetPty().GetSize().GetCols(), req.GetPty().GetSize().GetRows(), err))
}
outWg.Go(func() {
for {
buf := make([]byte, ptyChunkSize)
n, readErr := tty.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Pty{
Pty: buf[:n],
},
},
}
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from pty: %s\n", readErr)
break
}
}
})
h.tty = tty
} else {
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdout pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stdoutLogs := make(chan []byte, outputBufferSize)
defer close(stdoutLogs)
stdoutLogger := logger.With().Str("event_type", "stdout").Logger()
go logs.LogBufferedDataEvents(stdoutLogs, &stdoutLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stdout.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stdout{
Stdout: buf[:n],
},
},
}
stdoutLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stdout: %s\n", readErr)
break
}
}
})
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stderr pipe for command '%s': %w", userCmd, err))
}
outWg.Go(func() {
stderrLogs := make(chan []byte, outputBufferSize)
defer close(stderrLogs)
stderrLogger := logger.With().Str("event_type", "stderr").Logger()
go logs.LogBufferedDataEvents(stderrLogs, &stderrLogger, "data")
for {
buf := make([]byte, stdChunkSize)
n, readErr := stderr.Read(buf)
if n > 0 {
outMultiplex.Source <- rpc.ProcessEvent_Data{
Data: &rpc.ProcessEvent_DataEvent{
Output: &rpc.ProcessEvent_DataEvent_Stderr{
Stderr: buf[:n],
},
},
}
stderrLogs <- buf[:n]
}
if errors.Is(readErr, io.EOF) {
break
}
if readErr != nil {
fmt.Fprintf(os.Stderr, "error reading from stderr: %s\n", readErr)
break
}
}
})
// For backwards compatibility we still set the stdin if not explicitly disabled
// If stdin is disabled, the process will use /dev/null as stdin
if req.Stdin == nil || req.GetStdin() == true {
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error creating stdin pipe for command '%s': %w", userCmd, err))
}
h.stdin = stdin
}
}
go func() {
outWg.Wait()
close(outMultiplex.Source)
outCancel()
}()
return h, nil
}
func getProcType(req *rpc.StartRequest) cgroups.ProcessType {
if req != nil && req.GetPty() != nil {
return cgroups.ProcessTypePTY
}
return cgroups.ProcessTypeUser
}
func (p *Handler) SendSignal(signal syscall.Signal) error {
if p.cmd.Process == nil {
return fmt.Errorf("process not started")
}
if signal == syscall.SIGKILL || signal == syscall.SIGTERM {
p.outCancel()
}
return p.cmd.Process.Signal(signal)
}
func (p *Handler) ResizeTty(size *pty.Winsize) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process")
}
return pty.Setsize(p.tty, size)
}
func (p *Handler) WriteStdin(data []byte) error {
if p.tty != nil {
return fmt.Errorf("tty assigned to process — input should be written to the pty, not the stdin")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return fmt.Errorf("stdin not enabled or closed")
}
_, err := p.stdin.Write(data)
if err != nil {
return fmt.Errorf("error writing to stdin of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
// CloseStdin closes the stdin pipe to signal EOF to the process.
// Only works for non-PTY processes.
func (p *Handler) CloseStdin() error {
if p.tty != nil {
return fmt.Errorf("cannot close stdin for PTY process — send Ctrl+D (0x04) instead")
}
p.stdinMu.Lock()
defer p.stdinMu.Unlock()
if p.stdin == nil {
return nil
}
err := p.stdin.Close()
// We still set the stdin to nil even on error as there are no errors,
// for which it is really safe to retry close across all distributions.
p.stdin = nil
return err
}
func (p *Handler) WriteTty(data []byte) error {
if p.tty == nil {
return fmt.Errorf("tty not assigned to process — input should be written to the stdin, not the tty")
}
_, err := p.tty.Write(data)
if err != nil {
return fmt.Errorf("error writing to tty of process '%d': %w", p.cmd.Process.Pid, err)
}
return nil
}
func (p *Handler) Start() (uint32, error) {
// Pty is already started in the New method
if p.tty == nil {
err := p.cmd.Start()
if err != nil {
return 0, fmt.Errorf("error starting process '%s': %w", p.userCommand(), err)
}
}
p.logger.
Info().
Str("event_type", "process_start").
Int("pid", p.cmd.Process.Pid).
Str("command", p.userCommand()).
Msg(fmt.Sprintf("Process with pid %d started", p.cmd.Process.Pid))
return uint32(p.cmd.Process.Pid), nil
}
func (p *Handler) Wait() {
// Wait for the output pipes to be closed or cancelled.
<-p.outCtx.Done()
err := p.cmd.Wait()
p.tty.Close()
var errMsg *string
if err != nil {
msg := err.Error()
errMsg = &msg
}
endEvent := &rpc.ProcessEvent_EndEvent{
Error: errMsg,
ExitCode: int32(p.cmd.ProcessState.ExitCode()),
Exited: p.cmd.ProcessState.Exited(),
Status: p.cmd.ProcessState.String(),
}
event := rpc.ProcessEvent_End{
End: endEvent,
}
p.EndEvent.Source <- event
p.logger.
Info().
Str("event_type", "process_end").
Interface("process_result", endEvent).
Msg(fmt.Sprintf("Process with pid %d ended", p.cmd.Process.Pid))
// Ensure the process cancel is called to cleanup resources.
// As it is called after end event and Wait, it should not affect command execution or returned events.
p.cancel()
}

View File

@ -0,0 +1,73 @@
package handler
import (
"sync"
"sync/atomic"
)
type MultiplexedChannel[T any] struct {
Source chan T
channels []chan T
mu sync.RWMutex
exited atomic.Bool
}
func NewMultiplexedChannel[T any](buffer int) *MultiplexedChannel[T] {
c := &MultiplexedChannel[T]{
channels: nil,
Source: make(chan T, buffer),
}
go func() {
for v := range c.Source {
c.mu.RLock()
for _, cons := range c.channels {
cons <- v
}
c.mu.RUnlock()
}
c.exited.Store(true)
for _, cons := range c.channels {
close(cons)
}
}()
return c
}
func (m *MultiplexedChannel[T]) Fork() (chan T, func()) {
if m.exited.Load() {
ch := make(chan T)
close(ch)
return ch, func() {}
}
m.mu.Lock()
defer m.mu.Unlock()
consumer := make(chan T)
m.channels = append(m.channels, consumer)
return consumer, func() {
m.remove(consumer)
}
}
func (m *MultiplexedChannel[T]) remove(consumer chan T) {
m.mu.Lock()
defer m.mu.Unlock()
for i, ch := range m.channels {
if ch == consumer {
m.channels = append(m.channels[:i], m.channels[i+1:]...)
return
}
}
}

View File

@ -0,0 +1,107 @@
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func handleInput(ctx context.Context, process *handler.Handler, in *rpc.ProcessInput, logger *zerolog.Logger) error {
switch in.GetInput().(type) {
case *rpc.ProcessInput_Pty:
err := process.WriteTty(in.GetPty())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to tty: %w", err))
}
case *rpc.ProcessInput_Stdin:
err := process.WriteStdin(in.GetStdin())
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error writing to stdin: %w", err))
}
logger.Debug().
Str("event_type", "stdin").
Interface("stdin", in.GetStdin()).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Streaming input to process")
default:
return connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", in.GetInput()))
}
return nil
}
func (s *Service) SendInput(ctx context.Context, req *connect.Request[rpc.SendInputRequest]) (*connect.Response[rpc.SendInputResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
err = handleInput(ctx, proc, req.Msg.GetInput(), s.logger)
if err != nil {
return nil, err
}
return connect.NewResponse(&rpc.SendInputResponse{}), nil
}
func (s *Service) StreamInput(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
return logs.LogClientStreamWithoutEvents(ctx, s.logger, stream, s.streamInputHandler)
}
func (s *Service) streamInputHandler(ctx context.Context, stream *connect.ClientStream[rpc.StreamInputRequest]) (*connect.Response[rpc.StreamInputResponse], error) {
var proc *handler.Handler
for stream.Receive() {
req := stream.Msg()
switch req.GetEvent().(type) {
case *rpc.StreamInputRequest_Start:
p, err := s.getProcess(req.GetStart().GetProcess())
if err != nil {
return nil, err
}
proc = p
case *rpc.StreamInputRequest_Data:
err := handleInput(ctx, proc, req.GetData().GetInput(), s.logger)
if err != nil {
return nil, err
}
case *rpc.StreamInputRequest_Keepalive:
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid event type %T", req.GetEvent()))
}
}
err := stream.Err()
if err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error streaming input: %w", err))
}
return connect.NewResponse(&rpc.StreamInputResponse{}), nil
}
func (s *Service) CloseStdin(
_ context.Context,
req *connect.Request[rpc.CloseStdinRequest],
) (*connect.Response[rpc.CloseStdinResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if err := handler.CloseStdin(); err != nil {
return nil, connect.NewError(connect.CodeUnknown, fmt.Errorf("error closing stdin: %w", err))
}
return connect.NewResponse(&rpc.CloseStdinResponse{}), nil
}

View File

@ -0,0 +1,28 @@
package process
import (
"context"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) List(context.Context, *connect.Request[rpc.ListRequest]) (*connect.Response[rpc.ListResponse], error) {
processes := make([]*rpc.ProcessInfo, 0)
s.processes.Range(func(pid uint32, value *handler.Handler) bool {
processes = append(processes, &rpc.ProcessInfo{
Pid: pid,
Tag: value.Tag,
Config: value.Config,
})
return true
})
return connect.NewResponse(&rpc.ListResponse{
Processes: processes,
}), nil
}

View File

@ -0,0 +1,84 @@
package process
import (
"fmt"
"connectrpc.com/connect"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
spec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process/processconnect"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
type Service struct {
processes *utils.Map[uint32, *handler.Handler]
logger *zerolog.Logger
defaults *execcontext.Defaults
cgroupManager cgroups.Manager
}
func newService(l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
return &Service{
logger: l,
processes: utils.NewMap[uint32, *handler.Handler](),
defaults: defaults,
cgroupManager: cgroupManager,
}
}
func Handle(server *chi.Mux, l *zerolog.Logger, defaults *execcontext.Defaults, cgroupManager cgroups.Manager) *Service {
service := newService(l, defaults, cgroupManager)
interceptors := connect.WithInterceptors(logs.NewUnaryLogInterceptor(l))
path, h := spec.NewProcessHandler(service, interceptors)
server.Mount(path, h)
return service
}
func (s *Service) getProcess(selector *rpc.ProcessSelector) (*handler.Handler, error) {
var proc *handler.Handler
switch selector.GetSelector().(type) {
case *rpc.ProcessSelector_Pid:
p, ok := s.processes.Load(selector.GetPid())
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with pid %d not found", selector.GetPid()))
}
proc = p
case *rpc.ProcessSelector_Tag:
tag := selector.GetTag()
s.processes.Range(func(_ uint32, value *handler.Handler) bool {
if value.Tag == nil {
return true
}
if *value.Tag == tag {
proc = value
return true
}
return false
})
if proc == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("process with tag %s not found", tag))
}
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid input type %T", selector))
}
return proc, nil
}

View File

@ -0,0 +1,38 @@
package process
import (
"context"
"fmt"
"syscall"
"connectrpc.com/connect"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) SendSignal(
_ context.Context,
req *connect.Request[rpc.SendSignalRequest],
) (*connect.Response[rpc.SendSignalResponse], error) {
handler, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
var signal syscall.Signal
switch req.Msg.GetSignal() {
case rpc.Signal_SIGNAL_SIGKILL:
signal = syscall.SIGKILL
case rpc.Signal_SIGNAL_SIGTERM:
signal = syscall.SIGTERM
default:
return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("invalid signal: %s", req.Msg.GetSignal()))
}
err = handler.SendSignal(signal)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error sending signal: %w", err))
}
return connect.NewResponse(&rpc.SendSignalResponse{}), nil
}

View File

@ -0,0 +1,247 @@
package process
import (
"context"
"errors"
"fmt"
"net/http"
"os/user"
"strconv"
"time"
"connectrpc.com/connect"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/process/handler"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) InitializeStartProcess(ctx context.Context, user *user.User, req *rpc.StartRequest) error {
var err error
ctx = logs.AddRequestIDToContext(ctx)
defer s.logger.
Err(err).
Interface("request", req).
Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).
Msg("Initialized startCmd")
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
startProcCtx, startProcCancel := context.WithCancel(ctx)
proc, err := handler.New(startProcCtx, user, req, &handlerL, s.defaults, s.cgroupManager, startProcCancel)
if err != nil {
return err
}
pid, err := proc.Start()
if err != nil {
return err
}
s.processes.Store(pid, proc)
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
return nil
}
func (s *Service) Start(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
return logs.LogServerStreamWithoutEvents(ctx, s.logger, req, stream, s.handleStart)
}
func (s *Service) handleStart(ctx context.Context, req *connect.Request[rpc.StartRequest], stream *connect.ServerStream[rpc.StartResponse]) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
handlerL := s.logger.With().Str(string(logs.OperationIDKey), ctx.Value(logs.OperationIDKey).(string)).Logger()
u, err := permissions.GetAuthUser(ctx, s.defaults.User)
if err != nil {
return err
}
timeout, err := determineTimeoutFromHeader(stream.Conn().RequestHeader())
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
// Create a new context with a timeout if provided.
// We do not want the command to be killed if the request context is cancelled
procCtx, cancelProc := context.Background(), func() {}
if timeout > 0 { // zero timeout means no timeout
procCtx, cancelProc = context.WithTimeout(procCtx, timeout)
}
proc, err := handler.New( //nolint:contextcheck // TODO: fix this later
procCtx,
u,
req.Msg,
&handlerL,
s.defaults,
s.cgroupManager,
cancelProc,
)
if err != nil {
// Ensure the process cancel is called to cleanup resources.
cancelProc()
return err
}
exitChan := make(chan struct{})
startMultiplexer := handler.NewMultiplexedChannel[rpc.ProcessEvent_Start](0)
defer close(startMultiplexer.Source)
start, startCancel := startMultiplexer.Fork()
defer startCancel()
data, dataCancel := proc.DataEvent.Fork()
defer dataCancel()
end, endCancel := proc.EndEvent.Fork()
defer endCancel()
go func() {
defer close(exitChan)
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-start:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("start event channel closed before sending start event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending start event: %w", streamErr)))
return
}
}
keepaliveTicker, resetKeepalive := permissions.GetKeepAliveTicker(req)
defer keepaliveTicker.Stop()
dataLoop:
for {
select {
case <-keepaliveTicker.C:
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &rpc.ProcessEvent_Keepalive{
Keepalive: &rpc.ProcessEvent_KeepAlive{},
},
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending keepalive: %w", streamErr)))
return
}
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-data:
if !ok {
break dataLoop
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending data event: %w", streamErr)))
return
}
resetKeepalive()
}
}
select {
case <-ctx.Done():
cancel(ctx.Err())
return
case event, ok := <-end:
if !ok {
cancel(connect.NewError(connect.CodeUnknown, errors.New("end event channel closed before sending end event")))
return
}
streamErr := stream.Send(&rpc.StartResponse{
Event: &rpc.ProcessEvent{
Event: &event,
},
})
if streamErr != nil {
cancel(connect.NewError(connect.CodeUnknown, fmt.Errorf("error sending end event: %w", streamErr)))
return
}
}
}()
pid, err := proc.Start()
if err != nil {
return connect.NewError(connect.CodeInvalidArgument, err)
}
s.processes.Store(pid, proc)
start <- rpc.ProcessEvent_Start{
Start: &rpc.ProcessEvent_StartEvent{
Pid: pid,
},
}
go func() {
defer s.processes.Delete(pid)
proc.Wait()
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-exitChan:
return nil
}
}
func determineTimeoutFromHeader(header http.Header) (time.Duration, error) {
timeoutHeader := header.Get("Connect-Timeout-Ms")
if timeoutHeader == "" {
return 0, nil
}
timeout, err := strconv.Atoi(timeoutHeader)
if err != nil {
return 0, err
}
return time.Duration(timeout) * time.Millisecond, nil
}

View File

@ -0,0 +1,30 @@
package process
import (
"context"
"fmt"
"connectrpc.com/connect"
"github.com/creack/pty"
rpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
)
func (s *Service) Update(_ context.Context, req *connect.Request[rpc.UpdateRequest]) (*connect.Response[rpc.UpdateResponse], error) {
proc, err := s.getProcess(req.Msg.GetProcess())
if err != nil {
return nil, err
}
if req.Msg.GetPty() != nil {
err := proc.ResizeTty(&pty.Winsize{
Rows: uint16(req.Msg.GetPty().GetSize().GetRows()),
Cols: uint16(req.Msg.GetPty().GetSize().GetCols()),
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("error resizing tty: %w", err))
}
}
return connect.NewResponse(&rpc.UpdateResponse{}), nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,337 @@
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: filesystem/filesystem.proto
package filesystemconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
filesystem "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/filesystem"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// FilesystemName is the fully-qualified name of the Filesystem service.
FilesystemName = "filesystem.Filesystem"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// FilesystemStatProcedure is the fully-qualified name of the Filesystem's Stat RPC.
FilesystemStatProcedure = "/filesystem.Filesystem/Stat"
// FilesystemMakeDirProcedure is the fully-qualified name of the Filesystem's MakeDir RPC.
FilesystemMakeDirProcedure = "/filesystem.Filesystem/MakeDir"
// FilesystemMoveProcedure is the fully-qualified name of the Filesystem's Move RPC.
FilesystemMoveProcedure = "/filesystem.Filesystem/Move"
// FilesystemListDirProcedure is the fully-qualified name of the Filesystem's ListDir RPC.
FilesystemListDirProcedure = "/filesystem.Filesystem/ListDir"
// FilesystemRemoveProcedure is the fully-qualified name of the Filesystem's Remove RPC.
FilesystemRemoveProcedure = "/filesystem.Filesystem/Remove"
// FilesystemWatchDirProcedure is the fully-qualified name of the Filesystem's WatchDir RPC.
FilesystemWatchDirProcedure = "/filesystem.Filesystem/WatchDir"
// FilesystemCreateWatcherProcedure is the fully-qualified name of the Filesystem's CreateWatcher
// RPC.
FilesystemCreateWatcherProcedure = "/filesystem.Filesystem/CreateWatcher"
// FilesystemGetWatcherEventsProcedure is the fully-qualified name of the Filesystem's
// GetWatcherEvents RPC.
FilesystemGetWatcherEventsProcedure = "/filesystem.Filesystem/GetWatcherEvents"
// FilesystemRemoveWatcherProcedure is the fully-qualified name of the Filesystem's RemoveWatcher
// RPC.
FilesystemRemoveWatcherProcedure = "/filesystem.Filesystem/RemoveWatcher"
)
// FilesystemClient is a client for the filesystem.Filesystem service.
type FilesystemClient interface {
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error)
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
}
// NewFilesystemClient constructs a client for the filesystem.Filesystem service. By default, it
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewFilesystemClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) FilesystemClient {
baseURL = strings.TrimRight(baseURL, "/")
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
return &filesystemClient{
stat: connect.NewClient[filesystem.StatRequest, filesystem.StatResponse](
httpClient,
baseURL+FilesystemStatProcedure,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithClientOptions(opts...),
),
makeDir: connect.NewClient[filesystem.MakeDirRequest, filesystem.MakeDirResponse](
httpClient,
baseURL+FilesystemMakeDirProcedure,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithClientOptions(opts...),
),
move: connect.NewClient[filesystem.MoveRequest, filesystem.MoveResponse](
httpClient,
baseURL+FilesystemMoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithClientOptions(opts...),
),
listDir: connect.NewClient[filesystem.ListDirRequest, filesystem.ListDirResponse](
httpClient,
baseURL+FilesystemListDirProcedure,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithClientOptions(opts...),
),
remove: connect.NewClient[filesystem.RemoveRequest, filesystem.RemoveResponse](
httpClient,
baseURL+FilesystemRemoveProcedure,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithClientOptions(opts...),
),
watchDir: connect.NewClient[filesystem.WatchDirRequest, filesystem.WatchDirResponse](
httpClient,
baseURL+FilesystemWatchDirProcedure,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithClientOptions(opts...),
),
createWatcher: connect.NewClient[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse](
httpClient,
baseURL+FilesystemCreateWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithClientOptions(opts...),
),
getWatcherEvents: connect.NewClient[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse](
httpClient,
baseURL+FilesystemGetWatcherEventsProcedure,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithClientOptions(opts...),
),
removeWatcher: connect.NewClient[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse](
httpClient,
baseURL+FilesystemRemoveWatcherProcedure,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithClientOptions(opts...),
),
}
}
// filesystemClient implements FilesystemClient.
type filesystemClient struct {
stat *connect.Client[filesystem.StatRequest, filesystem.StatResponse]
makeDir *connect.Client[filesystem.MakeDirRequest, filesystem.MakeDirResponse]
move *connect.Client[filesystem.MoveRequest, filesystem.MoveResponse]
listDir *connect.Client[filesystem.ListDirRequest, filesystem.ListDirResponse]
remove *connect.Client[filesystem.RemoveRequest, filesystem.RemoveResponse]
watchDir *connect.Client[filesystem.WatchDirRequest, filesystem.WatchDirResponse]
createWatcher *connect.Client[filesystem.CreateWatcherRequest, filesystem.CreateWatcherResponse]
getWatcherEvents *connect.Client[filesystem.GetWatcherEventsRequest, filesystem.GetWatcherEventsResponse]
removeWatcher *connect.Client[filesystem.RemoveWatcherRequest, filesystem.RemoveWatcherResponse]
}
// Stat calls filesystem.Filesystem.Stat.
func (c *filesystemClient) Stat(ctx context.Context, req *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
return c.stat.CallUnary(ctx, req)
}
// MakeDir calls filesystem.Filesystem.MakeDir.
func (c *filesystemClient) MakeDir(ctx context.Context, req *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
return c.makeDir.CallUnary(ctx, req)
}
// Move calls filesystem.Filesystem.Move.
func (c *filesystemClient) Move(ctx context.Context, req *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
return c.move.CallUnary(ctx, req)
}
// ListDir calls filesystem.Filesystem.ListDir.
func (c *filesystemClient) ListDir(ctx context.Context, req *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
return c.listDir.CallUnary(ctx, req)
}
// Remove calls filesystem.Filesystem.Remove.
func (c *filesystemClient) Remove(ctx context.Context, req *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
return c.remove.CallUnary(ctx, req)
}
// WatchDir calls filesystem.Filesystem.WatchDir.
func (c *filesystemClient) WatchDir(ctx context.Context, req *connect.Request[filesystem.WatchDirRequest]) (*connect.ServerStreamForClient[filesystem.WatchDirResponse], error) {
return c.watchDir.CallServerStream(ctx, req)
}
// CreateWatcher calls filesystem.Filesystem.CreateWatcher.
func (c *filesystemClient) CreateWatcher(ctx context.Context, req *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
return c.createWatcher.CallUnary(ctx, req)
}
// GetWatcherEvents calls filesystem.Filesystem.GetWatcherEvents.
func (c *filesystemClient) GetWatcherEvents(ctx context.Context, req *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
return c.getWatcherEvents.CallUnary(ctx, req)
}
// RemoveWatcher calls filesystem.Filesystem.RemoveWatcher.
func (c *filesystemClient) RemoveWatcher(ctx context.Context, req *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
return c.removeWatcher.CallUnary(ctx, req)
}
// FilesystemHandler is an implementation of the filesystem.Filesystem service.
type FilesystemHandler interface {
Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error)
MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error)
Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error)
ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error)
Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error)
WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error
// Non-streaming versions of WatchDir
CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error)
GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error)
RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error)
}
// NewFilesystemHandler builds an HTTP handler from the service implementation. It returns the path
// on which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewFilesystemHandler(svc FilesystemHandler, opts ...connect.HandlerOption) (string, http.Handler) {
filesystemMethods := filesystem.File_filesystem_filesystem_proto.Services().ByName("Filesystem").Methods()
filesystemStatHandler := connect.NewUnaryHandler(
FilesystemStatProcedure,
svc.Stat,
connect.WithSchema(filesystemMethods.ByName("Stat")),
connect.WithHandlerOptions(opts...),
)
filesystemMakeDirHandler := connect.NewUnaryHandler(
FilesystemMakeDirProcedure,
svc.MakeDir,
connect.WithSchema(filesystemMethods.ByName("MakeDir")),
connect.WithHandlerOptions(opts...),
)
filesystemMoveHandler := connect.NewUnaryHandler(
FilesystemMoveProcedure,
svc.Move,
connect.WithSchema(filesystemMethods.ByName("Move")),
connect.WithHandlerOptions(opts...),
)
filesystemListDirHandler := connect.NewUnaryHandler(
FilesystemListDirProcedure,
svc.ListDir,
connect.WithSchema(filesystemMethods.ByName("ListDir")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveHandler := connect.NewUnaryHandler(
FilesystemRemoveProcedure,
svc.Remove,
connect.WithSchema(filesystemMethods.ByName("Remove")),
connect.WithHandlerOptions(opts...),
)
filesystemWatchDirHandler := connect.NewServerStreamHandler(
FilesystemWatchDirProcedure,
svc.WatchDir,
connect.WithSchema(filesystemMethods.ByName("WatchDir")),
connect.WithHandlerOptions(opts...),
)
filesystemCreateWatcherHandler := connect.NewUnaryHandler(
FilesystemCreateWatcherProcedure,
svc.CreateWatcher,
connect.WithSchema(filesystemMethods.ByName("CreateWatcher")),
connect.WithHandlerOptions(opts...),
)
filesystemGetWatcherEventsHandler := connect.NewUnaryHandler(
FilesystemGetWatcherEventsProcedure,
svc.GetWatcherEvents,
connect.WithSchema(filesystemMethods.ByName("GetWatcherEvents")),
connect.WithHandlerOptions(opts...),
)
filesystemRemoveWatcherHandler := connect.NewUnaryHandler(
FilesystemRemoveWatcherProcedure,
svc.RemoveWatcher,
connect.WithSchema(filesystemMethods.ByName("RemoveWatcher")),
connect.WithHandlerOptions(opts...),
)
return "/filesystem.Filesystem/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case FilesystemStatProcedure:
filesystemStatHandler.ServeHTTP(w, r)
case FilesystemMakeDirProcedure:
filesystemMakeDirHandler.ServeHTTP(w, r)
case FilesystemMoveProcedure:
filesystemMoveHandler.ServeHTTP(w, r)
case FilesystemListDirProcedure:
filesystemListDirHandler.ServeHTTP(w, r)
case FilesystemRemoveProcedure:
filesystemRemoveHandler.ServeHTTP(w, r)
case FilesystemWatchDirProcedure:
filesystemWatchDirHandler.ServeHTTP(w, r)
case FilesystemCreateWatcherProcedure:
filesystemCreateWatcherHandler.ServeHTTP(w, r)
case FilesystemGetWatcherEventsProcedure:
filesystemGetWatcherEventsHandler.ServeHTTP(w, r)
case FilesystemRemoveWatcherProcedure:
filesystemRemoveWatcherHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedFilesystemHandler returns CodeUnimplemented from all methods.
type UnimplementedFilesystemHandler struct{}
func (UnimplementedFilesystemHandler) Stat(context.Context, *connect.Request[filesystem.StatRequest]) (*connect.Response[filesystem.StatResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Stat is not implemented"))
}
func (UnimplementedFilesystemHandler) MakeDir(context.Context, *connect.Request[filesystem.MakeDirRequest]) (*connect.Response[filesystem.MakeDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.MakeDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Move(context.Context, *connect.Request[filesystem.MoveRequest]) (*connect.Response[filesystem.MoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Move is not implemented"))
}
func (UnimplementedFilesystemHandler) ListDir(context.Context, *connect.Request[filesystem.ListDirRequest]) (*connect.Response[filesystem.ListDirResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.ListDir is not implemented"))
}
func (UnimplementedFilesystemHandler) Remove(context.Context, *connect.Request[filesystem.RemoveRequest]) (*connect.Response[filesystem.RemoveResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.Remove is not implemented"))
}
func (UnimplementedFilesystemHandler) WatchDir(context.Context, *connect.Request[filesystem.WatchDirRequest], *connect.ServerStream[filesystem.WatchDirResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.WatchDir is not implemented"))
}
func (UnimplementedFilesystemHandler) CreateWatcher(context.Context, *connect.Request[filesystem.CreateWatcherRequest]) (*connect.Response[filesystem.CreateWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.CreateWatcher is not implemented"))
}
func (UnimplementedFilesystemHandler) GetWatcherEvents(context.Context, *connect.Request[filesystem.GetWatcherEventsRequest]) (*connect.Response[filesystem.GetWatcherEventsResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.GetWatcherEvents is not implemented"))
}
func (UnimplementedFilesystemHandler) RemoveWatcher(context.Context, *connect.Request[filesystem.RemoveWatcherRequest]) (*connect.Response[filesystem.RemoveWatcherResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("filesystem.Filesystem.RemoveWatcher is not implemented"))
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,310 @@
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: process/process.proto
package processconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
process "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// ProcessName is the fully-qualified name of the Process service.
ProcessName = "process.Process"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// ProcessListProcedure is the fully-qualified name of the Process's List RPC.
ProcessListProcedure = "/process.Process/List"
// ProcessConnectProcedure is the fully-qualified name of the Process's Connect RPC.
ProcessConnectProcedure = "/process.Process/Connect"
// ProcessStartProcedure is the fully-qualified name of the Process's Start RPC.
ProcessStartProcedure = "/process.Process/Start"
// ProcessUpdateProcedure is the fully-qualified name of the Process's Update RPC.
ProcessUpdateProcedure = "/process.Process/Update"
// ProcessStreamInputProcedure is the fully-qualified name of the Process's StreamInput RPC.
ProcessStreamInputProcedure = "/process.Process/StreamInput"
// ProcessSendInputProcedure is the fully-qualified name of the Process's SendInput RPC.
ProcessSendInputProcedure = "/process.Process/SendInput"
// ProcessSendSignalProcedure is the fully-qualified name of the Process's SendSignal RPC.
ProcessSendSignalProcedure = "/process.Process/SendSignal"
// ProcessCloseStdinProcedure is the fully-qualified name of the Process's CloseStdin RPC.
ProcessCloseStdinProcedure = "/process.Process/CloseStdin"
)
// ProcessClient is a client for the process.Process service.
type ProcessClient interface {
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
Connect(context.Context, *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error)
Start(context.Context, *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error)
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse]
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
}
// NewProcessClient constructs a client for the process.Process service. By default, it uses the
// Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewProcessClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ProcessClient {
baseURL = strings.TrimRight(baseURL, "/")
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
return &processClient{
list: connect.NewClient[process.ListRequest, process.ListResponse](
httpClient,
baseURL+ProcessListProcedure,
connect.WithSchema(processMethods.ByName("List")),
connect.WithClientOptions(opts...),
),
connect: connect.NewClient[process.ConnectRequest, process.ConnectResponse](
httpClient,
baseURL+ProcessConnectProcedure,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithClientOptions(opts...),
),
start: connect.NewClient[process.StartRequest, process.StartResponse](
httpClient,
baseURL+ProcessStartProcedure,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithClientOptions(opts...),
),
update: connect.NewClient[process.UpdateRequest, process.UpdateResponse](
httpClient,
baseURL+ProcessUpdateProcedure,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithClientOptions(opts...),
),
streamInput: connect.NewClient[process.StreamInputRequest, process.StreamInputResponse](
httpClient,
baseURL+ProcessStreamInputProcedure,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithClientOptions(opts...),
),
sendInput: connect.NewClient[process.SendInputRequest, process.SendInputResponse](
httpClient,
baseURL+ProcessSendInputProcedure,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithClientOptions(opts...),
),
sendSignal: connect.NewClient[process.SendSignalRequest, process.SendSignalResponse](
httpClient,
baseURL+ProcessSendSignalProcedure,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithClientOptions(opts...),
),
closeStdin: connect.NewClient[process.CloseStdinRequest, process.CloseStdinResponse](
httpClient,
baseURL+ProcessCloseStdinProcedure,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithClientOptions(opts...),
),
}
}
// processClient implements ProcessClient.
type processClient struct {
list *connect.Client[process.ListRequest, process.ListResponse]
connect *connect.Client[process.ConnectRequest, process.ConnectResponse]
start *connect.Client[process.StartRequest, process.StartResponse]
update *connect.Client[process.UpdateRequest, process.UpdateResponse]
streamInput *connect.Client[process.StreamInputRequest, process.StreamInputResponse]
sendInput *connect.Client[process.SendInputRequest, process.SendInputResponse]
sendSignal *connect.Client[process.SendSignalRequest, process.SendSignalResponse]
closeStdin *connect.Client[process.CloseStdinRequest, process.CloseStdinResponse]
}
// List calls process.Process.List.
func (c *processClient) List(ctx context.Context, req *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
return c.list.CallUnary(ctx, req)
}
// Connect calls process.Process.Connect.
func (c *processClient) Connect(ctx context.Context, req *connect.Request[process.ConnectRequest]) (*connect.ServerStreamForClient[process.ConnectResponse], error) {
return c.connect.CallServerStream(ctx, req)
}
// Start calls process.Process.Start.
func (c *processClient) Start(ctx context.Context, req *connect.Request[process.StartRequest]) (*connect.ServerStreamForClient[process.StartResponse], error) {
return c.start.CallServerStream(ctx, req)
}
// Update calls process.Process.Update.
func (c *processClient) Update(ctx context.Context, req *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
return c.update.CallUnary(ctx, req)
}
// StreamInput calls process.Process.StreamInput.
func (c *processClient) StreamInput(ctx context.Context) *connect.ClientStreamForClient[process.StreamInputRequest, process.StreamInputResponse] {
return c.streamInput.CallClientStream(ctx)
}
// SendInput calls process.Process.SendInput.
func (c *processClient) SendInput(ctx context.Context, req *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
return c.sendInput.CallUnary(ctx, req)
}
// SendSignal calls process.Process.SendSignal.
func (c *processClient) SendSignal(ctx context.Context, req *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
return c.sendSignal.CallUnary(ctx, req)
}
// CloseStdin calls process.Process.CloseStdin.
func (c *processClient) CloseStdin(ctx context.Context, req *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
return c.closeStdin.CallUnary(ctx, req)
}
// ProcessHandler is an implementation of the process.Process service.
type ProcessHandler interface {
List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error)
Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error
Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error
Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error)
// Client input stream ensures ordering of messages
StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error)
SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error)
SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error)
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error)
}
// NewProcessHandler builds an HTTP handler from the service implementation. It returns the path on
// which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewProcessHandler(svc ProcessHandler, opts ...connect.HandlerOption) (string, http.Handler) {
processMethods := process.File_process_process_proto.Services().ByName("Process").Methods()
processListHandler := connect.NewUnaryHandler(
ProcessListProcedure,
svc.List,
connect.WithSchema(processMethods.ByName("List")),
connect.WithHandlerOptions(opts...),
)
processConnectHandler := connect.NewServerStreamHandler(
ProcessConnectProcedure,
svc.Connect,
connect.WithSchema(processMethods.ByName("Connect")),
connect.WithHandlerOptions(opts...),
)
processStartHandler := connect.NewServerStreamHandler(
ProcessStartProcedure,
svc.Start,
connect.WithSchema(processMethods.ByName("Start")),
connect.WithHandlerOptions(opts...),
)
processUpdateHandler := connect.NewUnaryHandler(
ProcessUpdateProcedure,
svc.Update,
connect.WithSchema(processMethods.ByName("Update")),
connect.WithHandlerOptions(opts...),
)
processStreamInputHandler := connect.NewClientStreamHandler(
ProcessStreamInputProcedure,
svc.StreamInput,
connect.WithSchema(processMethods.ByName("StreamInput")),
connect.WithHandlerOptions(opts...),
)
processSendInputHandler := connect.NewUnaryHandler(
ProcessSendInputProcedure,
svc.SendInput,
connect.WithSchema(processMethods.ByName("SendInput")),
connect.WithHandlerOptions(opts...),
)
processSendSignalHandler := connect.NewUnaryHandler(
ProcessSendSignalProcedure,
svc.SendSignal,
connect.WithSchema(processMethods.ByName("SendSignal")),
connect.WithHandlerOptions(opts...),
)
processCloseStdinHandler := connect.NewUnaryHandler(
ProcessCloseStdinProcedure,
svc.CloseStdin,
connect.WithSchema(processMethods.ByName("CloseStdin")),
connect.WithHandlerOptions(opts...),
)
return "/process.Process/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case ProcessListProcedure:
processListHandler.ServeHTTP(w, r)
case ProcessConnectProcedure:
processConnectHandler.ServeHTTP(w, r)
case ProcessStartProcedure:
processStartHandler.ServeHTTP(w, r)
case ProcessUpdateProcedure:
processUpdateHandler.ServeHTTP(w, r)
case ProcessStreamInputProcedure:
processStreamInputHandler.ServeHTTP(w, r)
case ProcessSendInputProcedure:
processSendInputHandler.ServeHTTP(w, r)
case ProcessSendSignalProcedure:
processSendSignalHandler.ServeHTTP(w, r)
case ProcessCloseStdinProcedure:
processCloseStdinHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedProcessHandler returns CodeUnimplemented from all methods.
type UnimplementedProcessHandler struct{}
func (UnimplementedProcessHandler) List(context.Context, *connect.Request[process.ListRequest]) (*connect.Response[process.ListResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.List is not implemented"))
}
func (UnimplementedProcessHandler) Connect(context.Context, *connect.Request[process.ConnectRequest], *connect.ServerStream[process.ConnectResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Connect is not implemented"))
}
func (UnimplementedProcessHandler) Start(context.Context, *connect.Request[process.StartRequest], *connect.ServerStream[process.StartResponse]) error {
return connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Start is not implemented"))
}
func (UnimplementedProcessHandler) Update(context.Context, *connect.Request[process.UpdateRequest]) (*connect.Response[process.UpdateResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.Update is not implemented"))
}
func (UnimplementedProcessHandler) StreamInput(context.Context, *connect.ClientStream[process.StreamInputRequest]) (*connect.Response[process.StreamInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.StreamInput is not implemented"))
}
func (UnimplementedProcessHandler) SendInput(context.Context, *connect.Request[process.SendInputRequest]) (*connect.Response[process.SendInputResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendInput is not implemented"))
}
func (UnimplementedProcessHandler) SendSignal(context.Context, *connect.Request[process.SendSignalRequest]) (*connect.Response[process.SendSignalResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.SendSignal is not implemented"))
}
func (UnimplementedProcessHandler) CloseStdin(context.Context, *connect.Request[process.CloseStdinRequest]) (*connect.Response[process.CloseStdinResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("process.Process.CloseStdin is not implemented"))
}

View File

@ -0,0 +1,108 @@
package filesystem
import (
"os"
"path/filepath"
"syscall"
"time"
)
func GetEntryFromPath(path string) (EntryInfo, error) {
fileInfo, err := os.Lstat(path)
if err != nil {
return EntryInfo{}, err
}
return GetEntryInfo(path, fileInfo), nil
}
func GetEntryInfo(path string, fileInfo os.FileInfo) EntryInfo {
fileMode := fileInfo.Mode()
var symlinkTarget *string
if fileMode&os.ModeSymlink != 0 {
// If we can't resolve the symlink target, we won't set the target
target := followSymlink(path)
symlinkTarget = &target
}
var entryType FileType
var mode os.FileMode
if symlinkTarget == nil {
entryType = getEntryType(fileMode)
mode = fileMode.Perm()
} else {
// If it's a symlink, we need to determine the type of the target
targetInfo, err := os.Stat(*symlinkTarget)
if err != nil {
entryType = UnknownFileType
} else {
entryType = getEntryType(targetInfo.Mode())
mode = targetInfo.Mode().Perm()
}
}
entry := EntryInfo{
Name: fileInfo.Name(),
Path: path,
Type: entryType,
Size: fileInfo.Size(),
Mode: mode,
Permissions: fileMode.String(),
ModifiedTime: fileInfo.ModTime(),
SymlinkTarget: symlinkTarget,
}
if base := getBase(fileInfo.Sys()); base != nil {
entry.AccessedTime = toTimestamp(base.Atim)
entry.CreatedTime = toTimestamp(base.Ctim)
entry.ModifiedTime = toTimestamp(base.Mtim)
entry.UID = base.Uid
entry.GID = base.Gid
} else if !fileInfo.ModTime().IsZero() {
entry.ModifiedTime = fileInfo.ModTime()
}
return entry
}
// getEntryType determines the type of file entry based on its mode and path.
// If the file is a symlink, it follows the symlink to determine the actual type.
func getEntryType(mode os.FileMode) FileType {
switch {
case mode.IsRegular():
return FileFileType
case mode.IsDir():
return DirectoryFileType
case mode&os.ModeSymlink == os.ModeSymlink:
return SymlinkFileType
default:
return UnknownFileType
}
}
// followSymlink resolves a symbolic link to its target path.
func followSymlink(path string) string {
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(path)
if err != nil {
return path
}
return resolvedPath
}
func toTimestamp(spec syscall.Timespec) time.Time {
if spec.Sec == 0 && spec.Nsec == 0 {
return time.Time{}
}
return time.Unix(spec.Sec, spec.Nsec)
}
func getBase(sys any) *syscall.Stat_t {
st, _ := sys.(*syscall.Stat_t)
return st
}

View File

@ -0,0 +1,264 @@
package filesystem
import (
"os"
"os/user"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetEntryType(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create test files
regularFile := filepath.Join(tempDir, "regular.txt")
require.NoError(t, os.WriteFile(regularFile, []byte("test content"), 0o644))
testDir := filepath.Join(tempDir, "testdir")
require.NoError(t, os.MkdirAll(testDir, 0o755))
symlink := filepath.Join(tempDir, "symlink")
require.NoError(t, os.Symlink(regularFile, symlink))
tests := []struct {
name string
path string
expected FileType
}{
{
name: "regular file",
path: regularFile,
expected: FileFileType,
},
{
name: "directory",
path: testDir,
expected: DirectoryFileType,
},
{
name: "symlink to file",
path: symlink,
expected: SymlinkFileType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
info, err := os.Lstat(tt.path)
require.NoError(t, err)
result := getEntryType(info.Mode())
assert.Equal(t, tt.expected, result)
})
}
}
func TestEntryInfoFromFileInfo_SymlinkChain(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
tempDir := t.TempDir()
// Create final target
target := filepath.Join(tempDir, "target")
require.NoError(t, os.MkdirAll(target, 0o755))
// Create a chain: link1 → link2 → target
link2 := filepath.Join(tempDir, "link2")
require.NoError(t, os.Symlink(target, link2))
link1 := filepath.Join(tempDir, "link1")
require.NoError(t, os.Symlink(link2, link1))
// run the test
result, err := GetEntryFromPath(link1)
require.NoError(t, err)
// verify the results
assert.Equal(t, "link1", result.Name)
assert.Equal(t, link1, result.Path)
assert.Equal(t, DirectoryFileType, result.Type) // Should resolve to final target type
assert.Contains(t, result.Permissions, "L")
// Canonicalize the expected target path to handle macOS symlink indirections
expectedTarget, err := filepath.EvalSymlinks(link1)
require.NoError(t, err)
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
}
func TestEntryInfoFromFileInfo_DifferentPermissions(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
testCases := []struct {
name string
permissions os.FileMode
expectedMode os.FileMode
expectedString string
}{
{"read-only", 0o444, 0o444, "-r--r--r--"},
{"executable", 0o755, 0o755, "-rwxr-xr-x"},
{"write-only", 0o200, 0o200, "--w-------"},
{"no permissions", 0o000, 0o000, "----------"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testFile := filepath.Join(tempDir, tc.name+".txt")
require.NoError(t, os.WriteFile(testFile, []byte("test"), tc.permissions))
result, err := GetEntryFromPath(testFile)
require.NoError(t, err)
assert.Equal(t, tc.expectedMode, result.Mode)
assert.Equal(t, tc.expectedString, result.Permissions)
})
}
}
func TestEntryInfoFromFileInfo_EmptyFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
emptyFile := filepath.Join(tempDir, "empty.txt")
require.NoError(t, os.WriteFile(emptyFile, []byte{}, 0o600))
result, err := GetEntryFromPath(emptyFile)
require.NoError(t, err)
assert.Equal(t, "empty.txt", result.Name)
assert.Equal(t, int64(0), result.Size)
assert.Equal(t, os.FileMode(0o600), result.Mode)
assert.Equal(t, FileFileType, result.Type)
}
func TestEntryInfoFromFileInfo_CyclicSymlink(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create cyclic symlink
cyclicSymlink := filepath.Join(tempDir, "cyclic")
require.NoError(t, os.Symlink(cyclicSymlink, cyclicSymlink))
result, err := GetEntryFromPath(cyclicSymlink)
require.NoError(t, err)
assert.Equal(t, "cyclic", result.Name)
assert.Equal(t, cyclicSymlink, result.Path)
assert.Equal(t, UnknownFileType, result.Type)
assert.Contains(t, result.Permissions, "L")
}
func TestEntryInfoFromFileInfo_BrokenSymlink(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create broken symlink
brokenSymlink := filepath.Join(tempDir, "broken")
require.NoError(t, os.Symlink("/nonexistent", brokenSymlink))
result, err := GetEntryFromPath(brokenSymlink)
require.NoError(t, err)
assert.Equal(t, "broken", result.Name)
assert.Equal(t, brokenSymlink, result.Path)
assert.Equal(t, UnknownFileType, result.Type)
assert.Contains(t, result.Permissions, "L")
// SymlinkTarget might be empty if followSymlink fails
}
func TestEntryInfoFromFileInfo(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
// Create a regular file with known content and permissions
testFile := filepath.Join(tempDir, "test.txt")
testContent := []byte("Hello, World!")
require.NoError(t, os.WriteFile(testFile, testContent, 0o644))
// Get current user for ownership comparison
currentUser, err := user.Current()
require.NoError(t, err)
result, err := GetEntryFromPath(testFile)
require.NoError(t, err)
// Basic assertions
assert.Equal(t, "test.txt", result.Name)
assert.Equal(t, testFile, result.Path)
assert.Equal(t, int64(len(testContent)), result.Size)
assert.Equal(t, FileFileType, result.Type)
assert.Equal(t, os.FileMode(0o644), result.Mode)
assert.Contains(t, result.Permissions, "-rw-r--r--")
assert.Equal(t, currentUser.Uid, strconv.Itoa(int(result.UID)))
assert.Equal(t, currentUser.Gid, strconv.Itoa(int(result.GID)))
assert.NotNil(t, result.ModifiedTime)
assert.Empty(t, result.SymlinkTarget)
// Check that modified time is reasonable (within last minute)
modTime := result.ModifiedTime
assert.WithinDuration(t, time.Now(), modTime, time.Minute)
}
func TestEntryInfoFromFileInfo_Directory(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
testDir := filepath.Join(tempDir, "testdir")
require.NoError(t, os.MkdirAll(testDir, 0o755))
result, err := GetEntryFromPath(testDir)
require.NoError(t, err)
assert.Equal(t, "testdir", result.Name)
assert.Equal(t, testDir, result.Path)
assert.Equal(t, DirectoryFileType, result.Type)
assert.Equal(t, os.FileMode(0o755), result.Mode)
assert.Equal(t, "drwxr-xr-x", result.Permissions)
assert.Empty(t, result.SymlinkTarget)
}
func TestEntryInfoFromFileInfo_Symlink(t *testing.T) {
t.Parallel()
// Base temporary directory. On macOS this lives under /var/folders/…
// which itself is a symlink to /private/var/folders/….
tempDir := t.TempDir()
// Create target file
targetFile := filepath.Join(tempDir, "target.txt")
require.NoError(t, os.WriteFile(targetFile, []byte("target content"), 0o644))
// Create symlink
symlinkPath := filepath.Join(tempDir, "symlink")
require.NoError(t, os.Symlink(targetFile, symlinkPath))
// Use Lstat to get symlink info (not the target)
result, err := GetEntryFromPath(symlinkPath)
require.NoError(t, err)
assert.Equal(t, "symlink", result.Name)
assert.Equal(t, symlinkPath, result.Path)
assert.Equal(t, FileFileType, result.Type) // Should resolve to target type
assert.Contains(t, result.Permissions, "L") // Should show as symlink in permissions
// Canonicalize the expected target path to handle macOS /var → /private/var symlink
expectedTarget, err := filepath.EvalSymlinks(symlinkPath)
require.NoError(t, err)
assert.Equal(t, expectedTarget, *result.SymlinkTarget)
}

View File

@ -0,0 +1,30 @@
package filesystem
import (
"os"
"time"
)
type EntryInfo struct {
Name string
Type FileType
Path string
Size int64
Mode os.FileMode
Permissions string
UID uint32
GID uint32
AccessedTime time.Time
CreatedTime time.Time
ModifiedTime time.Time
SymlinkTarget *string
}
type FileType int32
const (
UnknownFileType FileType = 0
FileFileType FileType = 1
DirectoryFileType FileType = 2
SymlinkFileType FileType = 3
)

View File

@ -0,0 +1,164 @@
package id
import (
"errors"
"fmt"
"maps"
"regexp"
"slices"
"strings"
"github.com/dchest/uniuri"
"github.com/google/uuid"
)
var (
caseInsensitiveAlphabet = []byte("abcdefghijklmnopqrstuvwxyz1234567890")
identifierRegex = regexp.MustCompile(`^[a-z0-9-_]+$`)
tagRegex = regexp.MustCompile(`^[a-z0-9-_.]+$`)
sandboxIDRegex = regexp.MustCompile(`^[a-z0-9]+$`)
)
const (
DefaultTag = "default"
TagSeparator = ":"
NamespaceSeparator = "/"
)
func Generate() string {
return uniuri.NewLenChars(uniuri.UUIDLen, caseInsensitiveAlphabet)
}
// ValidateSandboxID checks that a sandbox ID contains only lowercase alphanumeric characters.
func ValidateSandboxID(sandboxID string) error {
if !sandboxIDRegex.MatchString(sandboxID) {
return fmt.Errorf("invalid sandbox ID: %q", sandboxID)
}
return nil
}
func cleanAndValidate(value, name string, re *regexp.Regexp) (string, error) {
cleaned := strings.ToLower(strings.TrimSpace(value))
if !re.MatchString(cleaned) {
return "", fmt.Errorf("invalid %s: %s", name, value)
}
return cleaned, nil
}
func validateTag(tag string) (string, error) {
cleanedTag, err := cleanAndValidate(tag, "tag", tagRegex)
if err != nil {
return "", err
}
// Prevent tags from being a UUID
_, err = uuid.Parse(cleanedTag)
if err == nil {
return "", errors.New("tag cannot be a UUID")
}
return cleanedTag, nil
}
func ValidateAndDeduplicateTags(tags []string) ([]string, error) {
seen := make(map[string]struct{})
for _, tag := range tags {
cleanedTag, err := validateTag(tag)
if err != nil {
return nil, fmt.Errorf("invalid tag '%s': %w", tag, err)
}
seen[cleanedTag] = struct{}{}
}
return slices.Collect(maps.Keys(seen)), nil
}
// SplitIdentifier splits "namespace/alias" into its parts.
// Returns nil namespace for bare aliases, pointer for explicit namespace.
func SplitIdentifier(identifier string) (namespace *string, alias string) {
before, after, found := strings.Cut(identifier, NamespaceSeparator)
if !found {
return nil, before
}
return &before, after
}
// ParseName parses and validates "namespace/alias:tag" or "alias:tag".
// Returns the cleaned identifier (namespace/alias or alias) and optional tag.
// All components are validated and normalized (lowercase, trimmed).
func ParseName(input string) (identifier string, tag *string, err error) {
input = strings.TrimSpace(input)
// Extract raw parts
identifierPart, tagPart, hasTag := strings.Cut(input, TagSeparator)
namespacePart, aliasPart := SplitIdentifier(identifierPart)
// Validate tag
if hasTag {
validated, err := cleanAndValidate(tagPart, "tag", tagRegex)
if err != nil {
return "", nil, err
}
if !strings.EqualFold(validated, DefaultTag) {
tag = &validated
}
}
// Validate namespace
if namespacePart != nil {
validated, err := cleanAndValidate(*namespacePart, "namespace", identifierRegex)
if err != nil {
return "", nil, err
}
namespacePart = &validated
}
// Validate alias
aliasPart, err = cleanAndValidate(aliasPart, "template ID", identifierRegex)
if err != nil {
return "", nil, err
}
// Build identifier
if namespacePart != nil {
identifier = WithNamespace(*namespacePart, aliasPart)
} else {
identifier = aliasPart
}
return identifier, tag, nil
}
// WithTag returns the identifier with the given tag appended (e.g. "templateID:tag").
func WithTag(identifier, tag string) string {
return identifier + TagSeparator + tag
}
// WithNamespace returns identifier with the given namespace prefix.
func WithNamespace(namespace, alias string) string {
return namespace + NamespaceSeparator + alias
}
// ExtractAlias returns just the alias portion from an identifier (namespace/alias or alias).
func ExtractAlias(identifier string) string {
_, alias := SplitIdentifier(identifier)
return alias
}
// ValidateNamespaceMatchesTeam checks if an explicit namespace in the identifier matches the team's slug.
// Returns an error if the namespace doesn't match.
// If the identifier has no explicit namespace, returns nil (valid).
func ValidateNamespaceMatchesTeam(identifier, teamSlug string) error {
namespace, _ := SplitIdentifier(identifier)
if namespace != nil && *namespace != teamSlug {
return fmt.Errorf("namespace '%s' must match your team '%s'", *namespace, teamSlug)
}
return nil
}

View File

@ -0,0 +1,380 @@
package id
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
)
func TestParseName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantIdentifier string
wantTag *string
wantErr bool
}{
{
name: "bare alias only",
input: "my-template",
wantIdentifier: "my-template",
wantTag: nil,
},
{
name: "alias with tag",
input: "my-template:v1",
wantIdentifier: "my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "namespace and alias",
input: "acme/my-template",
wantIdentifier: "acme/my-template",
wantTag: nil,
},
{
name: "namespace, alias and tag",
input: "acme/my-template:v1",
wantIdentifier: "acme/my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "namespace with hyphens",
input: "my-team/my-template:prod",
wantIdentifier: "my-team/my-template",
wantTag: utils.ToPtr("prod"),
},
{
name: "default tag normalized to nil",
input: "my-template:default",
wantIdentifier: "my-template",
wantTag: nil,
},
{
name: "uppercase converted to lowercase",
input: "MyTemplate:Prod",
wantIdentifier: "mytemplate",
wantTag: utils.ToPtr("prod"),
},
{
name: "whitespace trimmed",
input: " my-template : v1 ",
wantIdentifier: "my-template",
wantTag: utils.ToPtr("v1"),
},
{
name: "invalid - empty namespace",
input: "/my-template",
wantErr: true,
},
{
name: "invalid - empty tag after colon",
input: "my-template:",
wantErr: true,
},
{
name: "invalid - special characters in alias",
input: "my template!",
wantErr: true,
},
{
name: "invalid - special characters in namespace",
input: "my team!/my-template",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotIdentifier, gotTag, err := ParseName(tt.input)
if tt.wantErr {
require.Error(t, err, "Expected ParseName() to return error, got")
return
}
require.NoError(t, err, "Expected ParseName() not to return error, got: %v", err)
assert.Equal(t, tt.wantIdentifier, gotIdentifier, "ParseName() identifier = %v, want %v", gotIdentifier, tt.wantIdentifier)
assert.Equal(t, tt.wantTag, gotTag, "ParseName() tag = %v, want %v", utils.Sprintp(gotTag), utils.Sprintp(tt.wantTag))
})
}
}
func TestWithNamespace(t *testing.T) {
t.Parallel()
got := WithNamespace("acme", "my-template")
want := "acme/my-template"
assert.Equal(t, want, got, "WithNamespace() = %q, want %q", got, want)
}
func TestSplitIdentifier(t *testing.T) {
t.Parallel()
tests := []struct {
name string
identifier string
wantNamespace *string
wantAlias string
}{
{
name: "bare alias",
identifier: "my-template",
wantNamespace: nil,
wantAlias: "my-template",
},
{
name: "with namespace",
identifier: "acme/my-template",
wantNamespace: ptrStr("acme"),
wantAlias: "my-template",
},
{
name: "empty namespace prefix",
identifier: "/my-template",
wantNamespace: ptrStr(""),
wantAlias: "my-template",
},
{
name: "multiple slashes - only first split",
identifier: "a/b/c",
wantNamespace: ptrStr("a"),
wantAlias: "b/c",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotNamespace, gotAlias := SplitIdentifier(tt.identifier)
if tt.wantNamespace == nil {
assert.Nil(t, gotNamespace)
} else {
require.NotNil(t, gotNamespace)
assert.Equal(t, *tt.wantNamespace, *gotNamespace)
}
assert.Equal(t, tt.wantAlias, gotAlias)
})
}
}
func ptrStr(s string) *string {
return &s
}
func TestValidateAndDeduplicateTags(t *testing.T) {
t.Parallel()
tests := []struct {
name string
tags []string
want []string
wantErr bool
}{
{
name: "single valid tag",
tags: []string{"v1"},
want: []string{"v1"},
wantErr: false,
},
{
name: "multiple unique tags",
tags: []string{"v1", "prod", "latest"},
want: []string{"v1", "prod", "latest"},
wantErr: false,
},
{
name: "duplicate tags deduplicated",
tags: []string{"v1", "V1", "v1"},
want: []string{"v1"},
wantErr: false,
},
{
name: "tags with dots and underscores",
tags: []string{"v1.0", "v1_1"},
want: []string{"v1.0", "v1_1"},
wantErr: false,
},
{
name: "invalid - UUID tag rejected",
tags: []string{"550e8400-e29b-41d4-a716-446655440000"},
wantErr: true,
},
{
name: "invalid - special characters",
tags: []string{"v1!", "v2@"},
wantErr: true,
},
{
name: "empty list returns empty",
tags: []string{},
want: []string{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := ValidateAndDeduplicateTags(tt.tags)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.ElementsMatch(t, tt.want, got)
})
}
}
func TestValidateSandboxID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "canonical sandbox ID",
input: "i1a2b3c4d5e6f7g8h9j0k",
wantErr: false,
},
{
name: "short alphanumeric",
input: "abc123",
wantErr: false,
},
{
name: "all digits",
input: "1234567890",
wantErr: false,
},
{
name: "all lowercase letters",
input: "abcdefghijklmnopqrst",
wantErr: false,
},
{
name: "invalid - empty",
input: "",
wantErr: true,
},
{
name: "invalid - contains colon (Redis separator)",
input: "abc:def",
wantErr: true,
},
{
name: "invalid - contains open brace (Redis hash slot)",
input: "abc{def",
wantErr: true,
},
{
name: "invalid - contains close brace (Redis hash slot)",
input: "abc}def",
wantErr: true,
},
{
name: "invalid - contains newline",
input: "abc\ndef",
wantErr: true,
},
{
name: "invalid - contains space",
input: "abc def",
wantErr: true,
},
{
name: "invalid - contains hyphen",
input: "abc-def",
wantErr: true,
},
{
name: "invalid - contains uppercase",
input: "abcDEF",
wantErr: true,
},
{
name: "invalid - contains slash",
input: "abc/def",
wantErr: true,
},
{
name: "invalid - contains null byte",
input: "abc\x00def",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := ValidateSandboxID(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateNamespaceMatchesTeam(t *testing.T) {
t.Parallel()
tests := []struct {
name string
identifier string
teamSlug string
wantErr bool
}{
{
name: "bare alias - no namespace",
identifier: "my-template",
teamSlug: "acme",
wantErr: false,
},
{
name: "matching namespace",
identifier: "acme/my-template",
teamSlug: "acme",
wantErr: false,
},
{
name: "mismatched namespace",
identifier: "other-team/my-template",
teamSlug: "acme",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := ValidateNamespaceMatchesTeam(tt.identifier, tt.teamSlug)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -0,0 +1,6 @@
package keys
const (
ApiKeyPrefix = "wrn_"
AccessTokenPrefix = "sk_wrn_"
)

View File

@ -0,0 +1,5 @@
package keys
type Hasher interface {
Hash(key []byte) string
}

View File

@ -0,0 +1,25 @@
package keys
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
)
type HMACSha256Hashing struct {
key []byte
}
func NewHMACSHA256Hashing(key []byte) *HMACSha256Hashing {
return &HMACSha256Hashing{key: key}
}
func (h *HMACSha256Hashing) Hash(content []byte) (string, error) {
mac := hmac.New(sha256.New, h.key)
_, err := mac.Write(content)
if err != nil {
return "", err
}
return hex.EncodeToString(mac.Sum(nil)), nil
}

View File

@ -0,0 +1,74 @@
package keys
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
)
func TestHMACSha256Hashing_ValidHash(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
content := []byte("hello world")
expectedHash := "18c4b268f0bbf8471eda56af3e70b1d4613d734dc538b4940b59931c412a1591"
actualHash, err := hasher.Hash(content)
require.NoError(t, err)
if actualHash != expectedHash {
t.Errorf("expected %s, got %s", expectedHash, actualHash)
}
}
func TestHMACSha256Hashing_EmptyContent(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
content := []byte("")
expectedHash := "2711cc23e9ab1b8a9bc0fe991238da92671624a9ebdaf1c1abec06e7e9a14f9b"
actualHash, err := hasher.Hash(content)
require.NoError(t, err)
if actualHash != expectedHash {
t.Errorf("expected %s, got %s", expectedHash, actualHash)
}
}
func TestHMACSha256Hashing_DifferentKey(t *testing.T) {
t.Parallel()
key := []byte("test-key")
hasher := NewHMACSHA256Hashing(key)
differentKeyHasher := NewHMACSHA256Hashing([]byte("different-key"))
content := []byte("hello world")
hashWithOriginalKey, err := hasher.Hash(content)
require.NoError(t, err)
hashWithDifferentKey, err := differentKeyHasher.Hash(content)
require.NoError(t, err)
if hashWithOriginalKey == hashWithDifferentKey {
t.Errorf("hashes with different keys should not match")
}
}
func TestHMACSha256Hashing_IdenticalResult(t *testing.T) {
t.Parallel()
key := []byte("placeholder-hashing-key")
content := []byte("test content for hashing")
mac := hmac.New(sha256.New, key)
mac.Write(content)
expectedResult := hex.EncodeToString(mac.Sum(nil))
hasher := NewHMACSHA256Hashing(key)
actualResult, err := hasher.Hash(content)
require.NoError(t, err)
if actualResult != expectedResult {
t.Errorf("expected %s, got %s", expectedResult, actualResult)
}
}

View File

@ -0,0 +1,99 @@
package keys
import (
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
)
const (
identifierValueSuffixLength = 4
identifierValuePrefixLength = 2
keyLength = 20
)
var hasher Hasher = NewSHA256Hashing()
type Key struct {
PrefixedRawValue string
HashedValue string
Masked MaskedIdentifier
}
type MaskedIdentifier struct {
Prefix string
ValueLength int
MaskedValuePrefix string
MaskedValueSuffix string
}
// MaskKey returns identifier masking properties in accordance to the OpenAPI response spec
func MaskKey(prefix, value string) (MaskedIdentifier, error) {
valueLength := len(value)
suffixOffset := valueLength - identifierValueSuffixLength
prefixOffset := identifierValuePrefixLength
if suffixOffset < 0 {
return MaskedIdentifier{}, fmt.Errorf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength)
}
if suffixOffset == 0 {
return MaskedIdentifier{}, fmt.Errorf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength)
}
// cap prefixOffset by suffixOffset to prevent overlap with the suffix.
if prefixOffset > suffixOffset {
prefixOffset = suffixOffset
}
maskPrefix := value[:prefixOffset]
maskSuffix := value[suffixOffset:]
maskedIdentifierProperties := MaskedIdentifier{
Prefix: prefix,
ValueLength: valueLength,
MaskedValuePrefix: maskPrefix,
MaskedValueSuffix: maskSuffix,
}
return maskedIdentifierProperties, nil
}
func GenerateKey(prefix string) (Key, error) {
keyBytes := make([]byte, keyLength)
_, err := rand.Read(keyBytes)
if err != nil {
return Key{}, err
}
generatedIdentifier := hex.EncodeToString(keyBytes)
mask, err := MaskKey(prefix, generatedIdentifier)
if err != nil {
return Key{}, err
}
return Key{
PrefixedRawValue: prefix + generatedIdentifier,
HashedValue: hasher.Hash(keyBytes),
Masked: mask,
}, nil
}
func VerifyKey(prefix string, key string) (string, error) {
if !strings.HasPrefix(key, prefix) {
return "", fmt.Errorf("invalid key prefix")
}
keyValue := key[len(prefix):]
keyBytes, err := hex.DecodeString(keyValue)
if err != nil {
return "", fmt.Errorf("invalid key")
}
return hasher.Hash(keyBytes), nil
}

View File

@ -0,0 +1,160 @@
package keys
import (
"fmt"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMaskKey(t *testing.T) {
t.Parallel()
t.Run("succeeds: value longer than suffix length", func(t *testing.T) {
t.Parallel()
masked, err := MaskKey("test_", "1234567890")
require.NoError(t, err)
assert.Equal(t, "test_", masked.Prefix)
assert.Equal(t, "12", masked.MaskedValuePrefix)
assert.Equal(t, "7890", masked.MaskedValueSuffix)
})
t.Run("succeeds: empty prefix, value longer than suffix length", func(t *testing.T) {
t.Parallel()
masked, err := MaskKey("", "1234567890")
require.NoError(t, err)
assert.Empty(t, masked.Prefix)
assert.Equal(t, "12", masked.MaskedValuePrefix)
assert.Equal(t, "7890", masked.MaskedValueSuffix)
})
t.Run("error: value length less than suffix length", func(t *testing.T) {
t.Parallel()
_, err := MaskKey("test", "123")
require.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength))
})
t.Run("error: value length equals suffix length", func(t *testing.T) {
t.Parallel()
_, err := MaskKey("test", "1234")
require.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength))
})
}
func TestGenerateKey(t *testing.T) {
t.Parallel()
keyLength := 40
t.Run("succeeds", func(t *testing.T) {
t.Parallel()
key, err := GenerateKey("test_")
require.NoError(t, err)
assert.Regexp(t, "^test_.*", key.PrefixedRawValue)
assert.Equal(t, "test_", key.Masked.Prefix)
assert.Equal(t, keyLength, key.Masked.ValueLength)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
})
t.Run("no prefix", func(t *testing.T) {
t.Parallel()
key, err := GenerateKey("")
require.NoError(t, err)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(keyLength)+"}$", key.PrefixedRawValue)
assert.Empty(t, key.Masked.Prefix)
assert.Equal(t, keyLength, key.Masked.ValueLength)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValuePrefixLength)+"}$", key.Masked.MaskedValuePrefix)
assert.Regexp(t, "^[0-9a-f]{"+strconv.Itoa(identifierValueSuffixLength)+"}$", key.Masked.MaskedValueSuffix)
assert.Regexp(t, "^\\$sha256\\$.*", key.HashedValue)
})
}
func TestGetMaskedIdentifierProperties(t *testing.T) {
t.Parallel()
type testCase struct {
name string
prefix string
value string
expectedResult MaskedIdentifier
expectedErrString string
}
testCases := []testCase{
// --- ERROR CASES (value's length <= identifierValueSuffixLength) ---
{
name: "error: value length < suffix length (3 vs 4)",
prefix: "pk_",
value: "abc",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
},
{
name: "error: value length == suffix length (4 vs 4)",
prefix: "sk_",
value: "abcd",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is equal to identifier suffix length (%d), which would expose the entire identifier in the mask", identifierValueSuffixLength),
},
{
name: "error: value length < suffix length (0 vs 4, empty value)",
prefix: "err_",
value: "",
expectedResult: MaskedIdentifier{},
expectedErrString: fmt.Sprintf("mask value length is less than identifier suffix length (%d)", identifierValueSuffixLength),
},
// --- SUCCESS CASES (value's length > identifierValueSuffixLength) ---
{
name: "success: value long (10), prefix val len fully used",
prefix: "pk_",
value: "abcdefghij",
expectedResult: MaskedIdentifier{
Prefix: "pk_",
ValueLength: 10,
MaskedValuePrefix: "ab",
MaskedValueSuffix: "ghij",
},
},
{
name: "success: value medium (5), prefix val len truncated by overlap",
prefix: "",
value: "abcde",
expectedResult: MaskedIdentifier{
Prefix: "",
ValueLength: 5,
MaskedValuePrefix: "a",
MaskedValueSuffix: "bcde",
},
},
{
name: "success: value medium (6), prefix val len fits exactly",
prefix: "pk_",
value: "abcdef",
expectedResult: MaskedIdentifier{
Prefix: "pk_",
ValueLength: 6,
MaskedValuePrefix: "ab",
MaskedValueSuffix: "cdef",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result, err := MaskKey(tc.prefix, tc.value)
if tc.expectedErrString != "" {
require.EqualError(t, err, tc.expectedErrString)
assert.Equal(t, tc.expectedResult, result)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedResult, result)
}
})
}
}

View File

@ -0,0 +1,30 @@
package keys
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
type Sha256Hashing struct{}
func NewSHA256Hashing() *Sha256Hashing {
return &Sha256Hashing{}
}
func (h *Sha256Hashing) Hash(key []byte) string {
hashBytes := sha256.Sum256(key)
hash64 := base64.RawStdEncoding.EncodeToString(hashBytes[:])
return fmt.Sprintf(
"$sha256$%s",
hash64,
)
}
func (h *Sha256Hashing) HashWithoutPrefix(key []byte) string {
hashBytes := sha256.Sum256(key)
return base64.RawStdEncoding.EncodeToString(hashBytes[:])
}

View File

@ -0,0 +1,15 @@
package keys
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSHA256Hashing(t *testing.T) {
t.Parallel()
hasher := NewSHA256Hashing()
hashed := hasher.Hash([]byte("test"))
assert.Regexp(t, "^\\$sha256\\$.*", hashed)
}

View File

@ -0,0 +1,20 @@
package keys
import (
"crypto/sha512"
"encoding/hex"
)
// HashAccessToken computes the SHA-512 hash of an access token.
func HashAccessToken(token string) string {
h := sha512.Sum512([]byte(token))
return hex.EncodeToString(h[:])
}
// HashAccessTokenBytes computes the SHA-512 hash of an access token from bytes.
func HashAccessTokenBytes(token []byte) string {
h := sha512.Sum512(token)
return hex.EncodeToString(h[:])
}

View File

@ -0,0 +1,47 @@
package smap
import (
cmap "github.com/orcaman/concurrent-map/v2"
)
type Map[V any] struct {
m cmap.ConcurrentMap[string, V]
}
func New[V any]() *Map[V] {
return &Map[V]{
m: cmap.New[V](),
}
}
func (m *Map[V]) Remove(key string) {
m.m.Remove(key)
}
func (m *Map[V]) Get(key string) (V, bool) {
return m.m.Get(key)
}
func (m *Map[V]) Insert(key string, value V) {
m.m.Set(key, value)
}
func (m *Map[V]) Upsert(key string, value V, cb cmap.UpsertCb[V]) V {
return m.m.Upsert(key, value, cb)
}
func (m *Map[V]) InsertIfAbsent(key string, value V) bool {
return m.m.SetIfAbsent(key, value)
}
func (m *Map[V]) Items() map[string]V {
return m.m.Items()
}
func (m *Map[V]) RemoveCb(key string, cb func(key string, v V, exists bool) bool) bool {
return m.m.RemoveCb(key, cb)
}
func (m *Map[V]) Count() int {
return m.m.Count()
}

View File

@ -0,0 +1,43 @@
package utils
import "fmt"
func ToPtr[T any](v T) *T {
return &v
}
func FromPtr[T any](s *T) T {
if s == nil {
var zero T
return zero
}
return *s
}
func Sprintp[T any](s *T) string {
if s == nil {
return "<nil>"
}
return fmt.Sprintf("%v", *s)
}
func DerefOrDefault[T any](s *T, defaultValue T) T {
if s == nil {
return defaultValue
}
return *s
}
func CastPtr[S any, T any](s *S, castFunc func(S) T) *T {
if s == nil {
return nil
}
t := castFunc(*s)
return &t
}

View File

@ -0,0 +1,27 @@
package utils
import (
"sync"
)
type AtomicMax struct {
val int64
mu sync.Mutex
}
func NewAtomicMax() *AtomicMax {
return &AtomicMax{}
}
func (a *AtomicMax) SetToGreater(newValue int64) bool {
a.mu.Lock()
defer a.mu.Unlock()
if a.val > newValue {
return false
}
a.val = newValue
return true
}

View File

@ -0,0 +1,76 @@
package utils
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAtomicMax_NewAtomicMax(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
require.NotNil(t, am)
require.Equal(t, int64(0), am.val)
}
func TestAtomicMax_SetToGreater_InitialValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
// Should succeed when newValue > current
assert.True(t, am.SetToGreater(10))
assert.Equal(t, int64(10), am.val)
}
func TestAtomicMax_SetToGreater_EqualValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = 10
// Should succeed when newValue > current
assert.True(t, am.SetToGreater(20))
assert.Equal(t, int64(20), am.val)
}
func TestAtomicMax_SetToGreater_GreaterValue(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = 10
// Should fail when newValue < current, keeping the max value
assert.False(t, am.SetToGreater(5))
assert.Equal(t, int64(10), am.val)
}
func TestAtomicMax_SetToGreater_NegativeValues(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
am.val = -5
assert.True(t, am.SetToGreater(-2))
assert.Equal(t, int64(-2), am.val)
}
func TestAtomicMax_SetToGreater_Concurrent(t *testing.T) {
t.Parallel()
am := NewAtomicMax()
var wg sync.WaitGroup
// Run 100 goroutines trying to update the value concurrently
numGoroutines := 100
wg.Add(numGoroutines)
for i := range numGoroutines {
go func(val int64) {
defer wg.Done()
am.SetToGreater(val)
}(int64(i))
}
wg.Wait()
// The final value should be 99 (the maximum value)
assert.Equal(t, int64(99), am.val)
}

View File

@ -0,0 +1,51 @@
package utils
import "sync"
type Map[K comparable, V any] struct {
m sync.Map
}
func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
m: sync.Map{},
}
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}

View File

@ -0,0 +1,43 @@
package utils
import (
"errors"
"mime"
"mime/multipart"
)
// CustomPart is a wrapper around multipart.Part that overloads the FileName method
type CustomPart struct {
*multipart.Part
}
// FileNameWithPath returns the filename parameter of the Part's Content-Disposition header.
// This method borrows from the original FileName method implementation but returns the full
// filename without using `filepath.Base`.
func (p *CustomPart) FileNameWithPath() (string, error) {
dispositionParams, err := p.parseContentDisposition()
if err != nil {
return "", err
}
filename, ok := dispositionParams["filename"]
if !ok {
return "", errors.New("filename not found in Content-Disposition header")
}
return filename, nil
}
func (p *CustomPart) parseContentDisposition() (map[string]string, error) {
v := p.Header.Get("Content-Disposition")
_, dispositionParams, err := mime.ParseMediaType(v)
if err != nil {
return nil, err
}
return dispositionParams, nil
}
// NewCustomPart creates a new CustomPart from a multipart.Part
func NewCustomPart(part *multipart.Part) *CustomPart {
return &CustomPart{Part: part}
}

View File

@ -0,0 +1,12 @@
package utils
import "path/filepath"
// FsnotifyPath creates an optionally recursive path for fsnotify/fsnotify internal implementation
func FsnotifyPath(path string, recursive bool) string {
if recursive {
return filepath.Join(path, "...")
}
return path
}

View File

@ -0,0 +1,290 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"time"
"connectrpc.com/authn"
connectcors "connectrpc.com/cors"
"github.com/go-chi/chi/v5"
"github.com/rs/cors"
"git.omukk.dev/wrenn/sandbox/envd/internal/api"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
"git.omukk.dev/wrenn/sandbox/envd/internal/logs"
"git.omukk.dev/wrenn/sandbox/envd/internal/permissions"
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
filesystemRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/filesystem"
processRpc "git.omukk.dev/wrenn/sandbox/envd/internal/services/process"
processSpec "git.omukk.dev/wrenn/sandbox/envd/internal/services/spec/process"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
const (
// Downstream timeout should be greater than upstream (in orchestrator proxy).
idleTimeout = 640 * time.Second
maxAge = 2 * time.Hour
defaultPort = 49983
portScannerInterval = 1000 * time.Millisecond
// This is the default user used in the container if not specified otherwise.
// It should be always overridden by the user in /init when building the template.
defaultUser = "root"
kilobyte = 1024
megabyte = 1024 * kilobyte
)
var (
Version = "0.5.4"
commitSHA string
isNotFC bool
port int64
versionFlag bool
commitFlag bool
startCmdFlag string
cgroupRoot string
)
func parseFlags() {
flag.BoolVar(
&isNotFC,
"isnotfc",
false,
"isNotFCmode prints all logs to stdout",
)
flag.BoolVar(
&versionFlag,
"version",
false,
"print envd version",
)
flag.BoolVar(
&commitFlag,
"commit",
false,
"print envd source commit",
)
flag.Int64Var(
&port,
"port",
defaultPort,
"a port on which the daemon should run",
)
flag.StringVar(
&startCmdFlag,
"cmd",
"",
"a command to run on the daemon start",
)
flag.StringVar(
&cgroupRoot,
"cgroup-root",
"/sys/fs/cgroup",
"cgroup root directory",
)
flag.Parse()
}
func withCORS(h http.Handler) http.Handler {
middleware := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
},
AllowedHeaders: []string{"*"},
ExposedHeaders: append(
connectcors.ExposedHeaders(),
"Location",
"Cache-Control",
"X-Content-Type-Options",
),
MaxAge: int(maxAge.Seconds()),
})
return middleware.Handler(h)
}
func main() {
parseFlags()
if versionFlag {
fmt.Printf("%s\n", Version)
return
}
if commitFlag {
fmt.Printf("%s\n", commitSHA)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := os.MkdirAll(host.WrennRunDir, 0o755); err != nil {
fmt.Fprintf(os.Stderr, "error creating wrenn run directory: %v\n", err)
}
defaults := &execcontext.Defaults{
User: defaultUser,
EnvVars: utils.NewMap[string, string](),
}
isFCBoolStr := strconv.FormatBool(!isNotFC)
defaults.EnvVars.Store("WRENN_SANDBOX", isFCBoolStr)
if err := os.WriteFile(filepath.Join(host.WrennRunDir, ".WRENN_SANDBOX"), []byte(isFCBoolStr), 0o444); err != nil {
fmt.Fprintf(os.Stderr, "error writing sandbox file: %v\n", err)
}
mmdsChan := make(chan *host.MMDSOpts, 1)
defer close(mmdsChan)
if !isNotFC {
go host.PollForMMDSOpts(ctx, mmdsChan, defaults.EnvVars)
}
l := logs.NewLogger(ctx, isNotFC, mmdsChan)
m := chi.NewRouter()
envLogger := l.With().Str("logger", "envd").Logger()
fsLogger := l.With().Str("logger", "filesystem").Logger()
filesystemRpc.Handle(m, &fsLogger, defaults)
cgroupManager := createCgroupManager()
defer func() {
err := cgroupManager.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to close cgroup manager: %v\n", err)
}
}()
processLogger := l.With().Str("logger", "process").Logger()
processService := processRpc.Handle(m, &processLogger, defaults, cgroupManager)
service := api.New(&envLogger, defaults, mmdsChan, isNotFC)
handler := api.HandlerFromMux(service, m)
middleware := authn.NewMiddleware(permissions.AuthenticateUsername)
s := &http.Server{
Handler: withCORS(
service.WithAuthorization(
middleware.Wrap(handler),
),
),
Addr: fmt.Sprintf("0.0.0.0:%d", port),
// We remove the timeouts as the connection is terminated by closing of the sandbox and keepalive close.
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: idleTimeout,
}
// TODO: Not used anymore in template build, replaced by direct envd command call.
if startCmdFlag != "" {
tag := "startCmd"
cwd := "/home/user"
user, err := permissions.GetUser("root")
if err != nil {
log.Fatalf("error getting user: %v", err) //nolint:gocritic // probably fine to bail if we're done?
}
if err = processService.InitializeStartProcess(ctx, user, &processSpec.StartRequest{
Tag: &tag,
Process: &processSpec.ProcessConfig{
Envs: make(map[string]string),
Cmd: "/bin/bash",
Args: []string{"-l", "-c", startCmdFlag},
Cwd: &cwd,
},
}); err != nil {
log.Fatalf("error starting process: %v", err)
}
}
// Bind all open ports on 127.0.0.1 and localhost to the eth0 interface
portScanner := publicport.NewScanner(portScannerInterval)
defer portScanner.Destroy()
portLogger := l.With().Str("logger", "port-forwarder").Logger()
portForwarder := publicport.NewForwarder(&portLogger, portScanner, cgroupManager)
go portForwarder.StartForwarding(ctx)
go portScanner.ScanAndBroadcast()
err := s.ListenAndServe()
if err != nil {
log.Fatalf("error starting server: %v", err)
}
}
func createCgroupManager() (m cgroups.Manager) {
defer func() {
if m == nil {
fmt.Fprintf(os.Stderr, "falling back to no-op cgroup manager\n")
m = cgroups.NewNoopManager()
}
}()
metrics, err := host.GetMetrics()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to calculate host metrics: %v\n", err)
return nil
}
// try to keep 1/8 of the memory free, but no more than 128 MB
maxMemoryReserved := uint64(float64(metrics.MemTotal) * .125)
maxMemoryReserved = min(maxMemoryReserved, uint64(128)*megabyte)
opts := []cgroups.Cgroup2ManagerOption{
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypePTY, "ptys", map[string]string{
"cpu.weight": "200", // gets much preferred cpu access, to help keep these real time
}),
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeSocat, "socats", map[string]string{
"cpu.weight": "150", // gets slightly preferred cpu access
"memory.min": fmt.Sprintf("%d", 5*megabyte),
"memory.low": fmt.Sprintf("%d", 8*megabyte),
}),
cgroups.WithCgroup2ProcessType(cgroups.ProcessTypeUser, "user", map[string]string{
"memory.high": fmt.Sprintf("%d", metrics.MemTotal-maxMemoryReserved),
"cpu.weight": "50", // less than envd, and less than core processes that default to 100
}),
}
if cgroupRoot != "" {
opts = append(opts, cgroups.WithCgroup2RootSysFSPath(cgroupRoot))
}
mgr, err := cgroups.NewCgroup2Manager(opts...)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create cgroup2 manager: %v\n", err)
return nil
}
return mgr
}

14
envd/spec/buf.gen.yaml Normal file
View File

@ -0,0 +1,14 @@
version: v1
plugins:
- plugin: go
out: ../internal/services/spec
opt: paths=source_relative
- plugin: connect-go
out: ../internal/services/spec
opt: paths=source_relative
managed:
enabled: true
optimize_for: SPEED
go_package_prefix:
default: git.omukk.dev/wrenn/sandbox/envd/internal/services/spec

303
envd/spec/envd.yaml Normal file
View File

@ -0,0 +1,303 @@
openapi: 3.0.0
info:
title: envd
version: 0.1.1
description: API for managing files' content and controlling envd
tags:
- name: files
paths:
/health:
get:
summary: Check the health of the service
responses:
"204":
description: The service is healthy
/metrics:
get:
summary: Get the stats of the service
security:
- AccessTokenAuth: []
- {}
responses:
"200":
description: The resource usage metrics of the service
content:
application/json:
schema:
$ref: "#/components/schemas/Metrics"
/init:
post:
summary: Set initial vars, ensure the time and metadata is synced with the host
security:
- AccessTokenAuth: []
- {}
requestBody:
content:
application/json:
schema:
type: object
properties:
volumeMounts:
type: array
items:
$ref: "#/components/schemas/VolumeMount"
hyperloopIP:
type: string
description: IP address of the hyperloop server to connect to
envVars:
$ref: "#/components/schemas/EnvVars"
accessToken:
type: string
description: Access token for secure access to envd service
x-go-type: SecureToken
timestamp:
type: string
format: date-time
description: The current timestamp in RFC3339 format
defaultUser:
type: string
description: The default user to use for operations
defaultWorkdir:
type: string
description: The default working directory to use for operations
responses:
"204":
description: Env vars set, the time and metadata is synced with the host
/envs:
get:
summary: Get the environment variables
security:
- AccessTokenAuth: []
- {}
responses:
"200":
description: Environment variables
content:
application/json:
schema:
$ref: "#/components/schemas/EnvVars"
/files:
get:
summary: Download a file
tags: [files]
security:
- AccessTokenAuth: []
- {}
parameters:
- $ref: "#/components/parameters/FilePath"
- $ref: "#/components/parameters/User"
- $ref: "#/components/parameters/Signature"
- $ref: "#/components/parameters/SignatureExpiration"
responses:
"200":
$ref: "#/components/responses/DownloadSuccess"
"401":
$ref: "#/components/responses/InvalidUser"
"400":
$ref: "#/components/responses/InvalidPath"
"404":
$ref: "#/components/responses/FileNotFound"
"500":
$ref: "#/components/responses/InternalServerError"
post:
summary: Upload a file and ensure the parent directories exist. If the file exists, it will be overwritten.
tags: [files]
security:
- AccessTokenAuth: []
- {}
parameters:
- $ref: "#/components/parameters/FilePath"
- $ref: "#/components/parameters/User"
- $ref: "#/components/parameters/Signature"
- $ref: "#/components/parameters/SignatureExpiration"
requestBody:
$ref: "#/components/requestBodies/File"
responses:
"200":
$ref: "#/components/responses/UploadSuccess"
"400":
$ref: "#/components/responses/InvalidPath"
"401":
$ref: "#/components/responses/InvalidUser"
"500":
$ref: "#/components/responses/InternalServerError"
"507":
$ref: "#/components/responses/NotEnoughDiskSpace"
components:
securitySchemes:
AccessTokenAuth:
type: apiKey
in: header
name: X-Access-Token
parameters:
FilePath:
name: path
in: query
required: false
description: Path to the file, URL encoded. Can be relative to user's home directory.
schema:
type: string
User:
name: username
in: query
required: false
description: User used for setting the owner, or resolving relative paths.
schema:
type: string
Signature:
name: signature
in: query
required: false
description: Signature used for file access permission verification.
schema:
type: string
SignatureExpiration:
name: signature_expiration
in: query
required: false
description: Signature expiration used for defining the expiration time of the signature.
schema:
type: integer
requestBodies:
File:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
file:
type: string
format: binary
responses:
UploadSuccess:
description: The file was uploaded successfully.
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/EntryInfo"
DownloadSuccess:
description: Entire file downloaded successfully.
content:
application/octet-stream:
schema:
type: string
format: binary
description: The file content
InvalidPath:
description: Invalid path
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
InternalServerError:
description: Internal server error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
FileNotFound:
description: File not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
InvalidUser:
description: Invalid user
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
NotEnoughDiskSpace:
description: Not enough disk space
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
schemas:
Error:
required:
- message
- code
properties:
message:
type: string
description: Error message
code:
type: integer
description: Error code
EntryInfo:
required:
- path
- name
- type
properties:
path:
type: string
description: Path to the file
name:
type: string
description: Name of the file
type:
type: string
description: Type of the file
enum:
- file
EnvVars:
type: object
description: Environment variables to set
additionalProperties:
type: string
Metrics:
type: object
description: Resource usage metrics
properties:
ts:
type: integer
format: int64
description: Unix timestamp in UTC for current sandbox time
cpu_count:
type: integer
description: Number of CPU cores
cpu_used_pct:
type: number
format: float
description: CPU usage percentage
mem_total:
type: integer
description: Total virtual memory in bytes
mem_used:
type: integer
description: Used virtual memory in bytes
disk_used:
type: integer
description: Used disk space in bytes
disk_total:
type: integer
description: Total disk space in bytes
VolumeMount:
type: object
description: Volume
additionalProperties: false
properties:
nfs_target:
type: string
path:
type: string
required:
- nfs_target
- path

View File

@ -0,0 +1,135 @@
syntax = "proto3";
package filesystem;
import "google/protobuf/timestamp.proto";
service Filesystem {
rpc Stat(StatRequest) returns (StatResponse);
rpc MakeDir(MakeDirRequest) returns (MakeDirResponse);
rpc Move(MoveRequest) returns (MoveResponse);
rpc ListDir(ListDirRequest) returns (ListDirResponse);
rpc Remove(RemoveRequest) returns (RemoveResponse);
rpc WatchDir(WatchDirRequest) returns (stream WatchDirResponse);
// Non-streaming versions of WatchDir
rpc CreateWatcher(CreateWatcherRequest) returns (CreateWatcherResponse);
rpc GetWatcherEvents(GetWatcherEventsRequest) returns (GetWatcherEventsResponse);
rpc RemoveWatcher(RemoveWatcherRequest) returns (RemoveWatcherResponse);
}
message MoveRequest {
string source = 1;
string destination = 2;
}
message MoveResponse {
EntryInfo entry = 1;
}
message MakeDirRequest {
string path = 1;
}
message MakeDirResponse {
EntryInfo entry = 1;
}
message RemoveRequest {
string path = 1;
}
message RemoveResponse {}
message StatRequest {
string path = 1;
}
message StatResponse {
EntryInfo entry = 1;
}
message EntryInfo {
string name = 1;
FileType type = 2;
string path = 3;
int64 size = 4;
uint32 mode = 5;
string permissions = 6;
string owner = 7;
string group = 8;
google.protobuf.Timestamp modified_time = 9;
// If the entry is a symlink, this field contains the target of the symlink.
optional string symlink_target = 10;
}
enum FileType {
FILE_TYPE_UNSPECIFIED = 0;
FILE_TYPE_FILE = 1;
FILE_TYPE_DIRECTORY = 2;
FILE_TYPE_SYMLINK = 3;
}
message ListDirRequest {
string path = 1;
uint32 depth = 2;
}
message ListDirResponse {
repeated EntryInfo entries = 1;
}
message WatchDirRequest {
string path = 1;
bool recursive = 2;
}
message FilesystemEvent {
string name = 1;
EventType type = 2;
}
message WatchDirResponse {
oneof event {
StartEvent start = 1;
FilesystemEvent filesystem = 2;
KeepAlive keepalive = 3;
}
message StartEvent {}
message KeepAlive {}
}
message CreateWatcherRequest {
string path = 1;
bool recursive = 2;
}
message CreateWatcherResponse {
string watcher_id = 1;
}
message GetWatcherEventsRequest {
string watcher_id = 1;
}
message GetWatcherEventsResponse {
repeated FilesystemEvent events = 1;
}
message RemoveWatcherRequest {
string watcher_id = 1;
}
message RemoveWatcherResponse {}
enum EventType {
EVENT_TYPE_UNSPECIFIED = 0;
EVENT_TYPE_CREATE = 1;
EVENT_TYPE_WRITE = 2;
EVENT_TYPE_REMOVE = 3;
EVENT_TYPE_RENAME = 4;
EVENT_TYPE_CHMOD = 5;
}

3
envd/spec/generate.go Normal file
View File

@ -0,0 +1,3 @@
package spec
//go:generate buf generate --template buf.gen.yaml

View File

@ -0,0 +1,171 @@
syntax = "proto3";
package process;
service Process {
rpc List(ListRequest) returns (ListResponse);
rpc Connect(ConnectRequest) returns (stream ConnectResponse);
rpc Start(StartRequest) returns (stream StartResponse);
rpc Update(UpdateRequest) returns (UpdateResponse);
// Client input stream ensures ordering of messages
rpc StreamInput(stream StreamInputRequest) returns (StreamInputResponse);
rpc SendInput(SendInputRequest) returns (SendInputResponse);
rpc SendSignal(SendSignalRequest) returns (SendSignalResponse);
// Close stdin to signal EOF to the process.
// Only works for non-PTY processes. For PTY, send Ctrl+D (0x04) instead.
rpc CloseStdin(CloseStdinRequest) returns (CloseStdinResponse);
}
message PTY {
Size size = 1;
message Size {
uint32 cols = 1;
uint32 rows = 2;
}
}
message ProcessConfig {
string cmd = 1;
repeated string args = 2;
map<string, string> envs = 3;
optional string cwd = 4;
}
message ListRequest {}
message ProcessInfo {
ProcessConfig config = 1;
uint32 pid = 2;
optional string tag = 3;
}
message ListResponse {
repeated ProcessInfo processes = 1;
}
message StartRequest {
ProcessConfig process = 1;
optional PTY pty = 2;
optional string tag = 3;
// This is optional for backwards compatibility.
// We default to true. New SDK versions will set this to false by default.
optional bool stdin = 4;
}
message UpdateRequest {
ProcessSelector process = 1;
optional PTY pty = 2;
}
message UpdateResponse {}
message ProcessEvent {
oneof event {
StartEvent start = 1;
DataEvent data = 2;
EndEvent end = 3;
KeepAlive keepalive = 4;
}
message StartEvent {
uint32 pid = 1;
}
message DataEvent {
oneof output {
bytes stdout = 1;
bytes stderr = 2;
bytes pty = 3;
}
}
message EndEvent {
sint32 exit_code = 1;
bool exited = 2;
string status = 3;
optional string error = 4;
}
message KeepAlive {}
}
message StartResponse {
ProcessEvent event = 1;
}
message ConnectResponse {
ProcessEvent event = 1;
}
message SendInputRequest {
ProcessSelector process = 1;
ProcessInput input = 2;
}
message SendInputResponse {}
message ProcessInput {
oneof input {
bytes stdin = 1;
bytes pty = 2;
}
}
message StreamInputRequest {
oneof event {
StartEvent start = 1;
DataEvent data = 2;
KeepAlive keepalive = 3;
}
message StartEvent {
ProcessSelector process = 1;
}
message DataEvent {
ProcessInput input = 2;
}
message KeepAlive {}
}
message StreamInputResponse {}
enum Signal {
SIGNAL_UNSPECIFIED = 0;
SIGNAL_SIGTERM = 15;
SIGNAL_SIGKILL = 9;
}
message SendSignalRequest {
ProcessSelector process = 1;
Signal signal = 2;
}
message SendSignalResponse {}
message CloseStdinRequest {
ProcessSelector process = 1;
}
message CloseStdinResponse {}
message ConnectRequest {
ProcessSelector process = 1;
}
message ProcessSelector {
oneof selector {
uint32 pid = 1;
string tag = 2;
}
}

68
go.mod
View File

@ -1,17 +1,59 @@
module github.com/wrenn-dev/wrenn-sandbox module git.omukk.dev/wrenn/sandbox
go 1.23.0 go 1.25.0
require ( require (
github.com/firecracker-microvm/firecracker-go-sdk v1.1.1 connectrpc.com/connect v1.19.1 // indirect
github.com/go-chi/chi/v5 v5.2.1 github.com/PuerkitoBio/purell v1.1.1 // indirect
github.com/gorilla/websocket v1.5.3 github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/jackc/pgx/v5 v5.7.4 github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect
github.com/mdlayher/vsock v1.2.1 github.com/containerd/fifo v1.0.0 // indirect
github.com/pressly/goose/v3 v3.24.3 github.com/containernetworking/cni v1.0.1 // indirect
github.com/prometheus/client_golang v1.21.1 github.com/containernetworking/plugins v1.0.1 // indirect
github.com/rs/cors v1.11.1 github.com/firecracker-microvm/firecracker-go-sdk v1.0.0 // indirect
golang.org/x/crypto v0.36.0 github.com/go-chi/chi/v5 v5.2.5 // indirect
google.golang.org/grpc v1.71.0 github.com/go-openapi/analysis v0.21.2 // indirect
google.golang.org/protobuf v1.36.5 github.com/go-openapi/errors v0.20.2 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.6 // indirect
github.com/go-openapi/loads v0.21.1 // indirect
github.com/go-openapi/runtime v0.24.0 // indirect
github.com/go-openapi/spec v0.20.4 // indirect
github.com/go-openapi/strfmt v0.21.2 // indirect
github.com/go-openapi/swag v0.21.1 // indirect
github.com/go-openapi/validate v0.22.0 // indirect
github.com/go-stack/stack v1.8.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.8.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/mdlayher/vsock v1.2.1 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pressly/goose/v3 v3.27.0 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect
github.com/rs/cors v1.11.1 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 // indirect
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect
go.mongodb.org/mongo-driver v1.8.3 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )

1219
go.sum Normal file

File diff suppressed because it is too large Load Diff